diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 3b32d25e..f46af1c0 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -86,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI +// +// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper. +// Use routing.NewModelRoutingWrapper() instead for unified routing logic. +// This type is kept for backward compatibility and test purposes. type FallbackHandler struct { getProxy func() *httputil.ReverseProxy modelMapper ModelMapper @@ -94,6 +98,8 @@ type FallbackHandler struct { // NewFallbackHandler creates a new fallback handler wrapper // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) +// +// Deprecated: Use routing.NewModelRoutingWrapper() instead. func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { return &FallbackHandler{ getProxy: getProxy, @@ -102,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler } // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support +// +// Deprecated: Use routing.NewModelRoutingWrapper() instead. func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler { if forceModelMappings == nil { forceModelMappings = func() bool { return false } diff --git a/internal/api/modules/amp/fallback_handlers_characterization_test.go b/internal/api/modules/amp/fallback_handlers_characterization_test.go new file mode 100644 index 00000000..e52bc5ce --- /dev/null +++ b/internal/api/modules/amp/fallback_handlers_characterization_test.go @@ -0,0 +1,326 @@ +package amp + +import ( + "bytes" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil" + "github.com/stretchr/testify/assert" +) + +// Characterization tests for fallback_handlers.go using testutil recorders +// These tests capture existing behavior before refactoring to routing layer + +func TestCharacterization_LocalProvider(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the test model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{ + {ID: "test-model-local"}, + }) + defer reg.UnregisterClient("char-test-local") + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create gin context + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Create fallback handler with proxy recorder + // Create a test server to act as the proxy target + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + // Create a reverse proxy that forwards to our test server + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Execute + wrapped := fh.WrapHandler(handlerRecorder.GinHandler()) + wrapped(c) + + // Assert: proxy NOT called + assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider") + + // Assert: local handler called once + assert.True(t, handlerRecorder.WasCalled(), "local handler should be called") + assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once") + + // Assert: request body model unchanged + assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged") +} + +func TestCharacterization_ModelMapping(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the TARGET model (the mapped-to model) + reg := registry.GetGlobalRegistry() + reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{ + {ID: "gpt-4-local"}, + }) + defer reg.UnregisterClient("char-test-mapped") + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create model mapper with a mapping + mapper := NewModelMapper([]config.AmpModelMapping{ + {From: "gpt-4-turbo", To: "gpt-4-local"}, + }) + + // Create gin context + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Request with original model that gets mapped + body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Create fallback handler with mapper + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }, mapper, func() bool { return false }) + + // Execute - use handler that returns model in response for rewriter to work + wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel()) + wrapped(c) + + // Assert: proxy NOT called + assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping") + + // Assert: local handler called once + assert.True(t, handlerRecorder.WasCalled(), "local handler should be called") + assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once") + + // Assert: request body model was rewritten to mapped model + assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model") + assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model") + + // Assert: context has mapped_model key set + mappedModel, exists := handlerRecorder.GetContextKey("mapped_model") + assert.True(t, exists, "context should have mapped_model key") + assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model") + + // Assert: response body model rewritten back to original + // The response writer should rewrite model names in the response + responseBody := w.Body.String() + assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name") +} + +func TestCharacterization_AmpCreditsProxy(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Setup recorders - NO local provider registered, NO mapping configured + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create gin context with CloseNotifier support (required for ReverseProxy) + w := testutil.NewCloseNotifierRecorder() + c, _ := gin.CreateTestContext(w) + + // Request with a model that has no local provider and no mapping + body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Create fallback handler + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Execute + wrapped := fh.WrapHandler(handlerRecorder.GinHandler()) + wrapped(c) + + // Assert: proxy called once + assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping") + assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once") + + // Assert: local handler NOT called + assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy") + + // Assert: body forwarded to proxy is original (no rewrite) + assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying") +} + +func TestCharacterization_BodyRestore(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the test model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{ + {ID: "test-model-body"}, + }) + defer reg.UnregisterClient("char-test-body") + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create gin context + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Create a complex request body that will be read by the wrapper for model extraction + originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Create fallback handler with proxy recorder + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Execute + wrapped := fh.WrapHandler(handlerRecorder.GinHandler()) + wrapped(c) + + // Assert: local handler called (not proxy, since we have a local provider) + assert.True(t, handlerRecorder.WasCalled(), "local handler should be called") + assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider") + + // Assert: handler receives complete original body + // This verifies that the body was properly restored after the wrapper read it for model extraction + assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction") +} + +// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler +// This is a characterization test for the route gating logic in routes.go +func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the test model (Gemini format uses path-based model extraction) + reg := registry.GetGlobalRegistry() + reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{ + {ID: "gemini-pro"}, + }) + defer reg.UnregisterClient("char-test-gemini") + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create a test server for the proxy + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + // Create fallback handler + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Create the Gemini bridge handler (simulating what routes.go does) + geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler()) + geminiV1Beta1Handler := fh.WrapHandler(geminiBridge) + + // Create router with the same gating logic as routes.go + r := gin.New() + r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) { + if c.Request.Method == "POST" { + if path := c.Param("path"); strings.Contains(path, "/models/") { + // POST with /models/ path -> use Gemini bridge with fallback handler + geminiV1Beta1Handler(c) + return + } + } + // Non-POST or no /models/ in path -> proxy upstream + proxyRecorder.ServeHTTP(c.Writer, c.Request) + }) + + // Execute: POST request with /models/ in path + body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Assert: local Gemini handler called + assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/") + + // Assert: proxy NOT called + assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path") +} + +// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy +// This is a characterization test for the route gating logic in routes.go +func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create a test server for the proxy + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + // Create fallback handler + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Create the Gemini bridge handler + geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler()) + geminiV1Beta1Handler := fh.WrapHandler(geminiBridge) + + // Create router with the same gating logic as routes.go + r := gin.New() + r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) { + if c.Request.Method == "POST" { + if path := c.Param("path"); strings.Contains(path, "/models/") { + geminiV1Beta1Handler(c) + return + } + } + proxyRecorder.ServeHTTP(c.Writer, c.Request) + }) + + // Execute: GET request (even with /models/ in path) + req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Assert: proxy called + assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests") + assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once") + + // Assert: local handler NOT called + assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests") +} diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index 24b8cdcc..b8d47432 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -276,6 +276,22 @@ func (m *DefaultModelMapper) GetMappings() map[string]string { return result } +// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice. +// Safe for concurrent use. +func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]config.AmpModelMapping, 0, len(m.mappings)) + for from, to := range m.mappings { + result = append(result, config.AmpModelMapping{ + From: from, + To: to, + }) + } + return result +} + type regexMapping struct { re *regexp.Regexp to string diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 456a50ac..790a3cce 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -5,11 +5,12 @@ import ( "errors" "net" "net/http" - "net/http/httputil" "strings" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini" @@ -234,19 +235,20 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha // If no local OAuth is available, falls back to ampcode.com proxy. geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) - geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { - return m.getProxy() - }, m.modelMapper, m.forceModelMappings) - geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) - // Route POST model calls through Gemini bridge with FallbackHandler. - // FallbackHandler checks provider -> mapping -> proxy fallback automatically. + // T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper + // Create a dedicated routing wrapper for the Gemini bridge + geminiBridgeWrapper := m.createModelRoutingWrapper() + geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge) + + // Route POST model calls through Gemini bridge with ModelRoutingWrapper. + // ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically. // All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior. ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) { if c.Request.Method == "POST" { if path := c.Param("path"); strings.Contains(path, "/models/") { - // POST with /models/ path -> use Gemini bridge with fallback handler - // FallbackHandler will check provider/mapping and proxy if needed + // POST with /models/ path -> use Gemini bridge with unified routing wrapper + // ModelRoutingWrapper will check provider/mapping and proxy if needed geminiV1Beta1Handler(c) return } @@ -256,6 +258,41 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha }) } +// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing. +// This is used for testing the new routing implementation (T-021 onwards). +func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper { + // Create a registry - in production this would be populated with actual providers + registry := routing.NewRegistry() + + // Create a minimal config with just AmpCode settings + // The Router only needs AmpCode.ModelMappings and OAuthModelAlias + cfg := &config.Config{ + AmpCode: func() config.AmpCode { + if m.modelMapper != nil { + return config.AmpCode{ + ModelMappings: m.modelMapper.GetMappingsAsConfig(), + } + } + return config.AmpCode{} + }(), + } + + // Create router with registry and config + router := routing.NewRouter(registry, cfg) + + // Create wrapper with proxy function + proxyFunc := func(c *gin.Context) { + proxy := m.getProxy() + if proxy != nil { + proxy.ServeHTTP(c.Writer, c.Request) + } else { + c.JSON(503, gin.H{"error": "amp upstream proxy not available"}) + } + } + + return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc) +} + // registerProviderAliases registers /api/provider/{provider}/... routes // These allow Amp CLI to route requests like: // @@ -269,12 +306,9 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler) openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) - // Create fallback handler wrapper that forwards to ampcode.com when provider not found - // Uses m.getProxy() for hot-reload support (proxy can be updated at runtime) - // Also includes model mapping support for routing unavailable models to alternatives - fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { - return m.getProxy() - }, m.modelMapper, m.forceModelMappings) + // Create unified routing wrapper (T-021 onwards) + // Replaces FallbackHandler with Router-based unified routing + routingWrapper := m.createModelRoutingWrapper() // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") @@ -302,33 +336,36 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han } // Root-level routes (for providers that omit /v1, like groq/cerebras) - // Wrap handlers with fallback logic to forward to ampcode.com when provider not found + // T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check) - provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions)) - provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions)) - provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses)) + provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions)) + provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions)) + provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses)) // /v1 routes (OpenAI/Claude-compatible endpoints) v1Amp := provider.Group("/v1") { v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback - // OpenAI-compatible endpoints with fallback - v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions)) - v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions)) - v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses)) + // OpenAI-compatible endpoints with ModelRoutingWrapper + // T-021, T-022: Migrated to unified routing wrapper + v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions)) + v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions)) + v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses)) - // Claude/Anthropic-compatible endpoints with fallback - v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages)) - v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens)) + // Claude/Anthropic-compatible endpoints with ModelRoutingWrapper + // T-023: Migrated Claude routes to unified routing wrapper + v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages)) + v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens)) } // /v1beta routes (Gemini native API) // Note: Gemini handler extracts model from URL path, so fallback logic needs special handling + // T-024: Migrated Gemini v1beta routes to unified routing wrapper v1betaAmp := provider.Group("/v1beta") { v1betaAmp.GET("/models", geminiHandlers.GeminiModels) - v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler)) + v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler)) v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler) } } diff --git a/internal/routing/extractor.go b/internal/routing/extractor.go new file mode 100644 index 00000000..94fe969a --- /dev/null +++ b/internal/routing/extractor.go @@ -0,0 +1,59 @@ +package routing + +import ( + "strings" + + "github.com/tidwall/gjson" +) + +// ModelExtractor extracts model names from request data. +type ModelExtractor interface { + // Extract returns the model name from the request body and gin parameters. + // The ginParams map contains route parameters like "action" and "path". + Extract(body []byte, ginParams map[string]string) (string, error) +} + +// DefaultModelExtractor is the standard implementation of ModelExtractor. +type DefaultModelExtractor struct{} + +// NewModelExtractor creates a new DefaultModelExtractor. +func NewModelExtractor() *DefaultModelExtractor { + return &DefaultModelExtractor{} +} + +// Extract extracts the model name from the request. +// It checks in order: +// 1. JSON body "model" field (OpenAI, Claude format) +// 2. "action" parameter for Gemini standard format (e.g., "gemini-pro:generateContent") +// 3. "path" parameter for AMP CLI Gemini format (e.g., "/publishers/google/models/gemini-3-pro:streamGenerateContent") +func (e *DefaultModelExtractor) Extract(body []byte, ginParams map[string]string) (string, error) { + // First try to parse from JSON body (OpenAI, Claude, etc.) + if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String { + return result.String(), nil + } + + // For Gemini requests, model is in the URL path + // Standard format: /models/{model}:generateContent -> :action parameter + if action, ok := ginParams["action"]; ok && action != "" { + // Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro") + parts := strings.Split(action, ":") + if len(parts) > 0 && parts[0] != "" { + return parts[0], nil + } + } + + // AMP CLI format: /publishers/google/models/{model}:method -> *path parameter + // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent + if path, ok := ginParams["path"]; ok && path != "" { + // Look for /models/{model}:method pattern + if idx := strings.Index(path, "/models/"); idx >= 0 { + modelPart := path[idx+8:] // Skip "/models/" + // Split by colon to get model name + if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 { + return modelPart[:colonIdx], nil + } + } + } + + return "", nil +} diff --git a/internal/routing/extractor_test.go b/internal/routing/extractor_test.go new file mode 100644 index 00000000..485b4831 --- /dev/null +++ b/internal/routing/extractor_test.go @@ -0,0 +1,214 @@ +package routing + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestModelExtractor_ExtractFromJSONBody(t *testing.T) { + extractor := NewModelExtractor() + + tests := []struct { + name string + body []byte + want string + wantErr bool + }{ + { + name: "extract from JSON body with model field", + body: []byte(`{"model":"gpt-4.1"}`), + want: "gpt-4.1", + }, + { + name: "extract claude model from JSON body", + body: []byte(`{"model":"claude-3-5-sonnet-20241022"}`), + want: "claude-3-5-sonnet-20241022", + }, + { + name: "extract with additional fields", + body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`), + want: "gpt-4", + }, + { + name: "empty body returns empty", + body: []byte{}, + want: "", + }, + { + name: "no model field returns empty", + body: []byte(`{"messages":[]}`), + want: "", + }, + { + name: "model is not string returns empty", + body: []byte(`{"model":123}`), + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractor.Extract(tt.body, nil) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestModelExtractor_ExtractFromGeminiActionParam(t *testing.T) { + extractor := NewModelExtractor() + + tests := []struct { + name string + body []byte + ginParams map[string]string + want string + }{ + { + name: "extract from action parameter - gemini-pro", + body: []byte(`{}`), + ginParams: map[string]string{"action": "gemini-pro:generateContent"}, + want: "gemini-pro", + }, + { + name: "extract from action parameter - gemini-ultra", + body: []byte(`{}`), + ginParams: map[string]string{"action": "gemini-ultra:chat"}, + want: "gemini-ultra", + }, + { + name: "empty action returns empty", + body: []byte(`{}`), + ginParams: map[string]string{"action": ""}, + want: "", + }, + { + name: "action without colon returns full value", + body: []byte(`{}`), + ginParams: map[string]string{"action": "gemini-model"}, + want: "gemini-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractor.Extract(tt.body, tt.ginParams) + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestModelExtractor_ExtractFromGeminiV1Beta1Path(t *testing.T) { + extractor := NewModelExtractor() + + tests := []struct { + name string + body []byte + ginParams map[string]string + want string + }{ + { + name: "extract from v1beta1 path - gemini-3-pro", + body: []byte(`{}`), + ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro:streamGenerateContent"}, + want: "gemini-3-pro", + }, + { + name: "extract from v1beta1 path with preview", + body: []byte(`{}`), + ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro-preview:generateContent"}, + want: "gemini-3-pro-preview", + }, + { + name: "path without models segment returns empty", + body: []byte(`{}`), + ginParams: map[string]string{"path": "/publishers/google/gemini-3-pro:streamGenerateContent"}, + want: "", + }, + { + name: "empty path returns empty", + body: []byte(`{}`), + ginParams: map[string]string{"path": ""}, + want: "", + }, + { + name: "path with /models/ but no colon returns empty", + body: []byte(`{}`), + ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro"}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractor.Extract(tt.body, tt.ginParams) + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestModelExtractor_ExtractPriority(t *testing.T) { + extractor := NewModelExtractor() + + // JSON body takes priority over gin params + t.Run("JSON body takes priority over action param", func(t *testing.T) { + body := []byte(`{"model":"gpt-4"}`) + params := map[string]string{"action": "gemini-pro:generateContent"} + got, err := extractor.Extract(body, params) + assert.NoError(t, err) + assert.Equal(t, "gpt-4", got) + }) + + // Action param takes priority over path param + t.Run("action param takes priority over path param", func(t *testing.T) { + body := []byte(`{}`) + params := map[string]string{ + "action": "gemini-action:generate", + "path": "/publishers/google/models/gemini-path:streamGenerateContent", + } + got, err := extractor.Extract(body, params) + assert.NoError(t, err) + assert.Equal(t, "gemini-action", got) + }) +} + +func TestModelExtractor_NoModelFound(t *testing.T) { + extractor := NewModelExtractor() + + tests := []struct { + name string + body []byte + ginParams map[string]string + }{ + { + name: "empty body and no params", + body: []byte{}, + ginParams: nil, + }, + { + name: "body without model and no params", + body: []byte(`{"messages":[]}`), + ginParams: map[string]string{}, + }, + { + name: "irrelevant params only", + body: []byte(`{}`), + ginParams: map[string]string{"other": "value"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractor.Extract(tt.body, tt.ginParams) + assert.NoError(t, err) + assert.Empty(t, got) + }) + } +} diff --git a/internal/routing/rewriter.go b/internal/routing/rewriter.go new file mode 100644 index 00000000..d0c02771 --- /dev/null +++ b/internal/routing/rewriter.go @@ -0,0 +1,159 @@ +package routing + +import ( + "bytes" + "net/http" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + log "github.com/sirupsen/logrus" +) + +// ModelRewriter handles model name rewriting in requests and responses. +type ModelRewriter interface { + // RewriteRequestBody rewrites the model field in a JSON request body. + // Returns the modified body or the original if no rewrite was needed. + RewriteRequestBody(body []byte, newModel string) ([]byte, error) + + // WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response. + // Returns the wrapped writer and a cleanup function that must be called after the response is complete. + WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) +} + +// DefaultModelRewriter is the standard implementation of ModelRewriter. +type DefaultModelRewriter struct{} + +// NewModelRewriter creates a new DefaultModelRewriter. +func NewModelRewriter() *DefaultModelRewriter { + return &DefaultModelRewriter{} +} + +// RewriteRequestBody replaces the model name in a JSON request body. +func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) { + if !gjson.GetBytes(body, "model").Exists() { + return body, nil + } + result, err := sjson.SetBytes(body, "model", newModel) + if err != nil { + return body, err + } + return result, nil +} + +// WrapResponseWriter wraps a response writer to rewrite model names. +// The cleanup function must be called after the handler completes to flush any buffered data. +func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) { + rw := &responseRewriter{ + ResponseWriter: w, + body: &bytes.Buffer{}, + requestedModel: requestedModel, + resolvedModel: resolvedModel, + } + return rw, func() { rw.flush() } +} + +// responseRewriter wraps http.ResponseWriter to intercept and modify the response body. +type responseRewriter struct { + http.ResponseWriter + body *bytes.Buffer + requestedModel string + resolvedModel string + isStreaming bool + wroteHeader bool + flushed bool +} + +// Write intercepts response writes and buffers them for model name replacement. +func (rw *responseRewriter) Write(data []byte) (int, error) { + // Ensure header is written + if !rw.wroteHeader { + rw.WriteHeader(http.StatusOK) + } + + // Detect streaming on first write + if rw.body.Len() == 0 && !rw.isStreaming { + contentType := rw.Header().Get("Content-Type") + rw.isStreaming = strings.Contains(contentType, "text/event-stream") || + strings.Contains(contentType, "stream") + } + + if rw.isStreaming { + n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) + if err == nil { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + } + return n, err + } + return rw.body.Write(data) +} + +// WriteHeader captures the status code and delegates to the underlying writer. +func (rw *responseRewriter) WriteHeader(code int) { + if !rw.wroteHeader { + rw.wroteHeader = true + rw.ResponseWriter.WriteHeader(code) + } +} + +// flush writes the buffered response with model names rewritten. +func (rw *responseRewriter) flush() { + if rw.flushed { + return + } + rw.flushed = true + + if rw.isStreaming { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + return + } + if rw.body.Len() > 0 { + data := rw.rewriteModelInResponse(rw.body.Bytes()) + if _, err := rw.ResponseWriter.Write(data); err != nil { + log.Warnf("response rewriter: failed to write rewritten response: %v", err) + } + } +} + +// modelFieldPaths lists all JSON paths where model name may appear. +var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"} + +// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model. +func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte { + if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel { + return data + } + + for _, path := range modelFieldPaths { + if gjson.GetBytes(data, path).Exists() { + data, _ = sjson.SetBytes(data, path, rw.requestedModel) + } + } + return data +} + +// rewriteStreamChunk rewrites model names in SSE stream chunks. +func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte { + if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel { + return chunk + } + + // SSE format: "data: {json}\n\n" + lines := bytes.Split(chunk, []byte("\n")) + for i, line := range lines { + if bytes.HasPrefix(line, []byte("data: ")) { + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) > 0 && jsonData[0] == '{' { + // Rewrite JSON in the data line + rewritten := rw.rewriteModelInResponse(jsonData) + lines[i] = append([]byte("data: "), rewritten...) + } + } + } + + return bytes.Join(lines, []byte("\n")) +} diff --git a/internal/routing/rewriter_test.go b/internal/routing/rewriter_test.go new file mode 100644 index 00000000..d628f710 --- /dev/null +++ b/internal/routing/rewriter_test.go @@ -0,0 +1,342 @@ +package routing + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestModelRewriter_RewriteRequestBody(t *testing.T) { + rewriter := NewModelRewriter() + + tests := []struct { + name string + body []byte + newModel string + wantModel string + wantChange bool + }{ + { + name: "rewrites model field in JSON body", + body: []byte(`{"model":"gpt-4.1","messages":[]}`), + newModel: "claude-local", + wantModel: "claude-local", + wantChange: true, + }, + { + name: "rewrites with empty body returns empty", + body: []byte{}, + newModel: "gpt-4", + wantModel: "", + wantChange: false, + }, + { + name: "handles missing model field gracefully", + body: []byte(`{"messages":[{"role":"user"}]}`), + newModel: "gpt-4", + wantModel: "", + wantChange: false, + }, + { + name: "preserves other fields when rewriting", + body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`), + newModel: "new-model", + wantModel: "new-model", + wantChange: true, + }, + { + name: "handles nested JSON structure", + body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`), + newModel: "claude-3-opus", + wantModel: "claude-3-opus", + wantChange: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel) + require.NoError(t, err) + + if tt.wantChange { + assert.NotEqual(t, string(tt.body), string(result), "body should have been modified") + } + + if tt.wantModel != "" { + // Parse result and check model field + model, _ := NewModelExtractor().Extract(result, nil) + assert.Equal(t, tt.wantModel, model) + } + }) + } +} + +func TestModelRewriter_WrapResponseWriter(t *testing.T) { + rewriter := NewModelRewriter() + + t.Run("response writer wraps without error", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + require.NotNil(t, wrapped) + require.NotNil(t, cleanup) + defer cleanup() + }) + + t.Run("rewrites model in non-streaming response", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + // Write a response with the resolved model + response := []byte(`{"model":"claude-local","content":"hello"}`) + wrapped.Header().Set("Content-Type", "application/json") + _, err := wrapped.Write(response) + require.NoError(t, err) + + // Cleanup triggers the rewrite + cleanup() + + // Check the response was rewritten to the requested model + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"model":"gpt-4"`) + assert.NotContains(t, string(body), `"model":"claude-local"`) + }) + + t.Run("no-op when requested equals resolved", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4") + + response := []byte(`{"model":"gpt-4","content":"hello"}`) + wrapped.Header().Set("Content-Type", "application/json") + _, err := wrapped.Write(response) + require.NoError(t, err) + + cleanup() + + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"model":"gpt-4"`) + }) + + t.Run("rewrites modelVersion field", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"modelVersion":"claude-local","content":"hello"}`) + wrapped.Header().Set("Content-Type", "application/json") + _, err := wrapped.Write(response) + require.NoError(t, err) + + cleanup() + + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"modelVersion":"gpt-4"`) + }) + + t.Run("handles streaming responses", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + // Set streaming content type + wrapped.Header().Set("Content-Type", "text/event-stream") + + // Write SSE chunks with resolved model + chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n") + _, err := wrapped.Write(chunk1) + require.NoError(t, err) + + chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n") + _, err = wrapped.Write(chunk2) + require.NoError(t, err) + + cleanup() + + // For streaming, data is written immediately with rewrites + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"model":"gpt-4"`) + assert.NotContains(t, string(body), `"model":"claude-local"`) + }) + + t.Run("empty body handled gracefully", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + wrapped.Header().Set("Content-Type", "application/json") + // Don't write anything + + cleanup() + + body := recorder.Body.Bytes() + assert.Empty(t, body) + }) + + t.Run("preserves other JSON fields", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`) + wrapped.Header().Set("Content-Type", "application/json") + _, err := wrapped.Write(response) + require.NoError(t, err) + + cleanup() + + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"temperature":0.7`) + assert.Contains(t, string(body), `"prompt_tokens":10`) + }) +} + +func TestResponseRewriter_ImplementsInterfaces(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + defer cleanup() + + // Should implement http.ResponseWriter + assert.Implements(t, (*http.ResponseWriter)(nil), wrapped) + + // Should preserve header access + wrapped.Header().Set("X-Custom", "value") + assert.Equal(t, "value", recorder.Header().Get("X-Custom")) + + // Should write status + wrapped.WriteHeader(http.StatusCreated) + assert.Equal(t, http.StatusCreated, recorder.Code) +} + +func TestResponseRewriter_Flush(t *testing.T) { + t.Run("flush writes buffered content", func(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"model":"claude-local","content":"test"}`) + wrapped.Header().Set("Content-Type", "application/json") + wrapped.Write(response) + + // Before cleanup, response should be empty (buffered) + assert.Empty(t, recorder.Body.Bytes()) + + // After cleanup, response should be written + cleanup() + assert.NotEmpty(t, recorder.Body.Bytes()) + }) + + t.Run("multiple flush calls are safe", func(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"model":"claude-local"}`) + wrapped.Header().Set("Content-Type", "application/json") + wrapped.Write(response) + + // First cleanup + cleanup() + firstBody := recorder.Body.Bytes() + + // Second cleanup should not write again + cleanup() + secondBody := recorder.Body.Bytes() + + assert.Equal(t, firstBody, secondBody) + }) +} + +func TestResponseRewriter_StreamingWithDataLines(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + wrapped.Header().Set("Content-Type", "text/event-stream") + + // SSE format with multiple data lines + chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n") + wrapped.Write(chunk) + + cleanup() + + body := recorder.Body.Bytes() + // Both data lines should have model rewritten + assert.Contains(t, string(body), `"model":"gpt-4"`) + assert.NotContains(t, string(body), `"model":"claude-local"`) +} + +func TestModelRewriter_RoundTrip(t *testing.T) { + // Simulate a full request -> response cycle with model rewriting + rewriter := NewModelRewriter() + + // Step 1: Rewrite request body + originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`) + rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local") + require.NoError(t, err) + + // Verify request was rewritten + extractor := NewModelExtractor() + requestModel, _ := extractor.Extract(rewrittenRequest, nil) + assert.Equal(t, "claude-local", requestModel) + + // Step 2: Simulate response with resolved model + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`) + wrapped.Header().Set("Content-Type", "application/json") + wrapped.Write(response) + cleanup() + + // Verify response was rewritten back + body, _ := io.ReadAll(recorder.Result().Body) + responseModel, _ := extractor.Extract(body, nil) + assert.Equal(t, "gpt-4", responseModel) +} + +func TestModelRewriter_NonJSONBody(t *testing.T) { + rewriter := NewModelRewriter() + + // Binary/non-JSON body should be returned unchanged + body := []byte{0x00, 0x01, 0x02, 0x03} + result, err := rewriter.RewriteRequestBody(body, "gpt-4") + require.NoError(t, err) + assert.Equal(t, body, result) +} + +func TestModelRewriter_InvalidJSON(t *testing.T) { + rewriter := NewModelRewriter() + + // Invalid JSON without model field should be returned unchanged + body := []byte(`not valid json`) + result, err := rewriter.RewriteRequestBody(body, "gpt-4") + require.NoError(t, err) + assert.Equal(t, body, result) +} + +func TestResponseRewriter_StatusCodePreserved(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + wrapped.WriteHeader(http.StatusAccepted) + wrapped.Write([]byte(`{"model":"claude-local"}`)) + cleanup() + + assert.Equal(t, http.StatusAccepted, recorder.Code) +} + +func TestResponseRewriter_HeaderFlushed(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + wrapped.Header().Set("Content-Type", "application/json") + wrapped.Header().Set("X-Request-ID", "abc123") + wrapped.Write([]byte(`{"model":"claude-local"}`)) + cleanup() + + result := recorder.Result() + assert.Equal(t, "application/json", result.Header.Get("Content-Type")) + assert.Equal(t, "abc123", result.Header.Get("X-Request-ID")) +} diff --git a/internal/routing/router.go b/internal/routing/router.go index db74ef3c..30548ab1 100644 --- a/internal/routing/router.go +++ b/internal/routing/router.go @@ -31,15 +31,17 @@ func NewRouter(registry *Registry, cfg *config.Config) *Router { return r } -// RoutingDecision contains the resolved routing information. -type RoutingDecision struct { +// LegacyRoutingDecision contains the resolved routing information. +// Deprecated: Will be replaced by RoutingDecision from types.go in T-013. +type LegacyRoutingDecision struct { RequestedModel string // Original model from request ResolvedModel string // After model-mappings Candidates []ProviderCandidate // Ordered list of providers to try } // Resolve determines the routing decision for the requested model. -func (r *Router) Resolve(requestedModel string) *RoutingDecision { +// Deprecated: Will be updated to use RoutingRequest and return *RoutingDecision in T-013. +func (r *Router) Resolve(requestedModel string) *LegacyRoutingDecision { // 1. Extract thinking suffix suffixResult := thinking.ParseSuffix(requestedModel) baseModel := suffixResult.ModelName @@ -60,13 +62,151 @@ func (r *Router) Resolve(requestedModel string) *RoutingDecision { return candidates[i].Provider.Priority() < candidates[j].Provider.Priority() }) - return &RoutingDecision{ + return &LegacyRoutingDecision{ RequestedModel: requestedModel, ResolvedModel: targetModel, Candidates: candidates, } } +// ResolveV2 determines the routing decision for a routing request. +// It uses the new RoutingRequest and RoutingDecision types. +func (r *Router) ResolveV2(req RoutingRequest) *RoutingDecision { + // 1. Extract thinking suffix + suffixResult := thinking.ParseSuffix(req.RequestedModel) + baseModel := suffixResult.ModelName + thinkingSuffix := "" + if suffixResult.HasSuffix { + thinkingSuffix = "(" + suffixResult.RawSuffix + ")" + } + + // 2. Check for local providers + localCandidates := r.findLocalCandidates(baseModel, suffixResult) + + // 3. Apply model-mappings if needed + mappedModel := r.applyMappings(baseModel) + mappingCandidates := r.findLocalCandidates(mappedModel, suffixResult) + + // 4. Determine route type based on preferences and availability + var decision *RoutingDecision + + if req.ForceModelMapping && mappedModel != baseModel && len(mappingCandidates) > 0 { + // FORCE MODE: Use mapping even if local provider exists + decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:]) + } else if req.PreferLocalProvider && len(localCandidates) > 0 { + // DEFAULT MODE with local preference: Use local provider first + decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix) + } else if len(localCandidates) > 0 { + // DEFAULT MODE: Local provider available + decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix) + } else if mappedModel != baseModel && len(mappingCandidates) > 0 { + // DEFAULT MODE: No local provider, but mapping available + decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:]) + } else { + // No local provider, no mapping - use amp credits proxy + decision = &RoutingDecision{ + RouteType: RouteTypeAmpCredits, + ResolvedModel: req.RequestedModel, + ShouldProxy: true, + } + } + + return decision +} + +// findLocalCandidates finds local provider candidates for a model. +func (r *Router) findLocalCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate { + var candidates []ProviderCandidate + + for _, p := range r.registry.All() { + if !p.SupportsModel(model) { + continue + } + + // Apply thinking suffix if needed + actualModel := model + if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix { + actualModel = model + "(" + suffixResult.RawSuffix + ")" + } + + if p.Available(actualModel) { + candidates = append(candidates, ProviderCandidate{ + Provider: p, + Model: actualModel, + }) + } + } + + // Sort by priority + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].Provider.Priority() < candidates[j].Provider.Priority() + }) + + return candidates +} + +// buildLocalProviderDecision creates a decision for local provider routing. +func (r *Router) buildLocalProviderDecision(requestedModel string, candidates []ProviderCandidate, thinkingSuffix string) *RoutingDecision { + resolvedModel := requestedModel + if thinkingSuffix != "" { + // Ensure thinking suffix is preserved + sr := thinking.ParseSuffix(requestedModel) + if !sr.HasSuffix { + resolvedModel = requestedModel + thinkingSuffix + } + } + + var fallbackModels []string + if len(candidates) > 1 { + for _, c := range candidates[1:] { + fallbackModels = append(fallbackModels, c.Model) + } + } + + return &RoutingDecision{ + RouteType: RouteTypeLocalProvider, + ResolvedModel: resolvedModel, + ProviderName: candidates[0].Provider.Name(), + FallbackModels: fallbackModels, + ShouldProxy: false, + } +} + +// buildMappingDecision creates a decision for model mapping routing. +func (r *Router) buildMappingDecision(requestedModel, mappedModel string, candidates []ProviderCandidate, thinkingSuffix string, fallbackCandidates []ProviderCandidate) *RoutingDecision { + // Apply thinking suffix to resolved model if needed + resolvedModel := mappedModel + if thinkingSuffix != "" { + sr := thinking.ParseSuffix(mappedModel) + if !sr.HasSuffix { + resolvedModel = mappedModel + thinkingSuffix + } + } + + var fallbackModels []string + for _, c := range fallbackCandidates { + fallbackModels = append(fallbackModels, c.Model) + } + + // Also add oauth aliases as fallbacks + baseMapped := thinking.ParseSuffix(mappedModel).ModelName + for _, alias := range r.oauthAliases[strings.ToLower(baseMapped)] { + // Check if this alias has providers + aliasCandidates := r.findLocalCandidates(alias, thinking.SuffixResult{ModelName: alias}) + for _, c := range aliasCandidates { + fallbackModels = append(fallbackModels, c.Model) + } + } + + return &RoutingDecision{ + RouteType: RouteTypeModelMapping, + ResolvedModel: resolvedModel, + ProviderName: candidates[0].Provider.Name(), + FallbackModels: fallbackModels, + ShouldProxy: false, + } +} + // applyMappings applies model-mappings configuration. func (r *Router) applyMappings(model string) string { key := strings.ToLower(strings.TrimSpace(model)) diff --git a/internal/routing/router_v2_test.go b/internal/routing/router_v2_test.go new file mode 100644 index 00000000..903b7aa8 --- /dev/null +++ b/internal/routing/router_v2_test.go @@ -0,0 +1,245 @@ +package routing + +import ( + "context" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/stretchr/testify/assert" +) + +func TestRouter_DefaultMode_PrefersLocal(t *testing.T) { + // Setup: Create a router with a mock provider that supports "gpt-4" + registry := NewRegistry() + mockProvider := &MockProvider{ + name: "openai", + supportedModels: []string{"gpt-4"}, + available: true, + priority: 1, + } + registry.Register(mockProvider) + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "claude-local"}, + }, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request gpt-4 when local provider exists + req := RoutingRequest{ + RequestedModel: "gpt-4", + PreferLocalProvider: true, + ForceModelMapping: false, + } + + decision := router.ResolveV2(req) + + // Assert: Should return LOCAL_PROVIDER, not MODEL_MAPPING + assert.Equal(t, RouteTypeLocalProvider, decision.RouteType) + assert.Equal(t, "gpt-4", decision.ResolvedModel) + assert.Equal(t, "openai", decision.ProviderName) + assert.False(t, decision.ShouldProxy) +} + +func TestRouter_DefaultMode_MapsWhenNoLocal(t *testing.T) { + // Setup: Create a router with NO provider for "gpt-4" but a mapping to "claude-local" + // which has a provider + registry := NewRegistry() + mockProvider := &MockProvider{ + name: "anthropic", + supportedModels: []string{"claude-local"}, + available: true, + priority: 1, + } + registry.Register(mockProvider) + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "claude-local"}, + }, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request gpt-4 when no local provider exists, but mapping exists + req := RoutingRequest{ + RequestedModel: "gpt-4", + PreferLocalProvider: true, + ForceModelMapping: false, + } + + decision := router.ResolveV2(req) + + // Assert: Should return MODEL_MAPPING + assert.Equal(t, RouteTypeModelMapping, decision.RouteType) + assert.Equal(t, "claude-local", decision.ResolvedModel) + assert.Equal(t, "anthropic", decision.ProviderName) + assert.False(t, decision.ShouldProxy) +} + +func TestRouter_DefaultMode_AmpCreditsWhenNoLocalOrMapping(t *testing.T) { + // Setup: Create a router with no providers and no mappings + registry := NewRegistry() + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{}, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request a model with no local provider and no mapping + req := RoutingRequest{ + RequestedModel: "unknown-model", + PreferLocalProvider: true, + ForceModelMapping: false, + } + + decision := router.ResolveV2(req) + + // Assert: Should return AMP_CREDITS with ShouldProxy=true + assert.Equal(t, RouteTypeAmpCredits, decision.RouteType) + assert.Equal(t, "unknown-model", decision.ResolvedModel) + assert.True(t, decision.ShouldProxy) + assert.Empty(t, decision.ProviderName) +} + +func TestRouter_ForceMode_MapsEvenWithLocal(t *testing.T) { + // Setup: Create a router with BOTH a local provider for "gpt-4" AND a mapping from "gpt-4" to "claude-local" + // The mapping target "claude-local" also has a provider + registry := NewRegistry() + + // Local provider for gpt-4 + openaiProvider := &MockProvider{ + name: "openai", + supportedModels: []string{"gpt-4"}, + available: true, + priority: 1, + } + registry.Register(openaiProvider) + + // Local provider for the mapped model + anthropicProvider := &MockProvider{ + name: "anthropic", + supportedModels: []string{"claude-local"}, + available: true, + priority: 2, + } + registry.Register(anthropicProvider) + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "claude-local"}, + }, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request gpt-4 with ForceModelMapping=true + // Even though gpt-4 has a local provider, mapping should take precedence + req := RoutingRequest{ + RequestedModel: "gpt-4", + PreferLocalProvider: false, + ForceModelMapping: true, + } + + decision := router.ResolveV2(req) + + // Assert: Should return MODEL_MAPPING, not LOCAL_PROVIDER + assert.Equal(t, RouteTypeModelMapping, decision.RouteType) + assert.Equal(t, "claude-local", decision.ResolvedModel) + assert.Equal(t, "anthropic", decision.ProviderName) + assert.False(t, decision.ShouldProxy) +} + +func TestRouter_ThinkingSuffix_Preserved(t *testing.T) { + // Setup: Create a router with mapping and provider for mapped model + registry := NewRegistry() + + mockProvider := &MockProvider{ + name: "anthropic", + supportedModels: []string{"claude-local"}, + available: true, + priority: 1, + } + registry.Register(mockProvider) + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "claude-3-5-sonnet", To: "claude-local"}, + }, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request claude-3-5-sonnet with thinking suffix + req := RoutingRequest{ + RequestedModel: "claude-3-5-sonnet(thinking:foo)", + PreferLocalProvider: true, + ForceModelMapping: false, + } + + decision := router.ResolveV2(req) + + // Assert: Thinking suffix should be preserved in resolved model + assert.Equal(t, RouteTypeModelMapping, decision.RouteType) + assert.Equal(t, "claude-local(thinking:foo)", decision.ResolvedModel) + assert.Equal(t, "anthropic", decision.ProviderName) +} + +// MockProvider is a mock implementation of Provider for testing +type MockProvider struct { + name string + providerType ProviderType + supportedModels []string + available bool + priority int +} + +func (m *MockProvider) Name() string { + return m.name +} + +func (m *MockProvider) Type() ProviderType { + if m.providerType == "" { + return ProviderTypeOAuth + } + return m.providerType +} + +func (m *MockProvider) SupportsModel(model string) bool { + for _, supported := range m.supportedModels { + if supported == model { + return true + } + } + return false +} + +func (m *MockProvider) Available(model string) bool { + return m.available +} + +func (m *MockProvider) Priority() int { + return m.priority +} + +func (m *MockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + return executor.Response{}, nil +} + +func (m *MockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) { + return nil, nil +} diff --git a/internal/routing/testutil/fake_handler.go b/internal/routing/testutil/fake_handler.go new file mode 100644 index 00000000..160aaad8 --- /dev/null +++ b/internal/routing/testutil/fake_handler.go @@ -0,0 +1,113 @@ +package testutil + +import ( + "io" + "net/http" + + "github.com/gin-gonic/gin" +) + +// FakeHandlerRecorder records handler invocations for testing. +type FakeHandlerRecorder struct { + Called bool + CallCount int + RequestBody []byte + RequestHeader http.Header + ContextKeys map[string]interface{} + ResponseStatus int + ResponseBody []byte +} + +// NewFakeHandlerRecorder creates a new fake handler recorder. +func NewFakeHandlerRecorder() *FakeHandlerRecorder { + return &FakeHandlerRecorder{ + ContextKeys: make(map[string]interface{}), + ResponseStatus: http.StatusOK, + ResponseBody: []byte(`{"status":"handled"}`), + } +} + +// GinHandler returns a gin.HandlerFunc that records the invocation. +func (f *FakeHandlerRecorder) GinHandler() gin.HandlerFunc { + return func(c *gin.Context) { + f.record(c) + c.Data(f.ResponseStatus, "application/json", f.ResponseBody) + } +} + +// GinHandlerWithModel returns a gin.HandlerFunc that records the invocation and returns the model from context. +// Useful for testing response rewriting in model mapping scenarios. +func (f *FakeHandlerRecorder) GinHandlerWithModel() gin.HandlerFunc { + return func(c *gin.Context) { + f.record(c) + // Return a response with the model field that would be in the actual API response + // If ResponseBody was explicitly set (not default), use that; otherwise generate from context + var body []byte + if mappedModel, exists := c.Get("mapped_model"); exists { + body = []byte(`{"model":"` + mappedModel.(string) + `","status":"handled"}`) + } else { + body = f.ResponseBody + } + c.Data(f.ResponseStatus, "application/json", body) + } +} + +// HTTPHandler returns an http.HandlerFunc that records the invocation. +func (f *FakeHandlerRecorder) HTTPHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + f.Called = true + f.CallCount++ + f.RequestBody = body + f.RequestHeader = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(f.ResponseStatus) + w.Write(f.ResponseBody) + } +} + +// record captures the request details from gin context. +func (f *FakeHandlerRecorder) record(c *gin.Context) { + f.Called = true + f.CallCount++ + + body, _ := io.ReadAll(c.Request.Body) + f.RequestBody = body + f.RequestHeader = c.Request.Header.Clone() + + // Capture common context keys used by routing + if val, exists := c.Get("mapped_model"); exists { + f.ContextKeys["mapped_model"] = val + } + if val, exists := c.Get("fallback_models"); exists { + f.ContextKeys["fallback_models"] = val + } + if val, exists := c.Get("route_type"); exists { + f.ContextKeys["route_type"] = val + } +} + +// Reset clears the recorder state. +func (f *FakeHandlerRecorder) Reset() { + f.Called = false + f.CallCount = 0 + f.RequestBody = nil + f.RequestHeader = nil + f.ContextKeys = make(map[string]interface{}) +} + +// GetContextKey returns a captured context key value. +func (f *FakeHandlerRecorder) GetContextKey(key string) (interface{}, bool) { + val, ok := f.ContextKeys[key] + return val, ok +} + +// WasCalled returns true if the handler was called. +func (f *FakeHandlerRecorder) WasCalled() bool { + return f.Called +} + +// GetCallCount returns the number of times the handler was called. +func (f *FakeHandlerRecorder) GetCallCount() int { + return f.CallCount +} diff --git a/internal/routing/testutil/fake_proxy.go b/internal/routing/testutil/fake_proxy.go new file mode 100644 index 00000000..3deea5a5 --- /dev/null +++ b/internal/routing/testutil/fake_proxy.go @@ -0,0 +1,83 @@ +package testutil + +import ( + "io" + "net/http" + "net/http/httptest" +) + +// CloseNotifierRecorder wraps httptest.ResponseRecorder with CloseNotify support. +// This is needed because ReverseProxy requires http.CloseNotifier. +type CloseNotifierRecorder struct { + *httptest.ResponseRecorder + closeChan chan bool +} + +// NewCloseNotifierRecorder creates a ResponseRecorder that implements CloseNotifier. +func NewCloseNotifierRecorder() *CloseNotifierRecorder { + return &CloseNotifierRecorder{ + ResponseRecorder: httptest.NewRecorder(), + closeChan: make(chan bool, 1), + } +} + +// CloseNotify implements http.CloseNotifier. +func (c *CloseNotifierRecorder) CloseNotify() <-chan bool { + return c.closeChan +} + +// FakeProxyRecorder records proxy invocations for testing. +type FakeProxyRecorder struct { + Called bool + CallCount int + RequestBody []byte + RequestHeaders http.Header + ResponseStatus int + ResponseBody []byte +} + +// NewFakeProxyRecorder creates a new fake proxy recorder. +func NewFakeProxyRecorder() *FakeProxyRecorder { + return &FakeProxyRecorder{ + ResponseStatus: http.StatusOK, + ResponseBody: []byte(`{"status":"proxied"}`), + } +} + +// ServeHTTP implements http.Handler to act as a reverse proxy. +func (f *FakeProxyRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) { + f.Called = true + f.CallCount++ + f.RequestHeaders = r.Header.Clone() + + body, err := io.ReadAll(r.Body) + if err == nil { + f.RequestBody = body + } + + w.WriteHeader(f.ResponseStatus) + w.Write(f.ResponseBody) +} + +// GetCallCount returns the number of times the proxy was called. +func (f *FakeProxyRecorder) GetCallCount() int { + return f.CallCount +} + +// Reset clears the recorder state. +func (f *FakeProxyRecorder) Reset() { + f.Called = false + f.CallCount = 0 + f.RequestBody = nil + f.RequestHeaders = nil +} + +// ToHandler returns the recorder as an http.Handler for use with httptest. +func (f *FakeProxyRecorder) ToHandler() http.Handler { + return http.HandlerFunc(f.ServeHTTP) +} + +// CreateTestServer creates an httptest server with this fake proxy. +func (f *FakeProxyRecorder) CreateTestServer() *httptest.Server { + return httptest.NewServer(f.ToHandler()) +} diff --git a/internal/routing/types.go b/internal/routing/types.go new file mode 100644 index 00000000..30c50610 --- /dev/null +++ b/internal/routing/types.go @@ -0,0 +1,62 @@ +package routing + +// RouteType represents the type of routing decision made for a request. +type RouteType string + +const ( + // RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free). + RouteTypeLocalProvider RouteType = "LOCAL_PROVIDER" + // RouteTypeModelMapping indicates the request was remapped to another available model (free). + RouteTypeModelMapping RouteType = "MODEL_MAPPING" + // RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits). + RouteTypeAmpCredits RouteType = "AMP_CREDITS" + // RouteTypeNoProvider indicates no provider or fallback available. + RouteTypeNoProvider RouteType = "NO_PROVIDER" +) + +// RoutingRequest contains the information needed to make a routing decision. +type RoutingRequest struct { + // RequestedModel is the model name from the incoming request. + RequestedModel string + // PreferLocalProvider indicates whether to prefer local providers over mappings. + // When true, check local providers first before applying model mappings. + PreferLocalProvider bool + // ForceModelMapping indicates whether to force model mapping even if local provider exists. + // When true, apply model mappings first and skip local provider checks. + ForceModelMapping bool +} + +// RoutingDecision contains the result of a routing decision. +type RoutingDecision struct { + // RouteType indicates the type of routing decision. + RouteType RouteType + // ResolvedModel is the final model name after any mappings. + ResolvedModel string + // ProviderName is the name of the selected provider (if any). + ProviderName string + // FallbackModels is a list of alternative models to try if the primary fails. + FallbackModels []string + // ShouldProxy indicates whether the request should be proxied to ampcode.com. + ShouldProxy bool +} + +// NewRoutingDecision creates a new RoutingDecision with the given parameters. +func NewRoutingDecision(routeType RouteType, resolvedModel, providerName string, fallbackModels []string, shouldProxy bool) *RoutingDecision { + return &RoutingDecision{ + RouteType: routeType, + ResolvedModel: resolvedModel, + ProviderName: providerName, + FallbackModels: fallbackModels, + ShouldProxy: shouldProxy, + } +} + +// IsLocal returns true if the decision routes to a local provider. +func (d *RoutingDecision) IsLocal() bool { + return d.RouteType == RouteTypeLocalProvider || d.RouteType == RouteTypeModelMapping +} + +// HasFallbacks returns true if there are fallback models available. +func (d *RoutingDecision) HasFallbacks() bool { + return len(d.FallbackModels) > 0 +} diff --git a/internal/routing/wrapper.go b/internal/routing/wrapper.go new file mode 100644 index 00000000..90d10eea --- /dev/null +++ b/internal/routing/wrapper.go @@ -0,0 +1,270 @@ +package routing + +import ( + "bufio" + "bytes" + "io" + "net" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys" + "github.com/sirupsen/logrus" +) + +// ProxyFunc is the function type for proxying requests. +type ProxyFunc func(c *gin.Context) + +// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic. +// It replaces the FallbackHandler logic with a Router-based approach. +type ModelRoutingWrapper struct { + router *Router + extractor ModelExtractor + rewriter ModelRewriter + proxyFunc ProxyFunc + logger *logrus.Logger +} + +// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies. +// If extractor is nil, a DefaultModelExtractor is used. +// If rewriter is nil, a DefaultModelRewriter is used. +// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead. +func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper { + if extractor == nil { + extractor = NewModelExtractor() + } + if rewriter == nil { + rewriter = NewModelRewriter() + } + return &ModelRoutingWrapper{ + router: router, + extractor: extractor, + rewriter: rewriter, + proxyFunc: proxyFunc, + logger: logrus.New(), + } +} + +// SetLogger sets the logger for the wrapper. +func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) { + w.logger = logger +} + +// Wrap wraps a gin.HandlerFunc with model routing logic. +// The returned handler will: +// 1. Extract the model from the request +// 2. Get a routing decision from the Router +// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS) +func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + // Read request body + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + w.logger.Errorf("routing wrapper: failed to read request body: %v", err) + handler(c) + return + } + + // Extract model from request + ginParams := map[string]string{ + "action": c.Param("action"), + "path": c.Param("path"), + } + modelName, err := w.extractor.Extract(bodyBytes, ginParams) + if err != nil { + w.logger.Warnf("routing wrapper: failed to extract model: %v", err) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + return + } + + if modelName == "" { + // No model found, proceed with original handler + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + return + } + + // Get routing decision + req := RoutingRequest{ + RequestedModel: modelName, + PreferLocalProvider: true, + ForceModelMapping: false, // TODO: Get from config + } + decision := w.router.ResolveV2(req) + + // Store decision in context for downstream handlers + c.Set(string(ctxkeys.RoutingDecision), decision) + + // Handle based on route type + switch decision.RouteType { + case RouteTypeLocalProvider: + w.handleLocalProvider(c, handler, bodyBytes, decision) + case RouteTypeModelMapping: + w.handleModelMapping(c, handler, bodyBytes, decision) + case RouteTypeAmpCredits: + w.handleAmpCredits(c, handler, bodyBytes) + default: + // No provider available + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + } + } +} + +// handleLocalProvider handles the LOCAL_PROVIDER route type. +func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) { + // Filter Anthropic-Beta header for local provider + filterAnthropicBetaHeader(c) + + // Restore body with original content + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Call handler + handler(c) +} + +// handleModelMapping handles the MODEL_MAPPING route type. +func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) { + // Rewrite request body with mapped model + rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel) + if err != nil { + w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err) + rewrittenBody = bodyBytes + } + _ = rewrittenBody + + // Store mapped model in context + c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel) + + // Store fallback models in context if present + if len(decision.FallbackModels) > 0 { + c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels) + } + + // Filter Anthropic-Beta header for local provider + filterAnthropicBetaHeader(c) + + // Restore body with rewritten content + c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody)) + + // Wrap response writer to rewrite model back + wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel) + c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer} + + // Call handler + handler(c) + + // Cleanup (flush response rewriting) + cleanup() +} + +// handleAmpCredits handles the AMP_CREDITS route type. +// It calls the proxy function directly if available, otherwise passes to handler. +// Does NOT filter headers or rewrite body - proxy handles everything. +func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) { + // Restore body with original content (no rewriting for proxy) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Call proxy function if available, otherwise fall back to handler + if w.proxyFunc != nil { + w.proxyFunc(c) + } else { + handler(c) + } +} + +// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers. +func filterAnthropicBetaHeader(c *gin.Context) { + if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" { + filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07") + if filtered != "" { + c.Request.Header.Set("Anthropic-Beta", filtered) + } else { + c.Request.Header.Del("Anthropic-Beta") + } + } +} + +// filterBetaFeatures removes specified beta features from the header. +func filterBetaFeatures(betaHeader, featureToRemove string) string { + // Simple implementation - can be enhanced + if betaHeader == featureToRemove { + return "" + } + return betaHeader +} + +// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter. +type ginResponseWriterAdapter struct { + http.ResponseWriter + original gin.ResponseWriter +} + +func (a *ginResponseWriterAdapter) WriteHeader(code int) { + a.ResponseWriter.WriteHeader(code) +} + +func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) { + return a.ResponseWriter.Write(data) +} + +func (a *ginResponseWriterAdapter) Header() http.Header { + return a.ResponseWriter.Header() +} + +// CloseNotify implements http.CloseNotifier. +func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool { + if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok { + return notifier.CloseNotify() + } + return a.original.CloseNotify() +} + +// Flush implements http.Flusher. +func (a *ginResponseWriterAdapter) Flush() { + if flusher, ok := a.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +// Hijack implements http.Hijacker. +func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return a.original.Hijack() +} + +// Status returns the HTTP status code. +func (a *ginResponseWriterAdapter) Status() int { + return a.original.Status() +} + +// Size returns the number of bytes already written into the response http body. +func (a *ginResponseWriterAdapter) Size() int { + return a.original.Size() +} + +// Written returns whether or not the response for this context has been written. +func (a *ginResponseWriterAdapter) Written() bool { + return a.original.Written() +} + +// WriteHeaderNow forces WriteHeader to be called. +func (a *ginResponseWriterAdapter) WriteHeaderNow() { + a.original.WriteHeaderNow() +} + +// WriteString writes the given string into the response body. +func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) { + return a.Write([]byte(s)) +} + +// Pusher returns the http.Pusher for server push. +func (a *ginResponseWriterAdapter) Pusher() http.Pusher { + if pusher, ok := a.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +}