diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go new file mode 100644 index 00000000..d0ccac56 --- /dev/null +++ b/internal/api/modules/amp/fallback_handlers.go @@ -0,0 +1,105 @@ +package amp + +import ( + "bytes" + "encoding/json" + "io" + "net/http/httputil" + "strings" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +// FallbackHandler wraps a standard handler with fallback logic to ampcode.com +// when the model's provider is not available in CLIProxyAPI +type FallbackHandler struct { + getProxy func() *httputil.ReverseProxy +} + +// NewFallbackHandler creates a new fallback handler wrapper +// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) +func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { + return &FallbackHandler{ + getProxy: getProxy, + } +} + +// WrapHandler wraps a gin.HandlerFunc with fallback logic +// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com +func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + // Read the request body to extract the model name + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + log.Errorf("amp fallback: failed to read request body: %v", err) + handler(c) + return + } + + // Restore the body for the handler to read + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Try to extract model from request body or URL path (for Gemini) + modelName := extractModelFromRequest(bodyBytes, c) + if modelName == "" { + // Can't determine model, proceed with normal handler + handler(c) + return + } + + // Normalize model (handles Gemini thinking suffixes) + normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName) + + // Check if we have providers for this model + providers := util.GetProviderName(normalizedModel) + + if len(providers) == 0 { + // No providers configured - check if we have a proxy for fallback + proxy := fh.getProxy() + if proxy != nil { + // Fallback to ampcode.com + log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName) + + // Restore body again for the proxy + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Forward to ampcode.com + proxy.ServeHTTP(c.Writer, c.Request) + return + } + + // No proxy available, let the normal handler return the error + log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName) + } + + // Providers available or no proxy for fallback, restore body and use normal handler + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + } +} + +// extractModelFromRequest attempts to extract the model name from various request formats +func extractModelFromRequest(body []byte, c *gin.Context) string { + // First try to parse from JSON body (OpenAI, Claude, etc.) + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err == nil { + // Check common model field names + if model, ok := payload["model"].(string); ok { + return model + } + } + + // For Gemini requests, model is in the URL path: /models/{model}:generateContent + // Extract from :action parameter (e.g., "gemini-pro:generateContent") + if action := c.Param("action"); action != "" { + // Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro") + parts := strings.Split(action, ":") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + } + + return "" +} diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index f952de8d..5231ec86 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -2,6 +2,7 @@ package amp import ( "net" + "net/http/httputil" "strings" "github.com/gin-gonic/gin" @@ -17,7 +18,7 @@ import ( func localhostOnlyMiddleware() gin.HandlerFunc { return func(c *gin.Context) { clientIP := c.ClientIP() - + // Parse the IP to handle both IPv4 and IPv6 ip := net.ParseIP(clientIP) if ip == nil { @@ -27,7 +28,7 @@ func localhostOnlyMiddleware() gin.HandlerFunc { }) return } - + // Check if IP is loopback (127.0.0.1 or ::1) if !ip.IsLoopback() { log.Warnf("Amp management: non-localhost IP %s attempted access, denying", clientIP) @@ -36,7 +37,7 @@ func localhostOnlyMiddleware() gin.HandlerFunc { }) return } - + c.Next() } } @@ -50,13 +51,13 @@ func noCORSMiddleware() gin.HandlerFunc { c.Header("Access-Control-Allow-Methods", "") c.Header("Access-Control-Allow-Headers", "") c.Header("Access-Control-Allow-Credentials", "") - + // For OPTIONS preflight, deny with 403 if c.Request.Method == "OPTIONS" { c.AbortWithStatus(403) return } - + c.Next() } } @@ -66,10 +67,10 @@ func noCORSMiddleware() gin.HandlerFunc { // If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1. func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) { ampAPI := engine.Group("/api") - + // Always disable CORS for management routes to prevent browser-based attacks ampAPI.Use(noCORSMiddleware()) - + // Apply localhost-only restriction if configured if restrictToLocalhost { ampAPI.Use(localhostOnlyMiddleware()) @@ -112,6 +113,12 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler) openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) + // Create fallback handler wrapper that forwards to ampcode.com when provider not found + // Uses lazy evaluation to access proxy (which is created after routes are registered) + fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy { + return m.proxy + }) + // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") if auth != nil { @@ -136,31 +143,33 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han } // Root-level routes (for providers that omit /v1, like groq/cerebras) - provider.GET("/models", ampModelsHandler) - provider.POST("/chat/completions", openaiHandlers.ChatCompletions) - provider.POST("/completions", openaiHandlers.Completions) - provider.POST("/responses", openaiResponsesHandlers.Responses) + // Wrap handlers with fallback logic to forward to ampcode.com when provider not found + provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check) + provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions)) + provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions)) + provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses)) // /v1 routes (OpenAI/Claude-compatible endpoints) v1Amp := provider.Group("/v1") { - v1Amp.GET("/models", ampModelsHandler) + v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback - // OpenAI-compatible endpoints - v1Amp.POST("/chat/completions", openaiHandlers.ChatCompletions) - v1Amp.POST("/completions", openaiHandlers.Completions) - v1Amp.POST("/responses", openaiResponsesHandlers.Responses) + // OpenAI-compatible endpoints with fallback + v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions)) + v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions)) + v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses)) - // Claude/Anthropic-compatible endpoints - v1Amp.POST("/messages", claudeCodeHandlers.ClaudeMessages) - v1Amp.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) + // Claude/Anthropic-compatible endpoints with fallback + v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages)) + v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens)) } // /v1beta routes (Gemini native API) + // Note: Gemini handler extracts model from URL path, so fallback logic needs special handling v1betaAmp := provider.Group("/v1beta") { v1betaAmp.GET("/models", geminiHandlers.GeminiModels) - v1betaAmp.POST("/models/:action", geminiHandlers.GeminiHandler) + v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler)) v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler) } } diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index 33c9edcf..80c1730a 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -62,6 +62,23 @@ func ParseGeminiThinkingSuffix(model string) (string, *int, *bool, bool) { return base, &budgetValue, nil, true } +func NormalizeGeminiThinkingModel(modelName string) (string, map[string]any) { + baseModel, budget, include, matched := ParseGeminiThinkingSuffix(modelName) + if !matched { + return baseModel, nil + } + metadata := map[string]any{ + GeminiOriginalModelMetadataKey: modelName, + } + if budget != nil { + metadata[GeminiThinkingBudgetMetadataKey] = *budget + } + if include != nil { + metadata[GeminiIncludeThoughtsMetadataKey] = *include + } + return baseModel, metadata +} + func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte { if budget == nil && includeThoughts == nil { return body diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 07edad11..7c64aaba 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -359,20 +359,7 @@ func cloneBytes(src []byte) []byte { } func normalizeModelMetadata(modelName string) (string, map[string]any) { - baseModel, budget, include, matched := util.ParseGeminiThinkingSuffix(modelName) - if !matched { - return baseModel, nil - } - metadata := map[string]any{ - util.GeminiOriginalModelMetadataKey: modelName, - } - if budget != nil { - metadata[util.GeminiThinkingBudgetMetadataKey] = *budget - } - if include != nil { - metadata[util.GeminiIncludeThoughtsMetadataKey] = *include - } - return baseModel, metadata + return util.NormalizeGeminiThinkingModel(modelName) } func cloneMetadata(src map[string]any) map[string]any {