package amp import ( "bytes" "errors" "io" "net/http" "net/http/httputil" "strings" "time" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) // AmpRouteType represents the type of routing decision made for an Amp request type AmpRouteType string const ( // RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free) RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER" // RouteTypeModelMapping indicates the request was remapped to another available model (free) RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING" // RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits) RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS" // RouteTypeNoProvider indicates no provider or fallback available RouteTypeNoProvider AmpRouteType = "NO_PROVIDER" ) // MappedModelContextKey is the Gin context key for passing mapped model names. // Deprecated: Use ctxkeys.MappedModel instead. const MappedModelContextKey = string(ctxkeys.MappedModel) // FallbackModelsContextKey is the Gin context key for passing fallback model names. // When the primary mapped model fails (e.g., quota exceeded), these models can be tried. // Deprecated: Use ctxkeys.FallbackModels instead. const FallbackModelsContextKey = string(ctxkeys.FallbackModels) // logAmpRouting logs the routing decision for an Amp request with structured fields func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { fields := log.Fields{ "component": "amp-routing", "route_type": string(routeType), "requested_model": requestedModel, "path": path, "timestamp": time.Now().Format(time.RFC3339), } if resolvedModel != "" && resolvedModel != requestedModel { fields["resolved_model"] = resolvedModel } if provider != "" { fields["provider"] = provider } switch routeType { case RouteTypeLocalProvider: fields["cost"] = "free" fields["source"] = "local_oauth" log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel) case RouteTypeModelMapping: fields["cost"] = "free" fields["source"] = "local_oauth" fields["mapping"] = requestedModel + " -> " + resolvedModel // model mapping already logged in mapper; avoid duplicate here case RouteTypeAmpCredits: fields["cost"] = "amp_credits" fields["source"] = "ampcode.com" fields["model_id"] = requestedModel // Explicit model_id for easy config reference log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"\"}]", requestedModel, requestedModel) case RouteTypeNoProvider: fields["cost"] = "none" fields["source"] = "error" fields["model_id"] = requestedModel // Explicit model_id for easy config reference log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel) } } // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI // // Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper. // Use routing.NewModelRoutingWrapper() instead for unified routing logic. // This type is kept for backward compatibility and test purposes. type FallbackHandler struct { getProxy func() *httputil.ReverseProxy modelMapper ModelMapper forceModelMappings func() bool } // NewFallbackHandler creates a new fallback handler wrapper // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) // // Deprecated: Use routing.NewModelRoutingWrapper() instead. func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { return &FallbackHandler{ getProxy: getProxy, forceModelMappings: func() bool { return false }, } } // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support // // Deprecated: Use routing.NewModelRoutingWrapper() instead. func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler { if forceModelMappings == nil { forceModelMappings = func() bool { return false } } return &FallbackHandler{ getProxy: getProxy, modelMapper: mapper, forceModelMappings: forceModelMappings, } } // SetModelMapper sets the model mapper for this handler (allows late binding) func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) { fh.modelMapper = mapper } // WrapHandler wraps a gin.HandlerFunc with fallback logic // If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { // Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces. // ReverseProxy raises this panic when the client connection is closed prematurely // (e.g., user cancels request, network disconnect) or when ServeHTTP is called // with a ResponseWriter that doesn't implement http.CloseNotifier. // This is an expected error condition, not a bug, so we handle it gracefully. defer func() { if rec := recover(); rec != nil { if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) { return } panic(rec) } }() requestPath := c.Request.URL.Path // Read the request body to extract the model name bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { log.Errorf("amp fallback: failed to read request body: %v", err) handler(c) return } // Restore the body for the handler to read c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Try to extract model from request body or URL path (for Gemini) modelName := extractModelFromRequest(bodyBytes, c) if modelName == "" { // Can't determine model, proceed with normal handler handler(c) return } // Normalize model (handles dynamic thinking suffixes) suffixResult := thinking.ParseSuffix(modelName) normalizedModel := suffixResult.ModelName thinkingSuffix := "" if suffixResult.HasSuffix { thinkingSuffix = "(" + suffixResult.RawSuffix + ")" } // resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one. resolveMappedModels := func() ([]string, []string) { if fh.modelMapper == nil { return nil, nil } mapper, ok := fh.modelMapper.(*DefaultModelMapper) if !ok { // Fallback to single model for non-DefaultModelMapper mappedModel := fh.modelMapper.MapModel(modelName) if mappedModel == "" { mappedModel = fh.modelMapper.MapModel(normalizedModel) } if mappedModel == "" { return nil, nil } mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName mappedProviders := util.GetProviderName(mappedBaseModel) if len(mappedProviders) == 0 { return nil, nil } return []string{mappedModel}, mappedProviders } // Use MapModelWithFallbacks for DefaultModelMapper mappedModels := mapper.MapModelWithFallbacks(modelName) if len(mappedModels) == 0 { mappedModels = mapper.MapModelWithFallbacks(normalizedModel) } if len(mappedModels) == 0 { return nil, nil } // Apply thinking suffix if needed for i, model := range mappedModels { if thinkingSuffix != "" { suffixResult := thinking.ParseSuffix(model) if !suffixResult.HasSuffix { mappedModels[i] = model + thinkingSuffix } } } // Get providers for the first model firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName providers := util.GetProviderName(firstBaseModel) if len(providers) == 0 { return nil, nil } return mappedModels, providers } // Track resolved model for logging (may change if mapping is applied) resolvedModel := normalizedModel usedMapping := false var providers []string // Helper to apply model mapping and update state applyMapping := func(mappedModels []string, mappedProviders []string) { bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0]) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Set(string(ctxkeys.MappedModel), mappedModels[0]) if len(mappedModels) > 1 { c.Set(string(ctxkeys.FallbackModels), mappedModels[1:]) } resolvedModel = mappedModels[0] usedMapping = true providers = mappedProviders } // Check if model mappings should be forced ahead of local API keys forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings() if forceMappings { // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // This allows users to route Amp requests to their preferred OAuth providers if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 { applyMapping(mappedModels, mappedProviders) } // If no mapping applied, check for local providers if !usedMapping { providers = util.GetProviderName(normalizedModel) } } else { // DEFAULT MODE: Check local providers first, then mappings as fallback providers = util.GetProviderName(normalizedModel) if len(providers) == 0 { // No providers configured - check if we have a model mapping if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 { applyMapping(mappedModels, mappedProviders) } } } // If no providers available, fallback to ampcode.com if len(providers) == 0 { proxy := fh.getProxy() if proxy != nil { // Log: Forwarding to ampcode.com (uses Amp credits) logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath) // Restore body again for the proxy c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Forward to ampcode.com proxy.ServeHTTP(c.Writer, c.Request) return } // No proxy available, let the normal handler return the error logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) } // Log the routing decision providerName := "" if len(providers) > 0 { providerName = providers[0] } if usedMapping { // Log: Model was mapped to another model log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) rewriter := NewResponseRewriter(c.Writer, modelName) c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) handler(c) rewriter.Flush() log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName) } else if len(providers) > 0 { // Log: Using local provider (free) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) handler(c) } else { // No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) handler(c) } } } // filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription // This is needed when using local providers (bypassing the Amp proxy) func filterAntropicBetaHeader(c *gin.Context) { if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" { if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" { c.Request.Header.Set("Anthropic-Beta", filtered) } else { c.Request.Header.Del("Anthropic-Beta") } } } // rewriteModelInRequest replaces the model name in a JSON request body func rewriteModelInRequest(body []byte, newModel string) []byte { if !gjson.GetBytes(body, "model").Exists() { return body } result, err := sjson.SetBytes(body, "model", newModel) if err != nil { log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err) return body } return result } // extractModelFromRequest attempts to extract the model name from various request formats func extractModelFromRequest(body []byte, c *gin.Context) string { // First try to parse from JSON body (OpenAI, Claude, etc.) // Check common model field names if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String { return result.String() } // For Gemini requests, model is in the URL path // Standard format: /models/{model}:generateContent -> :action parameter if action := c.Param("action"); action != "" { // Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro") parts := strings.Split(action, ":") if len(parts) > 0 && parts[0] != "" { return parts[0] } } // AMP CLI format: /publishers/google/models/{model}:method -> *path parameter // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent if path := c.Param("path"); path != "" { // Look for /models/{model}:method pattern if idx := strings.Index(path, "/models/"); idx >= 0 { modelPart := path[idx+8:] // Skip "/models/" // Split by colon to get model name if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 { return modelPart[:colonIdx] } } } return "" }