Implements unified model routing

Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach.

This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings.
It provides a more flexible and maintainable routing mechanism by centralizing routing logic.

The changes include:
- Introducing new `routing` package with core routing logic.
- Creating characterization tests to capture existing behavior.
- Implementing model extraction and rewriting.
- Updating AMP module routes to utilize the new routing wrapper.
- Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
This commit is contained in:
이대희
2026-02-01 16:58:06 +09:00
parent 527a269799
commit 9299897e04
14 changed files with 2105 additions and 31 deletions

View File

@@ -86,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
//
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
// This type is kept for backward compatibility and test purposes.
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
@@ -94,6 +98,8 @@ type FallbackHandler struct {
// NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
//
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
@@ -102,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
}
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
//
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
if forceModelMappings == nil {
forceModelMappings = func() bool { return false }

View File

@@ -0,0 +1,326 @@
package amp
import (
"bytes"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil"
"github.com/stretchr/testify/assert"
)
// Characterization tests for fallback_handlers.go using testutil recorders
// These tests capture existing behavior before refactoring to routing layer
func TestCharacterization_LocalProvider(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{
{ID: "test-model-local"},
})
defer reg.UnregisterClient("char-test-local")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with proxy recorder
// Create a test server to act as the proxy target
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
// Create a reverse proxy that forwards to our test server
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
// Assert: local handler called once
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
// Assert: request body model unchanged
assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged")
}
func TestCharacterization_ModelMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the TARGET model (the mapped-to model)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{
{ID: "gpt-4-local"},
})
defer reg.UnregisterClient("char-test-mapped")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create model mapper with a mapping
mapper := NewModelMapper([]config.AmpModelMapping{
{From: "gpt-4-turbo", To: "gpt-4-local"},
})
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Request with original model that gets mapped
body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with mapper
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
}, mapper, func() bool { return false })
// Execute - use handler that returns model in response for rewriter to work
wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel())
wrapped(c)
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping")
// Assert: local handler called once
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
// Assert: request body model was rewritten to mapped model
assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model")
assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model")
// Assert: context has mapped_model key set
mappedModel, exists := handlerRecorder.GetContextKey("mapped_model")
assert.True(t, exists, "context should have mapped_model key")
assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model")
// Assert: response body model rewritten back to original
// The response writer should rewrite model names in the response
responseBody := w.Body.String()
assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name")
}
func TestCharacterization_AmpCreditsProxy(t *testing.T) {
gin.SetMode(gin.TestMode)
// Setup recorders - NO local provider registered, NO mapping configured
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context with CloseNotifier support (required for ReverseProxy)
w := testutil.NewCloseNotifierRecorder()
c, _ := gin.CreateTestContext(w)
// Request with a model that has no local provider and no mapping
body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: proxy called once
assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping")
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
// Assert: local handler NOT called
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy")
// Assert: body forwarded to proxy is original (no rewrite)
assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying")
}
func TestCharacterization_BodyRestore(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{
{ID: "test-model-body"},
})
defer reg.UnregisterClient("char-test-body")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create a complex request body that will be read by the wrapper for model extraction
originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with proxy recorder
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: local handler called (not proxy, since we have a local provider)
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
// Assert: handler receives complete original body
// This verifies that the body was properly restored after the wrapper read it for model extraction
assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction")
}
// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler
// This is a characterization test for the route gating logic in routes.go
func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model (Gemini format uses path-based model extraction)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{
{ID: "gemini-pro"},
})
defer reg.UnregisterClient("char-test-gemini")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create a test server for the proxy
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
// Create fallback handler
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Create the Gemini bridge handler (simulating what routes.go does)
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
// Create router with the same gating logic as routes.go
r := gin.New()
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
// POST with /models/ path -> use Gemini bridge with fallback handler
geminiV1Beta1Handler(c)
return
}
}
// Non-POST or no /models/ in path -> proxy upstream
proxyRecorder.ServeHTTP(c.Writer, c.Request)
})
// Execute: POST request with /models/ in path
body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Assert: local Gemini handler called
assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/")
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path")
}
// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy
// This is a characterization test for the route gating logic in routes.go
func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) {
gin.SetMode(gin.TestMode)
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create a test server for the proxy
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
// Create fallback handler
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Create the Gemini bridge handler
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
// Create router with the same gating logic as routes.go
r := gin.New()
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
geminiV1Beta1Handler(c)
return
}
}
proxyRecorder.ServeHTTP(c.Writer, c.Request)
})
// Execute: GET request (even with /models/ in path)
req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Assert: proxy called
assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests")
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
// Assert: local handler NOT called
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests")
}

View File

