diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index b38ab5f1..0cbe0e1a 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -28,6 +28,9 @@ const ( RouteTypeNoProvider AmpRouteType = "NO_PROVIDER" ) +// MappedModelContextKey is the Gin context key for passing mapped model names. +const MappedModelContextKey = "mapped_model" + // logAmpRouting logs the routing decision for an Amp request with structured fields func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { fields := log.Fields{ @@ -141,6 +144,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Mapping found - rewrite the model in request body bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) resolvedModel = mappedModel usedMapping = true diff --git a/internal/api/modules/amp/gemini_bridge.go b/internal/api/modules/amp/gemini_bridge.go index 3b3d8374..cc31b685 100644 --- a/internal/api/modules/amp/gemini_bridge.go +++ b/internal/api/modules/amp/gemini_bridge.go @@ -26,6 +26,18 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl // Extract everything after "/models/" actionPart := path[idx+8:] // Skip "/models/" + // Check if model was mapped by FallbackHandler + if mappedModel, exists := c.Get(MappedModelContextKey); exists { + if strModel, ok := mappedModel.(string); ok && strModel != "" { + // Replace the model part in the action + // actionPart is like "model-name:method" + if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 { + method := actionPart[colonIdx:] // ":method" + actionPart = strModel + method + } + } + } + // Set this as the :action parameter that the Gemini handler expects c.Params = append(c.Params, gin.Param{ Key: "action", diff --git a/internal/api/modules/amp/gemini_bridge_test.go b/internal/api/modules/amp/gemini_bridge_test.go new file mode 100644 index 00000000..88accbf4 --- /dev/null +++ b/internal/api/modules/amp/gemini_bridge_test.go @@ -0,0 +1,120 @@ +package amp + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini" +) + +func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + path string + mappedModel string // empty string means no mapping + expectedAction string + }{ + { + name: "no_mapping_uses_url_model", + path: "/publishers/google/models/gemini-pro:generateContent", + mappedModel: "", + expectedAction: "gemini-pro:generateContent", + }, + { + name: "mapped_model_replaces_url_model", + path: "/publishers/google/models/gemini-exp:generateContent", + mappedModel: "gemini-2.0-flash", + expectedAction: "gemini-2.0-flash:generateContent", + }, + { + name: "mapping_preserves_method", + path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent", + mappedModel: "gemini-flash", + expectedAction: "gemini-flash:streamGenerateContent", + }, + { + name: "empty_mapped_model_ignored", + path: "/publishers/google/models/gemini-pro:generateContent", + mappedModel: "", + expectedAction: "gemini-pro:generateContent", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedAction string + + mockGeminiHandler := func(c *gin.Context) { + capturedAction = c.Param("action") + c.JSON(http.StatusOK, gin.H{"captured": capturedAction}) + } + + // Mirror the bridge logic from gemini_bridge.go + bridgeHandler := func(c *gin.Context) { + path := c.Param("path") + if idx := strings.Index(path, "/models/"); idx >= 0 { + actionPart := path[idx+8:] + + if mappedModel, exists := c.Get(MappedModelContextKey); exists { + if strModel, ok := mappedModel.(string); ok && strModel != "" { + if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 { + method := actionPart[colonIdx:] + actionPart = strModel + method + } + } + } + + c.Params = append(c.Params, gin.Param{Key: "action", Value: actionPart}) + mockGeminiHandler(c) + return + } + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid path"}) + } + + r := gin.New() + if tt.mappedModel != "" { + r.Use(func(c *gin.Context) { + c.Set(MappedModelContextKey, tt.mappedModel) + c.Next() + }) + } + r.POST("/api/provider/google/v1beta1/*path", bridgeHandler) + + req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d", w.Code) + } + if capturedAction != tt.expectedAction { + t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction) + } + }) + } +} + +func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + base := &handlers.BaseAPIHandler{} + geminiHandlers := gemini.NewGeminiAPIHandler(base) + bridgeHandler := createGeminiBridgeHandler(geminiHandlers) + + r := gin.New() + r.POST("/api/provider/google/v1beta1/*path", bridgeHandler) + + req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for invalid path, got %d", w.Code) + } +} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 57bb5246..1f61d5ac 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -170,9 +170,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha // If no local OAuth is available, falls back to ampcode.com proxy. geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) geminiBridge := createGeminiBridgeHandler(geminiHandlers) - geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy { + geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }) + }, m.modelMapper) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) // Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy. @@ -187,8 +187,18 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha } if modelPart != "" { normalized, _ := util.NormalizeGeminiThinkingModel(modelPart) - // Only handle locally when we have a provider; otherwise fall back to proxy + // Only handle locally when we have a provider or a valid mapping; otherwise fall back to proxy + hasProvider := false if providers := util.GetProviderName(normalized); len(providers) > 0 { + hasProvider = true + } else if m.modelMapper != nil { + // Check if mapped model has provider (MapModel returns target only if it has providers) + if mapped := m.modelMapper.MapModel(normalized); mapped != "" { + hasProvider = true + } + } + + if hasProvider { geminiV1Beta1Handler(c) return }