mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 20:30:51 +08:00
feat(amp): enhance model mapping and Gemini thinking configuration
This commit introduces several improvements to the AMP (Advanced Model Proxy) module: - **Model Mapping Logic:** The `FallbackHandler` now uses a more robust approach for model mapping. It includes the extraction and preservation of dynamic "thinking suffixes" (e.g., `(xhigh)`) during mapping, ensuring that these configurations are correctly applied to the mapped model. A new `resolveMappedModel` function centralizes this logic for cleaner code. - **ModelMapper Verification:** The `ModelMapper` in `model_mapping.go` now verifies that the target model of a mapping has available providers *after* normalizing it. This prevents mappings to non-existent or unresolvable models. - **Gemini Thinking Configuration Cleanup:** In `gemini_thinking.go`, unnecessary `generationConfig.thinkingConfig.include_thoughts` and `generationConfig.thinkingConfig.thinkingBudget` fields are now deleted from the request body when applying Gemini thinking levels. This prevents potential conflicts or redundant configurations. - **Testing:** A new test case `TestModelMapper_MapModel_TargetWithThinkingSuffix` has been added to `model_mapping_test.go` to specifically cover the preservation of thinking suffixes during model mapping.
This commit is contained in:
@@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Normalize model (handles dynamic thinking suffixes)
|
// Normalize model (handles dynamic thinking suffixes)
|
||||||
normalizedModel, _ := util.NormalizeThinkingModel(modelName)
|
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
|
||||||
|
thinkingSuffix := ""
|
||||||
|
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
|
||||||
|
thinkingSuffix = modelName[len(normalizedModel):]
|
||||||
|
}
|
||||||
|
|
||||||
|
resolveMappedModel := func() (string, []string) {
|
||||||
|
if fh.modelMapper == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||||
|
if mappedModel == "" {
|
||||||
|
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||||
|
}
|
||||||
|
mappedModel = strings.TrimSpace(mappedModel)
|
||||||
|
if mappedModel == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||||
|
// already specifies its own thinking suffix.
|
||||||
|
if thinkingSuffix != "" {
|
||||||
|
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
|
||||||
|
if mappedThinkingMetadata == nil {
|
||||||
|
mappedModel += thinkingSuffix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
|
||||||
|
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||||
|
if len(mappedProviders) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return mappedModel, mappedProviders
|
||||||
|
}
|
||||||
|
|
||||||
// Track resolved model for logging (may change if mapping is applied)
|
// Track resolved model for logging (may change if mapping is applied)
|
||||||
resolvedModel := normalizedModel
|
resolvedModel := normalizedModel
|
||||||
@@ -147,21 +183,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
if forceMappings {
|
if forceMappings {
|
||||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||||
// This allows users to route Amp requests to their preferred OAuth providers
|
// This allows users to route Amp requests to their preferred OAuth providers
|
||||||
if fh.modelMapper != nil {
|
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
// Mapping found and provider available - rewrite the model in request body
|
||||||
// Mapping found - check if we have a provider for the mapped model
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||||
mappedProviders := util.GetProviderName(mappedModel)
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
if len(mappedProviders) > 0 {
|
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||||
// Mapping found and provider available - rewrite the model in request body
|
c.Set(MappedModelContextKey, mappedModel)
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
resolvedModel = mappedModel
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
usedMapping = true
|
||||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
providers = mappedProviders
|
||||||
c.Set(MappedModelContextKey, mappedModel)
|
|
||||||
resolvedModel = mappedModel
|
|
||||||
usedMapping = true
|
|
||||||
providers = mappedProviders
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no mapping applied, check for local providers
|
// If no mapping applied, check for local providers
|
||||||
@@ -174,21 +204,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
|
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
// No providers configured - check if we have a model mapping
|
// No providers configured - check if we have a model mapping
|
||||||
if fh.modelMapper != nil {
|
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
// Mapping found and provider available - rewrite the model in request body
|
||||||
// Mapping found - check if we have a provider for the mapped model
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||||
mappedProviders := util.GetProviderName(mappedModel)
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
if len(mappedProviders) > 0 {
|
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||||
// Mapping found and provider available - rewrite the model in request body
|
c.Set(MappedModelContextKey, mappedModel)
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
resolvedModel = mappedModel
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
usedMapping = true
|
||||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
providers = mappedProviders
|
||||||
c.Set(MappedModelContextKey, mappedModel)
|
|
||||||
resolvedModel = mappedModel
|
|
||||||
usedMapping = true
|
|
||||||
providers = mappedProviders
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
// Log: Model was mapped to another model
|
// Log: Model was mapped to another model
|
||||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||||
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
c.Writer = rewriter
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
handler(c)
|
handler(c)
|
||||||
rewriter.Flush()
|
rewriter.Flush()
|
||||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel)
|
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
||||||
} else if len(providers) > 0 {
|
} else if len(providers) > 0 {
|
||||||
// Log: Using local provider (free)
|
// Log: Using local provider (free)
|
||||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
|||||||
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/http/httputil"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
||||||
|
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("test-client-amp-fallback")
|
||||||
|
|
||||||
|
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||||
|
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
||||||
|
})
|
||||||
|
|
||||||
|
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
||||||
|
|
||||||
|
handler := func(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"model": req.Model,
|
||||||
|
"seen_model": req.Model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
||||||
|
|
||||||
|
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
SeenModel string `json:"seen_model"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("Failed to parse response JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Model != "gpt-5.2(xhigh)" {
|
||||||
|
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
||||||
|
}
|
||||||
|
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
||||||
|
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -59,7 +59,8 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify target model has available providers
|
// Verify target model has available providers
|
||||||
providers := util.GetProviderName(targetModel)
|
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
|
||||||
|
providers := util.GetProviderName(normalizedTarget)
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
|
||||||
|
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("test-client-thinking")
|
||||||
|
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
result := mapper.MapModel("gpt-5.2-alias")
|
||||||
|
if result != "gpt-5.2(xhigh)" {
|
||||||
|
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
||||||
reg := registry.GetGlobalRegistry()
|
reg := registry.GetGlobalRegistry()
|
||||||
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
||||||
|
|||||||
@@ -136,6 +136,12 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool)
|
|||||||
updated = rewritten
|
updated = rewritten
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
|
||||||
|
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
}
|
||||||
|
if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
|
||||||
|
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget")
|
||||||
|
}
|
||||||
return updated
|
return updated
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,6 +173,12 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo
|
|||||||
updated = rewritten
|
updated = rewritten
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
|
||||||
|
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
}
|
||||||
|
if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
|
||||||
|
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||||
|
}
|
||||||
return updated
|
return updated
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user