@@ -276,6 +276,22 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
return result
}
// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice.
// Safe for concurrent use.
func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]config.AmpModelMapping, 0, len(m.mappings))
for from, to := range m.mappings {
result = append(result, config.AmpModelMapping{
From: from,
To: to,
})
}
return result
}
type regexMapping struct {
re *regexp.Regexp
to string

View File

@@ -5,11 +5,12 @@ import (
"errors"
"net"
"net/http"
"net/http/httputil"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
@@ -234,19 +235,20 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
// If no local OAuth is available, falls back to ampcode.com proxy.
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper, m.forceModelMappings)
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
// Route POST model calls through Gemini bridge with FallbackHandler.
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
// T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper
// Create a dedicated routing wrapper for the Gemini bridge
geminiBridgeWrapper := m.createModelRoutingWrapper()
geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge)
// Route POST model calls through Gemini bridge with ModelRoutingWrapper.
// ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically.
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
// POST with /models/ path -> use Gemini bridge with fallback handler
// FallbackHandler will check provider/mapping and proxy if needed
// POST with /models/ path -> use Gemini bridge with unified routing wrapper
// ModelRoutingWrapper will check provider/mapping and proxy if needed
geminiV1Beta1Handler(c)
return
}
@@ -256,6 +258,41 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
})
}
// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing.
// This is used for testing the new routing implementation (T-021 onwards).
func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper {
// Create a registry - in production this would be populated with actual providers
registry := routing.NewRegistry()
// Create a minimal config with just AmpCode settings
// The Router only needs AmpCode.ModelMappings and OAuthModelAlias
cfg := &config.Config{
AmpCode: func() config.AmpCode {
if m.modelMapper != nil {
return config.AmpCode{
ModelMappings: m.modelMapper.GetMappingsAsConfig(),
}
}
return config.AmpCode{}
}(),
}
// Create router with registry and config
router := routing.NewRouter(registry, cfg)
// Create wrapper with proxy function
proxyFunc := func(c *gin.Context) {
proxy := m.getProxy()
if proxy != nil {
proxy.ServeHTTP(c.Writer, c.Request)
} else {
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
}
}
return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc)
}
// registerProviderAliases registers /api/provider/{provider}/... routes
// These allow Amp CLI to route requests like:
//
@@ -269,12 +306,9 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
// Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper, m.forceModelMappings)
// Create unified routing wrapper (T-021 onwards)
// Replaces FallbackHandler with Router-based unified routing
routingWrapper := m.createModelRoutingWrapper()
// Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider")
@@ -302,33 +336,36 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
}
// Root-level routes (for providers that omit /v1, like groq/cerebras)
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
// T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
// /v1 routes (OpenAI/Claude-compatible endpoints)
v1Amp := provider.Group("/v1")
{
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
// OpenAI-compatible endpoints with fallback
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
// OpenAI-compatible endpoints with ModelRoutingWrapper
// T-021, T-022: Migrated to unified routing wrapper
v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
// Claude/Anthropic-compatible endpoints with fallback
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
// Claude/Anthropic-compatible endpoints with ModelRoutingWrapper
// T-023: Migrated Claude routes to unified routing wrapper
v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages))
v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens))
}
// /v1beta routes (Gemini native API)
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
// T-024: Migrated Gemini v1beta routes to unified routing wrapper
v1betaAmp := provider.Group("/v1beta")
{
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler))
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
}
}

View File

@@ -0,0 +1,59 @@
package routing
import (
"strings"
"github.com/tidwall/gjson"
)
// ModelExtractor extracts model names from request data.
type ModelExtractor interface {
// Extract returns the model name from the request body and gin parameters.
// The ginParams map contains route parameters like "action" and "path".
Extract(body []byte, ginParams map[string]string) (string, error)
}
// DefaultModelExtractor is the standard implementation of ModelExtractor.
type DefaultModelExtractor struct{}
// NewModelExtractor creates a new DefaultModelExtractor.
func NewModelExtractor() *DefaultModelExtractor {
return &DefaultModelExtractor{}
}
// Extract extracts the model name from the request.
// It checks in order:
// 1. JSON body "model" field (OpenAI, Claude format)
// 2. "action" parameter for Gemini standard format (e.g., "gemini-pro:generateContent")
// 3. "path" parameter for AMP CLI Gemini format (e.g., "/publishers/google/models/gemini-3-pro:streamGenerateContent")
func (e *DefaultModelExtractor) Extract(body []byte, ginParams map[string]string) (string, error) {
// First try to parse from JSON body (OpenAI, Claude, etc.)
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
return result.String(), nil
}
// For Gemini requests, model is in the URL path
// Standard format: /models/{model}:generateContent -> :action parameter
if action, ok := ginParams["action"]; ok && action != "" {
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
parts := strings.Split(action, ":")
if len(parts) > 0 && parts[0] != "" {
return parts[0], nil
}
}
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
if path, ok := ginParams["path"]; ok && path != "" {
// Look for /models/{model}:method pattern
if idx := strings.Index(path, "/models/"); idx >= 0 {
modelPart := path[idx+8:] // Skip "/models/"
// Split by colon to get model name
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
return modelPart[:colonIdx], nil
}
}
}
return "", nil
}

View File

