diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index e7132b81..940bd5e8 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } // Normalize model (handles dynamic thinking suffixes) - normalizedModel, _ := util.NormalizeThinkingModel(modelName) + normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName) + thinkingSuffix := "" + if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) { + thinkingSuffix = modelName[len(normalizedModel):] + } + + resolveMappedModel := func() (string, []string) { + if fh.modelMapper == nil { + return "", nil + } + + mappedModel := fh.modelMapper.MapModel(modelName) + if mappedModel == "" { + mappedModel = fh.modelMapper.MapModel(normalizedModel) + } + mappedModel = strings.TrimSpace(mappedModel) + if mappedModel == "" { + return "", nil + } + + // Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target + // already specifies its own thinking suffix. + if thinkingSuffix != "" { + _, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel) + if mappedThinkingMetadata == nil { + mappedModel += thinkingSuffix + } + } + + mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel) + mappedProviders := util.GetProviderName(mappedBaseModel) + if len(mappedProviders) == 0 { + return "", nil + } + + return mappedModel, mappedProviders + } // Track resolved model for logging (may change if mapping is applied) resolvedModel := normalizedModel @@ -147,21 +183,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if forceMappings { // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // This allows users to route Amp requests to their preferred OAuth providers - if fh.modelMapper != nil { - if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - check if we have a provider for the mapped model - mappedProviders := util.GetProviderName(mappedModel) - if len(mappedProviders) > 0 { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders - } - } + if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders } // If no mapping applied, check for local providers @@ -174,21 +204,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if len(providers) == 0 { // No providers configured - check if we have a model mapping - if fh.modelMapper != nil { - if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - check if we have a provider for the mapped model - mappedProviders := util.GetProviderName(mappedModel) - if len(mappedProviders) > 0 { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders - } - } + if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders } } } @@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Log: Model was mapped to another model log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) - rewriter := NewResponseRewriter(c.Writer, normalizedModel) + rewriter := NewResponseRewriter(c.Writer, modelName) c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) handler(c) rewriter.Flush() - log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel) + log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName) } else if len(providers) > 0 { // Log: Using local provider (free) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) diff --git a/internal/api/modules/amp/fallback_handlers_test.go b/internal/api/modules/amp/fallback_handlers_test.go new file mode 100644 index 00000000..a687fd11 --- /dev/null +++ b/internal/api/modules/amp/fallback_handlers_test.go @@ -0,0 +1,73 @@ +package amp + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "net/http/httputil" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{ + {ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"}, + }) + defer reg.UnregisterClient("test-client-amp-fallback") + + mapper := NewModelMapper([]config.AmpModelMapping{ + {From: "gpt-5.2", To: "test/gpt-5.2"}, + }) + + fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil) + + handler := func(c *gin.Context) { + var req struct { + Model string `json:"model"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "model": req.Model, + "seen_model": req.Model, + }) + } + + r := gin.New() + r.POST("/chat/completions", fallback.WrapHandler(handler)) + + reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`) + req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d", w.Code) + } + + var resp struct { + Model string `json:"model"` + SeenModel string `json:"seen_model"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to parse response JSON: %v", err) + } + + if resp.Model != "gpt-5.2(xhigh)" { + t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model) + } + if resp.SeenModel != "test/gpt-5.2(xhigh)" { + t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel) + } +} diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index 87384a80..bc31c4e5 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -59,7 +59,8 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string { } // Verify target model has available providers - providers := util.GetProviderName(targetModel) + normalizedTarget, _ := util.NormalizeThinkingModel(targetModel) + providers := util.GetProviderName(normalizedTarget) if len(providers) == 0 { log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) return "" diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go index 4f4e5a8e..664a17c5 100644 --- a/internal/api/modules/amp/model_mapping_test.go +++ b/internal/api/modules/amp/model_mapping_test.go @@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) { } } +func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{ + {ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"}, + }) + defer reg.UnregisterClient("test-client-thinking") + + mappings := []config.AmpModelMapping{ + {From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("gpt-5.2-alias") + if result != "gpt-5.2(xhigh)" { + t.Errorf("Expected gpt-5.2(xhigh), got %s", result) + } +} + func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) { reg := registry.GetGlobalRegistry() reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{ diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index ba9e13ef..290d5f92 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -136,6 +136,12 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool) updated = rewritten } } + if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() { + updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts") + } + if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() { + updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget") + } return updated } @@ -167,6 +173,12 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo updated = rewritten } } + if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() { + updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts") + } + if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() { + updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget") + } return updated }