fix(thinking): fallback to upstream model for thinking support when alias not in registry

This commit is contained in:
hkfires
2025-12-31 18:07:13 +08:00
parent d00e3ea973
commit 8bf3305b2b
5 changed files with 89 additions and 38 deletions

View File

@@ -96,7 +96,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude) translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
@@ -191,7 +191,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, true) translated = normalizeAntigravityThinking(req.Model, translated, true)
@@ -527,7 +527,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude) translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
@@ -697,7 +697,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, payload) payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, payload)
payload = normalizeAntigravityThinking(req.Model, payload, isClaude) payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "project")

View File

@@ -78,7 +78,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli") to := sdktranslator.FromString("gemini-cli")
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload) basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload) basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
@@ -217,7 +217,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli") to := sdktranslator.FromString("gemini-cli")
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload) basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload) basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
@@ -421,7 +421,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
// Gemini CLI endpoint when iterating fallback variants. // Gemini CLI endpoint when iterating fallback variants.
for _, attemptModel := range models { for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload) payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "model")

View File

@@ -17,35 +17,51 @@ func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string
// Use the alias from metadata if available, as it's registered in the global registry // Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered. // with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata) lookupModel := util.ResolveOriginalModel(model, metadata)
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(lookupModel, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) { if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload return payload
} }
if !util.ModelSupportsThinking(lookupModel) { if !util.ModelSupportsThinking(thinkingModel) {
return payload return payload
} }
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(lookupModel, *budgetOverride) norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride) return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
} }
// applyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192)) // ApplyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking. // for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte { func ApplyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
// Use the alias from metadata if available, as it's registered in the global registry // Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered. // with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata) lookupModel := util.ResolveOriginalModel(model, metadata)
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(lookupModel, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) { if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload return payload
} }
if !util.ModelSupportsThinking(lookupModel) { if !util.ModelSupportsThinking(thinkingModel) {
return payload return payload
} }
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(lookupModel, *budgetOverride) norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride) return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)

View File

@@ -12,9 +12,18 @@ func ModelSupportsThinking(model string) bool {
if model == "" { if model == "" {
return false return false
} }
// First check the global dynamic registry
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil { if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil {
return info.Thinking != nil return info.Thinking != nil
} }
// Fallback: check static model definitions
if info := registry.LookupStaticModelInfo(model); info != nil {
return info.Thinking != nil
}
// Fallback: check Antigravity static config
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil {
return cfg.Thinking != nil
}
return false return false
} }
@@ -63,11 +72,19 @@ func thinkingRangeFromRegistry(model string) (found bool, min int, max int, zero
if model == "" { if model == "" {
return false, 0, 0, false, false return false, 0, 0, false, false
} }
info := registry.GetGlobalRegistry().GetModelInfo(model) // First check global dynamic registry
if info == nil || info.Thinking == nil { if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil && info.Thinking != nil {
return false, 0, 0, false, false return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
} }
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed // Fallback: check static model definitions
if info := registry.LookupStaticModelInfo(model); info != nil && info.Thinking != nil {
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
}
// Fallback: check Antigravity static config
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil && cfg.Thinking != nil {
return true, cfg.Thinking.Min, cfg.Thinking.Max, cfg.Thinking.ZeroAllowed, cfg.Thinking.DynamicAllowed
}
return false, 0, 0, false, false
} }
// GetModelThinkingLevels returns the discrete reasoning effort levels for the model. // GetModelThinkingLevels returns the discrete reasoning effort levels for the model.

View File

@@ -3,6 +3,7 @@ package test
import ( import (
"testing" "testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -80,8 +81,9 @@ func TestModelAliasThinkingSuffix(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Step 1: Parse model suffix // Step 1: Parse model suffix (simulates SDK layer normalization)
baseModel, metadata := util.NormalizeThinkingModel(tt.requestModel) // For "gp(1000)" -> requestedModel="gp", metadata={thinking_budget: 1000}
requestedModel, metadata := util.NormalizeThinkingModel(tt.requestModel)
// Verify suffix was parsed // Verify suffix was parsed
if metadata == nil && (tt.suffixType == "numeric" || tt.suffixType == "level") { if metadata == nil && (tt.suffixType == "numeric" || tt.suffixType == "level") {
@@ -89,12 +91,13 @@ func TestModelAliasThinkingSuffix(t *testing.T) {
return return
} }
// Step 2: For aliases, simulate the model mapping by adding upstream model info // Step 2: Simulate OAuth model mapping
// Real flow: applyOAuthModelMapping stores requestedModel (the alias) in metadata
if tt.isAlias { if tt.isAlias {
if metadata == nil { if metadata == nil {
metadata = make(map[string]any) metadata = make(map[string]any)
} }
metadata[util.ModelMappingOriginalModelMetadataKey] = baseModel metadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
} }
// Step 3: Verify metadata extraction // Step 3: Verify metadata extraction
@@ -151,12 +154,15 @@ func TestModelAliasThinkingSuffix(t *testing.T) {
if tt.expectedField == "thinkingLevel" && util.IsGemini3Model(tt.upstreamModel) { if tt.expectedField == "thinkingLevel" && util.IsGemini3Model(tt.upstreamModel) {
body := []byte(`{"request":{"contents":[]}}`) body := []byte(`{"request":{"contents":[]}}`)
// Build metadata for the function // Build metadata simulating real OAuth flow:
// - requestedModel (alias like "gf") is stored in model_mapping_original_model
// - upstreamModel is passed as the model parameter
testMetadata := make(map[string]any) testMetadata := make(map[string]any)
if tt.isAlias { if tt.isAlias {
testMetadata[util.ModelMappingOriginalModelMetadataKey] = tt.upstreamModel // Real flow: applyOAuthModelMapping stores requestedModel (the alias)
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
} }
// Copy parsed metadata // Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
for k, v := range metadata { for k, v := range metadata {
testMetadata[k] = v testMetadata[k] = v
} }
@@ -172,20 +178,32 @@ func TestModelAliasThinkingSuffix(t *testing.T) {
} }
} }
// Step 5: Test Gemini 2.5 thinkingBudget application // Step 5: Test Gemini 2.5 thinkingBudget application using real ApplyThinkingMetadataCLI flow
if tt.expectedField == "thinkingBudget" && util.IsGemini25Model(tt.upstreamModel) { if tt.expectedField == "thinkingBudget" && util.IsGemini25Model(tt.upstreamModel) {
budget, _, _, _ := util.ThinkingFromMetadata(metadata) body := []byte(`{"request":{"contents":[]}}`)
if budget != nil {
body := []byte(`{"request":{"contents":[]}}`)
result := util.ApplyGeminiCLIThinkingConfig(body, budget, nil)
budgetVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget")
expectedBudget := tt.expectedValue.(int) // Build metadata simulating real OAuth flow:
if !budgetVal.Exists() { // - requestedModel (alias like "gp") is stored in model_mapping_original_model
t.Errorf("Case #%d: expected thinkingBudget in result", tt.id) // - upstreamModel is passed as the model parameter
} else if int(budgetVal.Int()) != expectedBudget { testMetadata := make(map[string]any)
t.Errorf("Case #%d: thinkingBudget = %d, want %d", tt.id, int(budgetVal.Int()), expectedBudget) if tt.isAlias {
} // Real flow: applyOAuthModelMapping stores requestedModel (the alias)
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
}
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
for k, v := range metadata {
testMetadata[k] = v
}
// Use the exported ApplyThinkingMetadataCLI which includes the fallback logic
result := executor.ApplyThinkingMetadataCLI(body, testMetadata, tt.upstreamModel)
budgetVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget")
expectedBudget := tt.expectedValue.(int)
if !budgetVal.Exists() {
t.Errorf("Case #%d: expected thinkingBudget in result", tt.id)
} else if int(budgetVal.Int()) != expectedBudget {
t.Errorf("Case #%d: thinkingBudget = %d, want %d", tt.id, int(budgetVal.Int()), expectedBudget)
} }
} }
}) })