@@ -0,0 +1,214 @@
package routing
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestModelExtractor_ExtractFromJSONBody(t *testing.T) {
extractor := NewModelExtractor()
tests := []struct {
name string
body []byte
want string
wantErr bool
}{
{
name: "extract from JSON body with model field",
body: []byte(`{"model":"gpt-4.1"}`),
want: "gpt-4.1",
},
{
name: "extract claude model from JSON body",
body: []byte(`{"model":"claude-3-5-sonnet-20241022"}`),
want: "claude-3-5-sonnet-20241022",
},
{
name: "extract with additional fields",
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`),
want: "gpt-4",
},
{
name: "empty body returns empty",
body: []byte{},
want: "",
},
{
name: "no model field returns empty",
body: []byte(`{"messages":[]}`),
want: "",
},
{
name: "model is not string returns empty",
body: []byte(`{"model":123}`),
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractor.Extract(tt.body, nil)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestModelExtractor_ExtractFromGeminiActionParam(t *testing.T) {
extractor := NewModelExtractor()
tests := []struct {
name string
body []byte
ginParams map[string]string
want string
}{
{
name: "extract from action parameter - gemini-pro",
body: []byte(`{}`),
ginParams: map[string]string{"action": "gemini-pro:generateContent"},
want: "gemini-pro",
},
{
name: "extract from action parameter - gemini-ultra",
body: []byte(`{}`),
ginParams: map[string]string{"action": "gemini-ultra:chat"},
want: "gemini-ultra",
},
{
name: "empty action returns empty",
body: []byte(`{}`),
ginParams: map[string]string{"action": ""},
want: "",
},
{
name: "action without colon returns full value",
body: []byte(`{}`),
ginParams: map[string]string{"action": "gemini-model"},
want: "gemini-model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractor.Extract(tt.body, tt.ginParams)
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestModelExtractor_ExtractFromGeminiV1Beta1Path(t *testing.T) {
extractor := NewModelExtractor()
tests := []struct {
name string
body []byte
ginParams map[string]string
want string
}{
{
name: "extract from v1beta1 path - gemini-3-pro",
body: []byte(`{}`),
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro:streamGenerateContent"},
want: "gemini-3-pro",
},
{
name: "extract from v1beta1 path with preview",
body: []byte(`{}`),
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro-preview:generateContent"},
want: "gemini-3-pro-preview",
},
{
name: "path without models segment returns empty",
body: []byte(`{}`),
ginParams: map[string]string{"path": "/publishers/google/gemini-3-pro:streamGenerateContent"},
want: "",
},
{
name: "empty path returns empty",
body: []byte(`{}`),
ginParams: map[string]string{"path": ""},
want: "",
},
{
name: "path with /models/ but no colon returns empty",
body: []byte(`{}`),
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro"},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractor.Extract(tt.body, tt.ginParams)
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestModelExtractor_ExtractPriority(t *testing.T) {
extractor := NewModelExtractor()
// JSON body takes priority over gin params
t.Run("JSON body takes priority over action param", func(t *testing.T) {
body := []byte(`{"model":"gpt-4"}`)
params := map[string]string{"action": "gemini-pro:generateContent"}
got, err := extractor.Extract(body, params)
assert.NoError(t, err)
assert.Equal(t, "gpt-4", got)
})
// Action param takes priority over path param
t.Run("action param takes priority over path param", func(t *testing.T) {
body := []byte(`{}`)
params := map[string]string{
"action": "gemini-action:generate",
"path": "/publishers/google/models/gemini-path:streamGenerateContent",
}
got, err := extractor.Extract(body, params)
assert.NoError(t, err)
assert.Equal(t, "gemini-action", got)
})
}
func TestModelExtractor_NoModelFound(t *testing.T) {
extractor := NewModelExtractor()
tests := []struct {
name string
body []byte
ginParams map[string]string
}{
{
name: "empty body and no params",
body: []byte{},
ginParams: nil,
},
{
name: "body without model and no params",
body: []byte(`{"messages":[]}`),
ginParams: map[string]string{},
},
{
name: "irrelevant params only",
body: []byte(`{}`),
ginParams: map[string]string{"other": "value"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractor.Extract(tt.body, tt.ginParams)
assert.NoError(t, err)
assert.Empty(t, got)
})
}
}

View File

@@ -0,0 +1,159 @@
package routing
import (
"bytes"
"net/http"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
log "github.com/sirupsen/logrus"
)
// ModelRewriter handles model name rewriting in requests and responses.
type ModelRewriter interface {
// RewriteRequestBody rewrites the model field in a JSON request body.
// Returns the modified body or the original if no rewrite was needed.
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
}
// DefaultModelRewriter is the standard implementation of ModelRewriter.
type DefaultModelRewriter struct{}
// NewModelRewriter creates a new DefaultModelRewriter.
func NewModelRewriter() *DefaultModelRewriter {
return &DefaultModelRewriter{}
}
// RewriteRequestBody replaces the model name in a JSON request body.
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
if !gjson.GetBytes(body, "model").Exists() {
return body, nil
}
result, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body, err
}
return result, nil
}
// WrapResponseWriter wraps a response writer to rewrite model names.
// The cleanup function must be called after the handler completes to flush any buffered data.
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
rw := &responseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
requestedModel: requestedModel,
resolvedModel: resolvedModel,
}
return rw, func() { rw.flush() }
}
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
type responseRewriter struct {
http.ResponseWriter
body *bytes.Buffer
requestedModel string
resolvedModel string
isStreaming bool
wroteHeader bool
flushed bool
}
// Write intercepts response writes and buffers them for model name replacement.
func (rw *responseRewriter) Write(data []byte) (int, error) {
// Ensure header is written
if !rw.wroteHeader {
rw.WriteHeader(http.StatusOK)
}
// Detect streaming on first write
if rw.body.Len() == 0 && !rw.isStreaming {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
if err == nil {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
return n, err
}
return rw.body.Write(data)
}
// WriteHeader captures the status code and delegates to the underlying writer.
func (rw *responseRewriter) WriteHeader(code int) {
if !rw.wroteHeader {
rw.wroteHeader = true
rw.ResponseWriter.WriteHeader(code)
}
}
// flush writes the buffered response with model names rewritten.
func (rw *responseRewriter) flush() {
if rw.flushed {
return
}
rw.flushed = true
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
return
}
if rw.body.Len() > 0 {
data := rw.rewriteModelInResponse(rw.body.Bytes())
if _, err := rw.ResponseWriter.Write(data); err != nil {
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear.
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
return data
}
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
}
}
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks.
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
return chunk
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
}
}
}
return bytes.Join(lines, []byte("\n"))
}

View File

@@ -0,0 +1,342 @@
package routing
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestModelRewriter_RewriteRequestBody(t *testing.T) {
rewriter := NewModelRewriter()
tests := []struct {
name string
body []byte
newModel string
wantModel string
wantChange bool
}{
{
name: "rewrites model field in JSON body",
body: []byte(`{"model":"gpt-4.1","messages":[]}`),
newModel: "claude-local",
wantModel: "claude-local",
wantChange: true,
},
{
name: "rewrites with empty body returns empty",
body: []byte{},
newModel: "gpt-4",
wantModel: "",
wantChange: false,
},
{
name: "handles missing model field gracefully",
body: []byte(`{"messages":[{"role":"user"}]}`),
newModel: "gpt-4",
wantModel: "",
wantChange: false,
},
{
name: "preserves other fields when rewriting",
body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`),
newModel: "new-model",
wantModel: "new-model",
wantChange: true,
},
{
name: "handles nested JSON structure",
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`),
newModel: "claude-3-opus",
wantModel: "claude-3-opus",
wantChange: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel)
require.NoError(t, err)
if tt.wantChange {
assert.NotEqual(t, string(tt.body), string(result), "body should have been modified")
}
if tt.wantModel != "" {
// Parse result and check model field
model, _ := NewModelExtractor().Extract(result, nil)
assert.Equal(t, tt.wantModel, model)
}
})
}
}
func TestModelRewriter_WrapResponseWriter(t *testing.T) {
rewriter := NewModelRewriter()
t.Run("response writer wraps without error", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
require.NotNil(t, wrapped)
require.NotNil(t, cleanup)
defer cleanup()
})
t.Run("rewrites model in non-streaming response", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
// Write a response with the resolved model
response := []byte(`{"model":"claude-local","content":"hello"}`)
wrapped.Header().Set("Content-Type", "application/json")
_, err := wrapped.Write(response)
require.NoError(t, err)
// Cleanup triggers the rewrite
cleanup()
// Check the response was rewritten to the requested model
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"model":"gpt-4"`)
assert.NotContains(t, string(body), `"model":"claude-local"`)
})
t.Run("no-op when requested equals resolved", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4")
response := []byte(`{"model":"gpt-4","content":"hello"}`)
wrapped.Header().Set("Content-Type", "application/json")
_, err := wrapped.Write(response)
require.NoError(t, err)
cleanup()
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"model":"gpt-4"`)
})
t.Run("rewrites modelVersion field", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"modelVersion":"claude-local","content":"hello"}`)
wrapped.Header().Set("Content-Type", "application/json")
_, err := wrapped.Write(response)
require.NoError(t, err)
cleanup()
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"modelVersion":"gpt-4"`)
})
t.Run("handles streaming responses", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
// Set streaming content type
wrapped.Header().Set("Content-Type", "text/event-stream")
// Write SSE chunks with resolved model
chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n")
_, err := wrapped.Write(chunk1)
require.NoError(t, err)
chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n")
_, err = wrapped.Write(chunk2)
require.NoError(t, err)
cleanup()
// For streaming, data is written immediately with rewrites
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"model":"gpt-4"`)
assert.NotContains(t, string(body), `"model":"claude-local"`)
})
t.Run("empty body handled gracefully", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
wrapped.Header().Set("Content-Type", "application/json")
// Don't write anything
cleanup()
body := recorder.Body.Bytes()
assert.Empty(t, body)
})
t.Run("preserves other JSON fields", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`)
wrapped.Header().Set("Content-Type", "application/json")
_, err := wrapped.Write(response)
require.NoError(t, err)
cleanup()
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"temperature":0.7`)
assert.Contains(t, string(body), `"prompt_tokens":10`)
})
}
func TestResponseRewriter_ImplementsInterfaces(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
defer cleanup()
// Should implement http.ResponseWriter
assert.Implements(t, (*http.ResponseWriter)(nil), wrapped)
// Should preserve header access
wrapped.Header().Set("X-Custom", "value")
assert.Equal(t, "value", recorder.Header().Get("X-Custom"))
// Should write status
wrapped.WriteHeader(http.StatusCreated)
assert.Equal(t, http.StatusCreated, recorder.Code)
}
func TestResponseRewriter_Flush(t *testing.T) {
t.Run("flush writes buffered content", func(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"model":"claude-local","content":"test"}`)
wrapped.Header().Set("Content-Type", "application/json")
wrapped.Write(response)
// Before cleanup, response should be empty (buffered)
assert.Empty(t, recorder.Body.Bytes())
// After cleanup, response should be written
cleanup()
assert.NotEmpty(t, recorder.Body.Bytes())
})
t.Run("multiple flush calls are safe", func(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"model":"claude-local"}`)
wrapped.Header().Set("Content-Type", "application/json")
wrapped.Write(response)
// First cleanup
cleanup()
firstBody := recorder.Body.Bytes()
// Second cleanup should not write again
cleanup()
secondBody := recorder.Body.Bytes()
assert.Equal(t, firstBody, secondBody)
})
}
func TestResponseRewriter_StreamingWithDataLines(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
wrapped.Header().Set("Content-Type", "text/event-stream")
// SSE format with multiple data lines
chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n")
wrapped.Write(chunk)
cleanup()
body := recorder.Body.Bytes()
// Both data lines should have model rewritten
assert.Contains(t, string(body), `"model":"gpt-4"`)
assert.NotContains(t, string(body), `"model":"claude-local"`)
}
func TestModelRewriter_RoundTrip(t *testing.T) {
// Simulate a full request -> response cycle with model rewriting
rewriter := NewModelRewriter()
// Step 1: Rewrite request body
originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`)
rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local")
require.NoError(t, err)
// Verify request was rewritten
extractor := NewModelExtractor()
requestModel, _ := extractor.Extract(rewrittenRequest, nil)
assert.Equal(t, "claude-local", requestModel)
// Step 2: Simulate response with resolved model
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`)
wrapped.Header().Set("Content-Type", "application/json")
wrapped.Write(response)
cleanup()
// Verify response was rewritten back
body, _ := io.ReadAll(recorder.Result().Body)
responseModel, _ := extractor.Extract(body, nil)
assert.Equal(t, "gpt-4", responseModel)
}
func TestModelRewriter_NonJSONBody(t *testing.T) {
rewriter := NewModelRewriter()
// Binary/non-JSON body should be returned unchanged
body := []byte{0x00, 0x01, 0x02, 0x03}
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestModelRewriter_InvalidJSON(t *testing.T) {
rewriter := NewModelRewriter()
// Invalid JSON without model field should be returned unchanged
body := []byte(`not valid json`)
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestResponseRewriter_StatusCodePreserved(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
wrapped.WriteHeader(http.StatusAccepted)
wrapped.Write([]byte(`{"model":"claude-local"}`))
cleanup()
assert.Equal(t, http.StatusAccepted, recorder.Code)
}
func TestResponseRewriter_HeaderFlushed(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
wrapped.Header().Set("Content-Type", "application/json")
wrapped.Header().Set("X-Request-ID", "abc123")
wrapped.Write([]byte(`{"model":"claude-local"}`))
cleanup()
result := recorder.Result()
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
assert.Equal(t, "abc123", result.Header.Get("X-Request-ID"))
}

View File

@@ -31,15 +31,17 @@ func NewRouter(registry *Registry, cfg *config.Config) *Router {
return r
}
// RoutingDecision contains the resolved routing information.
type RoutingDecision struct {
// LegacyRoutingDecision contains the resolved routing information.
// Deprecated: Will be replaced by RoutingDecision from types.go in T-013.
type LegacyRoutingDecision struct {
RequestedModel string // Original model from request
ResolvedModel string // After model-mappings
Candidates []ProviderCandidate // Ordered list of providers to try
}
// Resolve determines the routing decision for the requested model.
func (r *Router) Resolve(requestedModel string) *RoutingDecision {
// Deprecated: Will be updated to use RoutingRequest and return *RoutingDecision in T-013.
func (r *Router) Resolve(requestedModel string) *LegacyRoutingDecision {
// 1. Extract thinking suffix
suffixResult := thinking.ParseSuffix(requestedModel)
baseModel := suffixResult.ModelName
@@ -60,13 +62,151 @@ func (r *Router) Resolve(requestedModel string) *RoutingDecision {
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
})
return &RoutingDecision{
return &LegacyRoutingDecision{
RequestedModel: requestedModel,
ResolvedModel: targetModel,
Candidates: candidates,
}
}
// ResolveV2 determines the routing decision for a routing request.
// It uses the new RoutingRequest and RoutingDecision types.
func (r *Router) ResolveV2(req RoutingRequest) *RoutingDecision {
// 1. Extract thinking suffix
suffixResult := thinking.ParseSuffix(req.RequestedModel)
baseModel := suffixResult.ModelName
thinkingSuffix := ""
if suffixResult.HasSuffix {
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
}
// 2. Check for local providers
localCandidates := r.findLocalCandidates(baseModel, suffixResult)
// 3. Apply model-mappings if needed
mappedModel := r.applyMappings(baseModel)
mappingCandidates := r.findLocalCandidates(mappedModel, suffixResult)
// 4. Determine route type based on preferences and availability
var decision *RoutingDecision
if req.ForceModelMapping && mappedModel != baseModel && len(mappingCandidates) > 0 {
// FORCE MODE: Use mapping even if local provider exists
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
} else if req.PreferLocalProvider && len(localCandidates) > 0 {
// DEFAULT MODE with local preference: Use local provider first
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
} else if len(localCandidates) > 0 {
// DEFAULT MODE: Local provider available
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
} else if mappedModel != baseModel && len(mappingCandidates) > 0 {
// DEFAULT MODE: No local provider, but mapping available
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
} else {
// No local provider, no mapping - use amp credits proxy
decision = &RoutingDecision{
RouteType: RouteTypeAmpCredits,
ResolvedModel: req.RequestedModel,
ShouldProxy: true,
}
}
return decision
}
// findLocalCandidates finds local provider candidates for a model.
func (r *Router) findLocalCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
var candidates []ProviderCandidate
for _, p := range r.registry.All() {
if !p.SupportsModel(model) {
continue
}
// Apply thinking suffix if needed
actualModel := model
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
actualModel = model + "(" + suffixResult.RawSuffix + ")"
}
if p.Available(actualModel) {
candidates = append(candidates, ProviderCandidate{
Provider: p,
Model: actualModel,
})
}
}
// Sort by priority
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
})
return candidates
}
// buildLocalProviderDecision creates a decision for local provider routing.
func (r *Router) buildLocalProviderDecision(requestedModel string, candidates []ProviderCandidate, thinkingSuffix string) *RoutingDecision {
resolvedModel := requestedModel
if thinkingSuffix != "" {
// Ensure thinking suffix is preserved
sr := thinking.ParseSuffix(requestedModel)
if !sr.HasSuffix {
resolvedModel = requestedModel + thinkingSuffix
}
}
var fallbackModels []string
if len(candidates) > 1 {
for _, c := range candidates[1:] {
fallbackModels = append(fallbackModels, c.Model)
}
}
return &RoutingDecision{
RouteType: RouteTypeLocalProvider,
ResolvedModel: resolvedModel,
ProviderName: candidates[0].Provider.Name(),
FallbackModels: fallbackModels,
ShouldProxy: false,
}
}
// buildMappingDecision creates a decision for model mapping routing.
func (r *Router) buildMappingDecision(requestedModel, mappedModel string, candidates []ProviderCandidate, thinkingSuffix string, fallbackCandidates []ProviderCandidate) *RoutingDecision {
// Apply thinking suffix to resolved model if needed
resolvedModel := mappedModel
if thinkingSuffix != "" {
sr := thinking.ParseSuffix(mappedModel)
if !sr.HasSuffix {
resolvedModel = mappedModel + thinkingSuffix
}
}
var fallbackModels []string
for _, c := range fallbackCandidates {
fallbackModels = append(fallbackModels, c.Model)
}
// Also add oauth aliases as fallbacks
baseMapped := thinking.ParseSuffix(mappedModel).ModelName
for _, alias := range r.oauthAliases[strings.ToLower(baseMapped)] {
// Check if this alias has providers
aliasCandidates := r.findLocalCandidates(alias, thinking.SuffixResult{ModelName: alias})
for _, c := range aliasCandidates {
fallbackModels = append(fallbackModels, c.Model)
}
}
return &RoutingDecision{
RouteType: RouteTypeModelMapping,
ResolvedModel: resolvedModel,
ProviderName: candidates[0].Provider.Name(),
FallbackModels: fallbackModels,
ShouldProxy: false,
}
}
// applyMappings applies model-mappings configuration.
func (r *Router) applyMappings(model string) string {
key := strings.ToLower(strings.TrimSpace(model))

View File

@@ -0,0 +1,245 @@
package routing
import (
"context"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/stretchr/testify/assert"
)
func TestRouter_DefaultMode_PrefersLocal(t *testing.T) {
// Setup: Create a router with a mock provider that supports "gpt-4"
registry := NewRegistry()
mockProvider := &MockProvider{
name: "openai",
supportedModels: []string{"gpt-4"},
available: true,
priority: 1,
}
registry.Register(mockProvider)
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "claude-local"},
},
},
}
router := NewRouter(registry, cfg)
// Test: Request gpt-4 when local provider exists
req := RoutingRequest{
RequestedModel: "gpt-4",
PreferLocalProvider: true,
ForceModelMapping: false,
}
decision := router.ResolveV2(req)
// Assert: Should return LOCAL_PROVIDER, not MODEL_MAPPING
assert.Equal(t, RouteTypeLocalProvider, decision.RouteType)
assert.Equal(t, "gpt-4", decision.ResolvedModel)
assert.Equal(t, "openai", decision.ProviderName)
assert.False(t, decision.ShouldProxy)
}
func TestRouter_DefaultMode_MapsWhenNoLocal(t *testing.T) {
// Setup: Create a router with NO provider for "gpt-4" but a mapping to "claude-local"
// which has a provider
registry := NewRegistry()
mockProvider := &MockProvider{
name: "anthropic",
supportedModels: []string{"claude-local"},
available: true,
priority: 1,
}
registry.Register(mockProvider)
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "claude-local"},
},
},
}
router := NewRouter(registry, cfg)
// Test: Request gpt-4 when no local provider exists, but mapping exists
req := RoutingRequest{
RequestedModel: "gpt-4",
PreferLocalProvider: true,
ForceModelMapping: false,
}
decision := router.ResolveV2(req)
// Assert: Should return MODEL_MAPPING
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
assert.Equal(t, "claude-local", decision.ResolvedModel)
assert.Equal(t, "anthropic", decision.ProviderName)
assert.False(t, decision.ShouldProxy)
}
func TestRouter_DefaultMode_AmpCreditsWhenNoLocalOrMapping(t *testing.T) {
// Setup: Create a router with no providers and no mappings
registry := NewRegistry()
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{},
},
}
router := NewRouter(registry, cfg)
// Test: Request a model with no local provider and no mapping
req := RoutingRequest{
RequestedModel: "unknown-model",
PreferLocalProvider: true,
ForceModelMapping: false,
}
decision := router.ResolveV2(req)
// Assert: Should return AMP_CREDITS with ShouldProxy=true
assert.Equal(t, RouteTypeAmpCredits, decision.RouteType)
assert.Equal(t, "unknown-model", decision.ResolvedModel)
assert.True(t, decision.ShouldProxy)
assert.Empty(t, decision.ProviderName)
}
func TestRouter_ForceMode_MapsEvenWithLocal(t *testing.T) {
// Setup: Create a router with BOTH a local provider for "gpt-4" AND a mapping from "gpt-4" to "claude-local"
// The mapping target "claude-local" also has a provider
registry := NewRegistry()
// Local provider for gpt-4
openaiProvider := &MockProvider{
name: "openai",
supportedModels: []string{"gpt-4"},
available: true,
priority: 1,
}
registry.Register(openaiProvider)
// Local provider for the mapped model
anthropicProvider := &MockProvider{
name: "anthropic",
supportedModels: []string{"claude-local"},
available: true,
priority: 2,
}
registry.Register(anthropicProvider)
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "claude-local"},
},
},
}
router := NewRouter(registry, cfg)
// Test: Request gpt-4 with ForceModelMapping=true
// Even though gpt-4 has a local provider, mapping should take precedence
req := RoutingRequest{
RequestedModel: "gpt-4",
PreferLocalProvider: false,
ForceModelMapping: true,
}
decision := router.ResolveV2(req)
// Assert: Should return MODEL_MAPPING, not LOCAL_PROVIDER
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
assert.Equal(t, "claude-local", decision.ResolvedModel)
assert.Equal(t, "anthropic", decision.ProviderName)
assert.False(t, decision.ShouldProxy)
}
func TestRouter_ThinkingSuffix_Preserved(t *testing.T) {
// Setup: Create a router with mapping and provider for mapped model
registry := NewRegistry()
mockProvider := &MockProvider{
name: "anthropic",
supportedModels: []string{"claude-local"},
available: true,
priority: 1,
}
registry.Register(mockProvider)
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "claude-3-5-sonnet", To: "claude-local"},
},
},
}
router := NewRouter(registry, cfg)
// Test: Request claude-3-5-sonnet with thinking suffix
req := RoutingRequest{
RequestedModel: "claude-3-5-sonnet(thinking:foo)",
PreferLocalProvider: true,
ForceModelMapping: false,
}
decision := router.ResolveV2(req)
// Assert: Thinking suffix should be preserved in resolved model
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
assert.Equal(t, "claude-local(thinking:foo)", decision.ResolvedModel)
assert.Equal(t, "anthropic", decision.ProviderName)
}
// MockProvider is a mock implementation of Provider for testing
type MockProvider struct {
name string
providerType ProviderType
supportedModels []string
available bool
priority int
}
func (m *MockProvider) Name() string {
return m.name
}
func (m *MockProvider) Type() ProviderType {
if m.providerType == "" {
return ProviderTypeOAuth
}
return m.providerType
}
func (m *MockProvider) SupportsModel(model string) bool {
for _, supported := range m.supportedModels {
if supported == model {
return true
}
}
return false
}
func (m *MockProvider) Available(model string) bool {
return m.available
}
func (m *MockProvider) Priority() int {
return m.priority
}
func (m *MockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
return executor.Response{}, nil
}
func (m *MockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
return nil, nil
}

View File

@@ -0,0 +1,113 @@
package testutil
import (
"io"
"net/http"
"github.com/gin-gonic/gin"
)
// FakeHandlerRecorder records handler invocations for testing.
type FakeHandlerRecorder struct {
Called bool
CallCount int
RequestBody []byte
RequestHeader http.Header
ContextKeys map[string]interface{}
ResponseStatus int
ResponseBody []byte
}
// NewFakeHandlerRecorder creates a new fake handler recorder.
func NewFakeHandlerRecorder() *FakeHandlerRecorder {
return &FakeHandlerRecorder{
ContextKeys: make(map[string]interface{}),
ResponseStatus: http.StatusOK,
ResponseBody: []byte(`{"status":"handled"}`),
}
}
// GinHandler returns a gin.HandlerFunc that records the invocation.
func (f *FakeHandlerRecorder) GinHandler() gin.HandlerFunc {
return func(c *gin.Context) {
f.record(c)
c.Data(f.ResponseStatus, "application/json", f.ResponseBody)
}
}
// GinHandlerWithModel returns a gin.HandlerFunc that records the invocation and returns the model from context.
// Useful for testing response rewriting in model mapping scenarios.
func (f *FakeHandlerRecorder) GinHandlerWithModel() gin.HandlerFunc {
return func(c *gin.Context) {
f.record(c)
// Return a response with the model field that would be in the actual API response
// If ResponseBody was explicitly set (not default), use that; otherwise generate from context
var body []byte
if mappedModel, exists := c.Get("mapped_model"); exists {
body = []byte(`{"model":"` + mappedModel.(string) + `","status":"handled"}`)
} else {
body = f.ResponseBody
}
c.Data(f.ResponseStatus, "application/json", body)
}
}
// HTTPHandler returns an http.HandlerFunc that records the invocation.
func (f *FakeHandlerRecorder) HTTPHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
f.Called = true
f.CallCount++
f.RequestBody = body
f.RequestHeader = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(f.ResponseStatus)
w.Write(f.ResponseBody)
}
}
// record captures the request details from gin context.
func (f *FakeHandlerRecorder) record(c *gin.Context) {
f.Called = true
f.CallCount++
body, _ := io.ReadAll(c.Request.Body)
f.RequestBody = body
f.RequestHeader = c.Request.Header.Clone()
// Capture common context keys used by routing
if val, exists := c.Get("mapped_model"); exists {
f.ContextKeys["mapped_model"] = val
}
if val, exists := c.Get("fallback_models"); exists {
f.ContextKeys["fallback_models"] = val
}
if val, exists := c.Get("route_type"); exists {
f.ContextKeys["route_type"] = val
}
}
// Reset clears the recorder state.
func (f *FakeHandlerRecorder) Reset() {
f.Called = false
f.CallCount = 0
f.RequestBody = nil
f.RequestHeader = nil
f.ContextKeys = make(map[string]interface{})
}
// GetContextKey returns a captured context key value.
func (f *FakeHandlerRecorder) GetContextKey(key string) (interface{}, bool) {
val, ok := f.ContextKeys[key]
return val, ok
}
// WasCalled returns true if the handler was called.
func (f *FakeHandlerRecorder) WasCalled() bool {
return f.Called
}
// GetCallCount returns the number of times the handler was called.
func (f *FakeHandlerRecorder) GetCallCount() int {
return f.CallCount
}

View File

@@ -0,0 +1,83 @@
package testutil
import (
"io"
"net/http"
"net/http/httptest"
)
// CloseNotifierRecorder wraps httptest.ResponseRecorder with CloseNotify support.
// This is needed because ReverseProxy requires http.CloseNotifier.
type CloseNotifierRecorder struct {
*httptest.ResponseRecorder
closeChan chan bool
}
// NewCloseNotifierRecorder creates a ResponseRecorder that implements CloseNotifier.
func NewCloseNotifierRecorder() *CloseNotifierRecorder {
return &CloseNotifierRecorder{
ResponseRecorder: httptest.NewRecorder(),
closeChan: make(chan bool, 1),
}
}
// CloseNotify implements http.CloseNotifier.
func (c *CloseNotifierRecorder) CloseNotify() <-chan bool {
return c.closeChan
}
// FakeProxyRecorder records proxy invocations for testing.
type FakeProxyRecorder struct {
Called bool
CallCount int
RequestBody []byte
RequestHeaders http.Header
ResponseStatus int
ResponseBody []byte
}
// NewFakeProxyRecorder creates a new fake proxy recorder.
func NewFakeProxyRecorder() *FakeProxyRecorder {
return &FakeProxyRecorder{
ResponseStatus: http.StatusOK,
ResponseBody: []byte(`{"status":"proxied"}`),
}
}
// ServeHTTP implements http.Handler to act as a reverse proxy.
func (f *FakeProxyRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
f.Called = true
f.CallCount++
f.RequestHeaders = r.Header.Clone()
body, err := io.ReadAll(r.Body)
if err == nil {
f.RequestBody = body
}
w.WriteHeader(f.ResponseStatus)
w.Write(f.ResponseBody)
}
// GetCallCount returns the number of times the proxy was called.
func (f *FakeProxyRecorder) GetCallCount() int {
return f.CallCount
}
// Reset clears the recorder state.
func (f *FakeProxyRecorder) Reset() {
f.Called = false
f.CallCount = 0
f.RequestBody = nil
f.RequestHeaders = nil
}
// ToHandler returns the recorder as an http.Handler for use with httptest.
func (f *FakeProxyRecorder) ToHandler() http.Handler {
return http.HandlerFunc(f.ServeHTTP)
}
// CreateTestServer creates an httptest server with this fake proxy.
func (f *FakeProxyRecorder) CreateTestServer() *httptest.Server {
return httptest.NewServer(f.ToHandler())
}

