diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 394a295e..3cd8cf8e 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -55,17 +55,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) - if upstreamModel == "" { - upstreamModel = strings.TrimSpace(req.Model) - } - - translatedReq, body, err := e.translateRequest(req, opts, false, upstreamModel) + translatedReq, body, err := e.translateRequest(req, opts, false, req.Model) if err != nil { return resp, err } - endpoint := e.buildEndpoint(upstreamModel, body.action, opts.Alt) + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, URL: endpoint, @@ -115,17 +110,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) - if upstreamModel == "" { - upstreamModel = strings.TrimSpace(req.Model) - } - - translatedReq, body, err := e.translateRequest(req, opts, true, upstreamModel) + translatedReq, body, err := e.translateRequest(req, opts, true, req.Model) if err != nil { return nil, err } - endpoint := e.buildEndpoint(upstreamModel, body.action, opts.Alt) + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, URL: endpoint, @@ -266,12 +256,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth // CountTokens counts tokens for the given request using the AI Studio API. func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) - if upstreamModel == "" { - upstreamModel = strings.TrimSpace(req.Model) - } - - _, body, err := e.translateRequest(req, opts, false, upstreamModel) + _, body, err := e.translateRequest(req, opts, false, req.Model) if err != nil { return cliproxyexecutor.Response{}, err } @@ -280,7 +265,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A body.payload, _ = sjson.DeleteBytes(body.payload, "tools") body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings") - endpoint := e.buildEndpoint(upstreamModel, "countTokens", "") + endpoint := e.buildEndpoint(req.Model, "countTokens", "") wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, URL: endpoint, diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index c2aa4706..950141f0 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -76,11 +76,7 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") if isClaude { return e.executeClaudeNonStream(ctx, auth, req, opts) } @@ -98,13 +94,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au from := opts.SourceFormat to := sdktranslator.FromString("antigravity") - translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) + translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel) - translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated) - translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated) - translated = normalizeAntigravityThinking(upstreamModel, translated, isClaude) - translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated) + translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) + translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, isClaude) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -114,7 +110,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -191,20 +187,15 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("antigravity") - translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) + translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel) - translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated) - translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated) - translated = normalizeAntigravityThinking(upstreamModel, translated, true) - translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated) + translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) + translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, true) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -214,7 +205,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -530,21 +521,17 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") from := opts.SourceFormat to := sdktranslator.FromString("antigravity") - translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) + translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel) - translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated) - translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated) - translated = normalizeAntigravityThinking(upstreamModel, translated, isClaude) - translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated) + translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) + translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) + translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) + translated = normalizeAntigravityThinking(req.Model, translated, isClaude) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -554,7 +541,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return nil, err @@ -692,11 +679,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut to := sdktranslator.FromString("antigravity") respCtx := context.WithValue(ctx, "alt", opts.Alt) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -713,10 +696,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut var lastErr error for idx, baseURL := range baseURLs { - payload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - payload = applyThinkingMetadataCLI(payload, req.Metadata, upstreamModel) - payload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, payload) - payload = normalizeAntigravityThinking(upstreamModel, payload, isClaude) + payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) + payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload) + payload = normalizeAntigravityThinking(req.Model, payload, isClaude) payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "request.safetySettings") diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 52c60163..f74dc1e0 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -49,36 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. stream := from != to - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), stream) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream) + body, _ = sjson.SetBytes(body, "model", model) // Inject thinking config based on model metadata for thinking variants - body = e.injectThinkingConfig(upstreamModel, req.Metadata, body) + body = e.injectThinkingConfig(model, req.Metadata, body) - if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { + if !strings.HasPrefix(model, "claude-3-5-haiku") { body = checkSystemInstructions(body) } - body = applyPayloadConfig(e.cfg, upstreamModel, body) + body = applyPayloadConfig(e.cfg, model, body) // Disable thinking if tool_choice forces tool use (Anthropic API constraint) body = disableThinkingIfToolChoiceForced(body) // Ensure max_tokens > thinking.budget_tokens when thinking is enabled - body = ensureMaxTokensForThinking(upstreamModel, body) + body = ensureMaxTokensForThinking(model, body) // Extract betas from body and convert to header var extraBetas []string @@ -170,29 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("claude") - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } - } - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + body, _ = sjson.SetBytes(body, "model", model) // Inject thinking config based on model metadata for thinking variants - body = e.injectThinkingConfig(upstreamModel, req.Metadata, body) + body = e.injectThinkingConfig(model, req.Metadata, body) body = checkSystemInstructions(body) - body = applyPayloadConfig(e.cfg, upstreamModel, body) + body = applyPayloadConfig(e.cfg, model, body) // Disable thinking if tool_choice forces tool use (Anthropic API constraint) body = disableThinkingIfToolChoiceForced(body) // Ensure max_tokens > thinking.budget_tokens when thinking is enabled - body = ensureMaxTokensForThinking(upstreamModel, body) + body = ensureMaxTokensForThinking(model, body) // Extract betas from body and convert to header var extraBetas []string @@ -316,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. stream := from != to - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } - } - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), stream) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream) + body, _ = sjson.SetBytes(body, "model", model) - if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { + if !strings.HasPrefix(model, "claude-3-5-haiku") { body = checkSystemInstructions(body) } diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 71e36435..98678c4d 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -49,28 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false) - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body = NormalizeThinkingConfig(body, model, false) + if errValidate := ValidateThinkingConfig(body, model); errValidate != nil { return resp, errValidate } - body = applyPayloadConfig(e.cfg, upstreamModel, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.SetBytes(body, "stream", true) body, _ = sjson.DeleteBytes(body, "previous_response_id") @@ -156,30 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) - body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false) - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body = NormalizeThinkingConfig(body, model, false) + if errValidate := ValidateThinkingConfig(body, model); errValidate != nil { return nil, errValidate } - body = applyPayloadConfig(e.cfg, upstreamModel, body) + body = applyPayloadConfig(e.cfg, model, body) body, _ = sjson.DeleteBytes(body, "previous_response_id") - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body, _ = sjson.SetBytes(body, "model", model) url := strings.TrimSuffix(baseURL, "/") + "/responses" httpReq, err := e.cacheHelper(ctx, from, url, req, body) @@ -266,30 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) - modelForCounting := upstreamModel - - body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.SetBytes(body, "stream", false) - enc, err := tokenizerForCodexModel(modelForCounting) + enc, err := tokenizerForCodexModel(model) if err != nil { return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err) } diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 0be3bc76..a3b75839 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -75,21 +75,16 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) - if upstreamModel == "" { - upstreamModel = strings.TrimSpace(req.Model) - } - from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") - basePayload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, upstreamModel) - basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, basePayload) - basePayload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, basePayload) - basePayload = util.NormalizeGeminiCLIThinkingBudget(upstreamModel, basePayload) - basePayload = util.StripThinkingConfigIfUnsupported(upstreamModel, basePayload) - basePayload = fixGeminiCLIImageAspectRatio(upstreamModel, basePayload) - basePayload = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "gemini", "request", basePayload) + basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) + basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload) + basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload) + basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) + basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload) + basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) + basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload) action := "generateContent" if req.Metadata != nil { @@ -99,9 +94,9 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth } projectID := resolveGeminiProjectID(auth) - models := cliPreviewFallbackOrder(upstreamModel) - if len(models) == 0 || models[0] != upstreamModel { - models = append([]string{upstreamModel}, models...) + models := cliPreviewFallbackOrder(req.Model) + if len(models) == 0 || models[0] != req.Model { + models = append([]string{req.Model}, models...) } httpClient := newHTTPClient(ctx, e.cfg, auth, 0) @@ -115,10 +110,6 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth var lastStatus int var lastBody []byte - // NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be - // based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel - // is only used as the concrete model id sent to the upstream Gemini CLI endpoint (and the - // model label passed into response translation) when iterating fallback variants. for idx, attemptModel := range models { payload := append([]byte(nil), basePayload...) if action == "countTokens" { @@ -223,27 +214,22 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) - if upstreamModel == "" { - upstreamModel = strings.TrimSpace(req.Model) - } - from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") - basePayload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) - basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, upstreamModel) - basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, basePayload) - basePayload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, basePayload) - basePayload = util.NormalizeGeminiCLIThinkingBudget(upstreamModel, basePayload) - basePayload = util.StripThinkingConfigIfUnsupported(upstreamModel, basePayload) - basePayload = fixGeminiCLIImageAspectRatio(upstreamModel, basePayload) - basePayload = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "gemini", "request", basePayload) + basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model) + basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload) + basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload) + basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload) + basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload) + basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) + basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload) projectID := resolveGeminiProjectID(auth) - models := cliPreviewFallbackOrder(upstreamModel) - if len(models) == 0 || models[0] != upstreamModel { - models = append([]string{upstreamModel}, models...) + models := cliPreviewFallbackOrder(req.Model) + if len(models) == 0 || models[0] != req.Model { + models = append([]string{req.Model}, models...) } httpClient := newHTTPClient(ctx, e.cfg, auth, 0) @@ -257,10 +243,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut var lastStatus int var lastBody []byte - // NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be - // based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel - // is only used as the concrete model id sent to the upstream Gemini CLI endpoint (and the - // model label passed into response translation) when iterating fallback variants. for idx, attemptModel := range models { payload := append([]byte(nil), basePayload...) payload = setJSONField(payload, "project", projectID) @@ -417,14 +399,9 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") - upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) - if upstreamModel == "" { - upstreamModel = strings.TrimSpace(req.Model) - } - - models := cliPreviewFallbackOrder(upstreamModel) - if len(models) == 0 || models[0] != upstreamModel { - models = append([]string{upstreamModel}, models...) + models := cliPreviewFallbackOrder(req.Model) + if len(models) == 0 || models[0] != req.Model { + models = append([]string{req.Model}, models...) } httpClient := newHTTPClient(ctx, e.cfg, auth, 0) @@ -440,19 +417,17 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. var lastStatus int var lastBody []byte - // NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be - // based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel - // is only used as the concrete model id sent to the upstream Gemini CLI endpoint when iterating - // fallback variants. + // The loop variable attemptModel is only used as the concrete model id sent to the upstream + // Gemini CLI endpoint when iterating fallback variants. for _, attemptModel := range models { payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) - payload = applyThinkingMetadataCLI(payload, req.Metadata, upstreamModel) - payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, payload) + payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) + payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload) payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "request.safetySettings") - payload = util.StripThinkingConfigIfUnsupported(upstreamModel, payload) - payload = fixGeminiCLIImageAspectRatio(upstreamModel, payload) + payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) + payload = fixGeminiCLIImageAspectRatio(req.Model, payload) tok, errTok := tokenSource.Token() if errTok != nil { diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index da57150d..d69044b8 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -77,26 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override } // Official Gemini API via API key or OAuth bearer from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - body = ApplyThinkingMetadata(body, req.Metadata, req.Model) - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + body = ApplyThinkingMetadata(body, req.Metadata, model) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) action := "generateContent" if req.Metadata != nil { @@ -105,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } } baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action) + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -180,28 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - body = ApplyThinkingMetadata(body, req.Metadata, req.Model) - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + body = ApplyThinkingMetadata(body, req.Metadata, model) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -301,29 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { apiKey, bearer := geminiCreds(auth) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model) - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model) + translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(model, translatedReq) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", model) baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "countTokens") + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens") requestBody := bytes.NewReader(translatedReq) diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 03470bec..f8f4a63a 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -120,27 +120,22 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) budgetOverride = &norm } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body) - body = util.NormalizeGeminiThinkingBudget(upstreamModel, body) - body = util.StripThinkingConfigIfUnsupported(upstreamModel, body) - body = fixGeminiImageAspectRatio(upstreamModel, body) - body = applyPayloadConfig(e.cfg, upstreamModel, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + body, _ = sjson.SetBytes(body, "model", req.Model) action := "generateContent" if req.Metadata != nil { @@ -149,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } } baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -223,27 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body) - body = util.NormalizeGeminiThinkingBudget(upstreamModel, body) - body = util.StripThinkingConfigIfUnsupported(upstreamModel, body) - body = fixGeminiImageAspectRatio(upstreamModel, body) - body = applyPayloadConfig(e.cfg, upstreamModel, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) action := "generateContent" if req.Metadata != nil { @@ -256,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action) + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -327,30 +322,25 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) budgetOverride = &norm } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body) - body = util.NormalizeGeminiThinkingBudget(upstreamModel, body) - body = util.StripThinkingConfigIfUnsupported(upstreamModel, body) - body = fixGeminiImageAspectRatio(upstreamModel, body) - body = applyPayloadConfig(e.cfg, upstreamModel, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) + body = util.NormalizeGeminiThinkingBudget(req.Model, body) + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + body, _ = sjson.SetBytes(body, "model", req.Model) baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -447,33 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body) - body = util.NormalizeGeminiThinkingBudget(upstreamModel, body) - body = util.StripThinkingConfigIfUnsupported(upstreamModel, body) - body = fixGeminiImageAspectRatio(upstreamModel, body) - body = applyPayloadConfig(e.cfg, upstreamModel, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) // For API key auth, use simpler URL format without project/location if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -564,31 +554,26 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth // countTokensWithServiceAccount counts tokens using service account credentials. func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) budgetOverride = &norm } translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) } - translatedReq = util.StripThinkingConfigIfUnsupported(upstreamModel, translatedReq) - translatedReq = fixGeminiImageAspectRatio(upstreamModel, translatedReq) - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens") + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) if errNewReq != nil { @@ -656,24 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context // countTokensWithAPIKey handles token counting using API key credentials. func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { + translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) } - translatedReq = util.StripThinkingConfigIfUnsupported(upstreamModel, translatedReq) - translatedReq = fixGeminiImageAspectRatio(upstreamModel, translatedReq) - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(model, translatedReq) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", model) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") @@ -683,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "countTokens") + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens") httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) if errNewReq != nil { @@ -826,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau } return tok.AccessToken, nil } + +// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration. +// It matches the requested model alias against configured models and returns the actual upstream name. +func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string { + trimmed := strings.TrimSpace(alias) + if trimmed == "" { + return "" + } + + entry := e.resolveVertexConfig(auth) + if entry == nil { + return "" + } + + normalizedModel, metadata := util.NormalizeThinkingModel(trimmed) + + // Candidate names to match against configured aliases/names. + candidates := []string{strings.TrimSpace(normalizedModel)} + if !strings.EqualFold(normalizedModel, trimmed) { + candidates = append(candidates, trimmed) + } + if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) { + candidates = append(candidates, original) + } + + for i := range entry.Models { + model := entry.Models[i] + name := strings.TrimSpace(model.Name) + modelAlias := strings.TrimSpace(model.Alias) + + for _, candidate := range candidates { + if candidate == "" { + continue + } + if modelAlias != "" && strings.EqualFold(modelAlias, candidate) { + if name != "" { + return name + } + return candidate + } + if name != "" && strings.EqualFold(name, candidate) { + return name + } + } + } + return "" +} + +// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth. +func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey { + if auth == nil || e.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range e.cfg.VertexCompatAPIKey { + entry := &e.cfg.VertexCompatAPIKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && attrBase != "" { + if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range e.cfg.VertexCompatAPIKey { + entry := &e.cfg.VertexCompatAPIKey[i] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) { + return entry + } + } + } + return nil +} diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 9ac1c9f3..49fd4eb7 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -54,25 +54,18 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if strings.TrimSpace(upstreamModel) == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return resp, errValidate } body = applyIFlowThinkingConfig(body) body = preserveReasoningContentInMessages(body) - body = applyPayloadConfig(e.cfg, upstreamModel, body) + body = applyPayloadConfig(e.cfg, req.Model, body) endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint @@ -150,21 +143,14 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if strings.TrimSpace(upstreamModel) == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return nil, errValidate } body = applyIFlowThinkingConfig(body) @@ -174,7 +160,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 { body = ensureToolsArray(body) } - body = applyPayloadConfig(e.cfg, upstreamModel, body) + body = applyPayloadConfig(e.cfg, req.Model, body) endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint @@ -257,16 +243,11 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if strings.TrimSpace(upstreamModel) == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - enc, err := tokenizerForModel(upstreamModel) + enc, err := tokenizerForModel(req.Model) if err != nil { return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err) } diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 1c57c9b7..81fc31a1 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) allowCompat := e.allowCompatReasoningEffort(req.Model, auth) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" && modelOverride == "" { - translated, _ = sjson.SetBytes(translated, "model", upstreamModel) - } - translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat) - if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil { + translated = NormalizeThinkingConfig(translated, req.Model, allowCompat) + if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil { return resp, errValidate } @@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) allowCompat := e.allowCompatReasoningEffort(req.Model, auth) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" && modelOverride == "" { - translated, _ = sjson.SetBytes(translated, "model", upstreamModel) - } - translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat) - if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil { + translated = NormalizeThinkingConfig(translated, req.Model, allowCompat) + if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil { return nil, errValidate } diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index cf4aa6e3..ff6fa414 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -12,7 +12,6 @@ import ( qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -48,23 +47,16 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if strings.TrimSpace(upstreamModel) == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) - body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return resp, errValidate } - body = applyPayloadConfig(e.cfg, upstreamModel, body) + body = applyPayloadConfig(e.cfg, req.Model, body) url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) @@ -131,21 +123,14 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if strings.TrimSpace(upstreamModel) == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return nil, errValidate } toolsResult := gjson.GetBytes(body, "tools") @@ -155,7 +140,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`)) } body, _ = sjson.SetBytes(body, "stream_options.include_usage", true) - body = applyPayloadConfig(e.cfg, upstreamModel, body) + body = applyPayloadConfig(e.cfg, req.Model, body) url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) @@ -235,18 +220,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut } func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if strings.TrimSpace(upstreamModel) == "" { - upstreamModel = req.Model - } - from := opts.SourceFormat to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) modelName := gjson.GetBytes(body, "model").String() if strings.TrimSpace(modelName) == "" { - modelName = upstreamModel + modelName = req.Model } enc, err := tokenizerForModel(modelName) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index a6eaf3c5..c480d965 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -413,7 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.Execute(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -475,7 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -537,7 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { rerr := &Error{Message: errStream.Error()} diff --git a/sdk/cliproxy/auth/model_name_mappings.go b/sdk/cliproxy/auth/model_name_mappings.go index 483cb9c9..f1b31aa5 100644 --- a/sdk/cliproxy/auth/model_name_mappings.go +++ b/sdk/cliproxy/auth/model_name_mappings.go @@ -65,17 +65,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod m.modelNameMappings.Store(table) } -func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any { - original := m.resolveOAuthUpstreamModel(auth, requestedModel) - if original == "" { - return metadata - } - if metadata != nil { - if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok { - if s, okStr := v.(string); okStr && strings.EqualFold(s, original) { - return metadata - } - } +// applyOAuthModelMapping resolves the upstream model from OAuth model mappings +// and returns the resolved model along with updated metadata. If a mapping exists, +// the returned model is the upstream model and metadata contains the original +// requested model for response translation. +func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) { + upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel) + if upstreamModel == "" { + return requestedModel, metadata } out := make(map[string]any, 1) if len(metadata) > 0 { @@ -84,8 +81,8 @@ func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel stri out[k] = v } } - out[util.ModelMappingOriginalModelMetadataKey] = original - return out + out[util.ModelMappingOriginalModelMetadataKey] = upstreamModel + return upstreamModel, out } func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {