diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index b38ab5f1..0cbe0e1a 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -28,6 +28,9 @@ const ( RouteTypeNoProvider AmpRouteType = "NO_PROVIDER" ) +// MappedModelContextKey is the Gin context key for passing mapped model names. +const MappedModelContextKey = "mapped_model" + // logAmpRouting logs the routing decision for an Amp request with structured fields func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { fields := log.Fields{ @@ -141,6 +144,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Mapping found - rewrite the model in request body bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) resolvedModel = mappedModel usedMapping = true diff --git a/internal/api/modules/amp/gemini_bridge.go b/internal/api/modules/amp/gemini_bridge.go index 3b3d8374..d6ad8f79 100644 --- a/internal/api/modules/amp/gemini_bridge.go +++ b/internal/api/modules/amp/gemini_bridge.go @@ -4,7 +4,6 @@ import ( "strings" "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini" ) // createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths @@ -15,16 +14,31 @@ import ( // // This extracts the model+method from the AMP path and sets it as the :action parameter // so the standard Gemini handler can process it. -func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc { +// +// The handler parameter should be a Gemini-compatible handler that expects the :action param. +func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { // Get the full path from the catch-all parameter path := c.Param("path") // Extract model:method from AMP CLI path format // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent - if idx := strings.Index(path, "/models/"); idx >= 0 { - // Extract everything after "/models/" - actionPart := path[idx+8:] // Skip "/models/" + const modelsPrefix = "/models/" + if idx := strings.Index(path, modelsPrefix); idx >= 0 { + // Extract everything after modelsPrefix + actionPart := path[idx+len(modelsPrefix):] + + // Check if model was mapped by FallbackHandler + if mappedModel, exists := c.Get(MappedModelContextKey); exists { + if strModel, ok := mappedModel.(string); ok && strModel != "" { + // Replace the model part in the action + // actionPart is like "model-name:method" + if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 { + method := actionPart[colonIdx:] // ":method" + actionPart = strModel + method + } + } + } // Set this as the :action parameter that the Gemini handler expects c.Params = append(c.Params, gin.Param{ @@ -32,8 +46,8 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl Value: actionPart, }) - // Call the standard Gemini handler - geminiHandler.GeminiHandler(c) + // Call the handler + handler(c) return } diff --git a/internal/api/modules/amp/gemini_bridge_test.go b/internal/api/modules/amp/gemini_bridge_test.go new file mode 100644 index 00000000..347456c3 --- /dev/null +++ b/internal/api/modules/amp/gemini_bridge_test.go @@ -0,0 +1,93 @@ +package amp + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + path string + mappedModel string // empty string means no mapping + expectedAction string + }{ + { + name: "no_mapping_uses_url_model", + path: "/publishers/google/models/gemini-pro:generateContent", + mappedModel: "", + expectedAction: "gemini-pro:generateContent", + }, + { + name: "mapped_model_replaces_url_model", + path: "/publishers/google/models/gemini-exp:generateContent", + mappedModel: "gemini-2.0-flash", + expectedAction: "gemini-2.0-flash:generateContent", + }, + { + name: "mapping_preserves_method", + path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent", + mappedModel: "gemini-flash", + expectedAction: "gemini-flash:streamGenerateContent", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedAction string + + mockGeminiHandler := func(c *gin.Context) { + capturedAction = c.Param("action") + c.JSON(http.StatusOK, gin.H{"captured": capturedAction}) + } + + // Use the actual createGeminiBridgeHandler function + bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler) + + r := gin.New() + if tt.mappedModel != "" { + r.Use(func(c *gin.Context) { + c.Set(MappedModelContextKey, tt.mappedModel) + c.Next() + }) + } + r.POST("/api/provider/google/v1beta1/*path", bridgeHandler) + + req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d", w.Code) + } + if capturedAction != tt.expectedAction { + t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction) + } + }) + } +} + +func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + mockHandler := func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + } + bridgeHandler := createGeminiBridgeHandler(mockHandler) + + r := gin.New() + r.POST("/api/provider/google/v1beta1/*path", bridgeHandler) + + req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for invalid path, got %d", w.Code) + } +} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 57bb5246..6826dbbe 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -9,7 +9,6 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini" @@ -169,30 +168,22 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha // We bridge these to our standard Gemini handler to enable local OAuth. // If no local OAuth is available, falls back to ampcode.com proxy. geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) - geminiBridge := createGeminiBridgeHandler(geminiHandlers) - geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy { + geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) + geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }) + }, m.modelMapper) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) - // Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy. + // Route POST model calls through Gemini bridge with FallbackHandler. + // FallbackHandler checks provider -> mapping -> proxy fallback automatically. // All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior. ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) { if c.Request.Method == "POST" { - // Attempt to extract the model name from the AMP-style path if path := c.Param("path"); strings.Contains(path, "/models/") { - modelPart := path[strings.Index(path, "/models/")+len("/models/"):] - if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 { - modelPart = modelPart[:colonIdx] - } - if modelPart != "" { - normalized, _ := util.NormalizeGeminiThinkingModel(modelPart) - // Only handle locally when we have a provider; otherwise fall back to proxy - if providers := util.GetProviderName(normalized); len(providers) > 0 { - geminiV1Beta1Handler(c) - return - } - } + // POST with /models/ path -> use Gemini bridge with fallback handler + // FallbackHandler will check provider/mapping and proxy if needed + geminiV1Beta1Handler(c) + return } } // Non-POST or no local provider available -> proxy upstream