62
internal/routing/types.go Normal file
View File

@@ -0,0 +1,62 @@
package routing
// RouteType represents the type of routing decision made for a request.
type RouteType string
const (
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free).
RouteTypeLocalProvider RouteType = "LOCAL_PROVIDER"
// RouteTypeModelMapping indicates the request was remapped to another available model (free).
RouteTypeModelMapping RouteType = "MODEL_MAPPING"
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits).
RouteTypeAmpCredits RouteType = "AMP_CREDITS"
// RouteTypeNoProvider indicates no provider or fallback available.
RouteTypeNoProvider RouteType = "NO_PROVIDER"
)
// RoutingRequest contains the information needed to make a routing decision.
type RoutingRequest struct {
// RequestedModel is the model name from the incoming request.
RequestedModel string
// PreferLocalProvider indicates whether to prefer local providers over mappings.
// When true, check local providers first before applying model mappings.
PreferLocalProvider bool
// ForceModelMapping indicates whether to force model mapping even if local provider exists.
// When true, apply model mappings first and skip local provider checks.
ForceModelMapping bool
}
// RoutingDecision contains the result of a routing decision.
type RoutingDecision struct {
// RouteType indicates the type of routing decision.
RouteType RouteType
// ResolvedModel is the final model name after any mappings.
ResolvedModel string
// ProviderName is the name of the selected provider (if any).
ProviderName string
// FallbackModels is a list of alternative models to try if the primary fails.
FallbackModels []string
// ShouldProxy indicates whether the request should be proxied to ampcode.com.
ShouldProxy bool
}
// NewRoutingDecision creates a new RoutingDecision with the given parameters.
func NewRoutingDecision(routeType RouteType, resolvedModel, providerName string, fallbackModels []string, shouldProxy bool) *RoutingDecision {
return &RoutingDecision{
RouteType: routeType,
ResolvedModel: resolvedModel,
ProviderName: providerName,
FallbackModels: fallbackModels,
ShouldProxy: shouldProxy,
}
}
// IsLocal returns true if the decision routes to a local provider.
func (d *RoutingDecision) IsLocal() bool {
return d.RouteType == RouteTypeLocalProvider || d.RouteType == RouteTypeModelMapping
}
// HasFallbacks returns true if there are fallback models available.
func (d *RoutingDecision) HasFallbacks() bool {
return len(d.FallbackModels) > 0
}

270
internal/routing/wrapper.go Normal file
View File

@@ -0,0 +1,270 @@
package routing
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
"github.com/sirupsen/logrus"
)
// ProxyFunc is the function type for proxying requests.
type ProxyFunc func(c *gin.Context)
// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic.
// It replaces the FallbackHandler logic with a Router-based approach.
type ModelRoutingWrapper struct {
router *Router
extractor ModelExtractor
rewriter ModelRewriter
proxyFunc ProxyFunc
logger *logrus.Logger
}
// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies.
// If extractor is nil, a DefaultModelExtractor is used.
// If rewriter is nil, a DefaultModelRewriter is used.
// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead.
func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper {
if extractor == nil {
extractor = NewModelExtractor()
}
if rewriter == nil {
rewriter = NewModelRewriter()
}
return &ModelRoutingWrapper{
router: router,
extractor: extractor,
rewriter: rewriter,
proxyFunc: proxyFunc,
logger: logrus.New(),
}
}
// SetLogger sets the logger for the wrapper.
func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) {
w.logger = logger
}
// Wrap wraps a gin.HandlerFunc with model routing logic.
// The returned handler will:
// 1. Extract the model from the request
// 2. Get a routing decision from the Router
// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS)
func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Read request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
w.logger.Errorf("routing wrapper: failed to read request body: %v", err)
handler(c)
return
}
// Extract model from request
ginParams := map[string]string{
"action": c.Param("action"),
"path": c.Param("path"),
}
modelName, err := w.extractor.Extract(bodyBytes, ginParams)
if err != nil {
w.logger.Warnf("routing wrapper: failed to extract model: %v", err)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
return
}
if modelName == "" {
// No model found, proceed with original handler
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
return
}
// Get routing decision
req := RoutingRequest{
RequestedModel: modelName,
PreferLocalProvider: true,
ForceModelMapping: false, // TODO: Get from config
}
decision := w.router.ResolveV2(req)
// Store decision in context for downstream handlers
c.Set(string(ctxkeys.RoutingDecision), decision)
// Handle based on route type
switch decision.RouteType {
case RouteTypeLocalProvider:
w.handleLocalProvider(c, handler, bodyBytes, decision)
case RouteTypeModelMapping:
w.handleModelMapping(c, handler, bodyBytes, decision)
case RouteTypeAmpCredits:
w.handleAmpCredits(c, handler, bodyBytes)
default:
// No provider available
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
}
}
}
// handleLocalProvider handles the LOCAL_PROVIDER route type.
func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
// Filter Anthropic-Beta header for local provider
filterAnthropicBetaHeader(c)
// Restore body with original content
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Call handler
handler(c)
}
// handleModelMapping handles the MODEL_MAPPING route type.
func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
// Rewrite request body with mapped model
rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel)
if err != nil {
w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err)
rewrittenBody = bodyBytes
}
_ = rewrittenBody
// Store mapped model in context
c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel)
// Store fallback models in context if present
if len(decision.FallbackModels) > 0 {
c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels)
}
// Filter Anthropic-Beta header for local provider
filterAnthropicBetaHeader(c)
// Restore body with rewritten content
c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody))
// Wrap response writer to rewrite model back
wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel)
c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer}
// Call handler
handler(c)
// Cleanup (flush response rewriting)
cleanup()
}
// handleAmpCredits handles the AMP_CREDITS route type.
// It calls the proxy function directly if available, otherwise passes to handler.
// Does NOT filter headers or rewrite body - proxy handles everything.
func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) {
// Restore body with original content (no rewriting for proxy)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Call proxy function if available, otherwise fall back to handler
if w.proxyFunc != nil {
w.proxyFunc(c)
} else {
handler(c)
}
}
// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers.
func filterAnthropicBetaHeader(c *gin.Context) {
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
if filtered != "" {
c.Request.Header.Set("Anthropic-Beta", filtered)
} else {
c.Request.Header.Del("Anthropic-Beta")
}
}
}
// filterBetaFeatures removes specified beta features from the header.
func filterBetaFeatures(betaHeader, featureToRemove string) string {
// Simple implementation - can be enhanced
if betaHeader == featureToRemove {
return ""
}
return betaHeader
}
// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter.
type ginResponseWriterAdapter struct {
http.ResponseWriter
original gin.ResponseWriter
}
func (a *ginResponseWriterAdapter) WriteHeader(code int) {
a.ResponseWriter.WriteHeader(code)
}
func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) {
return a.ResponseWriter.Write(data)
}
func (a *ginResponseWriterAdapter) Header() http.Header {
return a.ResponseWriter.Header()
}
// CloseNotify implements http.CloseNotifier.
func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool {
if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok {
return notifier.CloseNotify()
}
return a.original.CloseNotify()
}
// Flush implements http.Flusher.
func (a *ginResponseWriterAdapter) Flush() {
if flusher, ok := a.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// Hijack implements http.Hijacker.
func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return a.original.Hijack()
}
// Status returns the HTTP status code.
func (a *ginResponseWriterAdapter) Status() int {
return a.original.Status()
}
// Size returns the number of bytes already written into the response http body.
func (a *ginResponseWriterAdapter) Size() int {
return a.original.Size()
}
// Written returns whether or not the response for this context has been written.
func (a *ginResponseWriterAdapter) Written() bool {
return a.original.Written()
}
// WriteHeaderNow forces WriteHeader to be called.
func (a *ginResponseWriterAdapter) WriteHeaderNow() {
a.original.WriteHeaderNow()
}
// WriteString writes the given string into the response body.
func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) {
return a.Write([]byte(s))
}
// Pusher returns the http.Pusher for server push.
func (a *ginResponseWriterAdapter) Pusher() http.Pusher {
if pusher, ok := a.ResponseWriter.(http.Pusher); ok {
return pusher
}
return nil
}