refactor(executor): resolve upstream model at conductor level before execution

This commit is contained in:
hkfires
2025-12-30 19:31:54 +08:00
parent b055e00c1a
commit 96340bf136
12 changed files with 341 additions and 432 deletions

View File

@@ -55,17 +55,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) translatedReq, body, err := e.translateRequest(req, opts, false, req.Model)
if upstreamModel == "" {
upstreamModel = strings.TrimSpace(req.Model)
}
translatedReq, body, err := e.translateRequest(req, opts, false, upstreamModel)
if err != nil { if err != nil {
return resp, err return resp, err
} }
endpoint := e.buildEndpoint(upstreamModel, body.action, opts.Alt) endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{ wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost, Method: http.MethodPost,
URL: endpoint, URL: endpoint,
@@ -115,17 +110,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) translatedReq, body, err := e.translateRequest(req, opts, true, req.Model)
if upstreamModel == "" {
upstreamModel = strings.TrimSpace(req.Model)
}
translatedReq, body, err := e.translateRequest(req, opts, true, upstreamModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
endpoint := e.buildEndpoint(upstreamModel, body.action, opts.Alt) endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{ wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost, Method: http.MethodPost,
URL: endpoint, URL: endpoint,
@@ -266,12 +256,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
// CountTokens counts tokens for the given request using the AI Studio API. // CountTokens counts tokens for the given request using the AI Studio API.
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) _, body, err := e.translateRequest(req, opts, false, req.Model)
if upstreamModel == "" {
upstreamModel = strings.TrimSpace(req.Model)
}
_, body, err := e.translateRequest(req, opts, false, upstreamModel)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, err return cliproxyexecutor.Response{}, err
} }
@@ -280,7 +265,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
body.payload, _ = sjson.DeleteBytes(body.payload, "tools") body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings") body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings")
endpoint := e.buildEndpoint(upstreamModel, "countTokens", "") endpoint := e.buildEndpoint(req.Model, "countTokens", "")
wsReq := &wsrelay.HTTPRequest{ wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost, Method: http.MethodPost,
URL: endpoint, URL: endpoint,

View File

@@ -76,11 +76,7 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
// Execute performs a non-streaming request to the Antigravity API. // Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
if isClaude { if isClaude {
return e.executeClaudeNonStream(ctx, auth, req, opts) return e.executeClaudeNonStream(ctx, auth, req, opts)
} }
@@ -98,13 +94,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(upstreamModel, translated, isClaude) translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -114,7 +110,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return resp, err return resp, err
@@ -191,20 +187,15 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(upstreamModel, translated, true) translated = normalizeAntigravityThinking(req.Model, translated, true)
translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -214,7 +205,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return resp, err return resp, err
@@ -530,21 +521,17 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(upstreamModel, translated, isClaude) translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -554,7 +541,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return nil, err return nil, err
@@ -692,11 +679,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -713,10 +696,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
payload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, upstreamModel) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, payload) payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
payload = normalizeAntigravityThinking(upstreamModel, payload, isClaude) payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings") payload = deleteJSONField(payload, "request.safetySettings")

View File

@@ -49,36 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} }
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude. // Use streaming translation to preserve function calling, except for claude.
stream := from != to stream := from != to
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), stream) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants // Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(upstreamModel, req.Metadata, body) body = e.injectThinkingConfig(model, req.Metadata, body)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
} }
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, model, body)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint) // Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body) body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled // Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(upstreamModel, body) body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header // Extract betas from body and convert to header
var extraBetas []string var extraBetas []string
@@ -170,29 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
} }
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
upstreamModel = modelOverride body, _ = sjson.SetBytes(body, "model", model)
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
// Inject thinking config based on model metadata for thinking variants // Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(upstreamModel, req.Metadata, body) body = e.injectThinkingConfig(model, req.Metadata, body)
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, model, body)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint) // Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body) body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled // Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(upstreamModel, body) body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header // Extract betas from body and convert to header
var extraBetas []string var extraBetas []string
@@ -316,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude. // Use streaming translation to preserve function calling, except for claude.
stream := from != to stream := from != to
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
} }
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
upstreamModel = modelOverride body, _ = sjson.SetBytes(body, "model", model)
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
} }

View File

@@ -49,28 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false) body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.SetBytes(body, "stream", true) body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
@@ -156,30 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false) body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
url := strings.TrimSuffix(baseURL, "/") + "/responses" url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body) httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -266,30 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
} }
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
modelForCounting := upstreamModel body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", model)
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "stream", false) body, _ = sjson.SetBytes(body, "stream", false)
enc, err := tokenizerForCodexModel(modelForCounting) enc, err := tokenizerForCodexModel(model)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
} }

View File

@@ -75,21 +75,16 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
if upstreamModel == "" {
upstreamModel = strings.TrimSpace(req.Model)
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli") to := sdktranslator.FromString("gemini-cli")
basePayload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, upstreamModel) basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, basePayload) basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, basePayload) basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(upstreamModel, basePayload) basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(upstreamModel, basePayload) basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(upstreamModel, basePayload) basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "gemini", "request", basePayload) basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -99,9 +94,9 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
} }
projectID := resolveGeminiProjectID(auth) projectID := resolveGeminiProjectID(auth)
models := cliPreviewFallbackOrder(upstreamModel) models := cliPreviewFallbackOrder(req.Model)
if len(models) == 0 || models[0] != upstreamModel { if len(models) == 0 || models[0] != req.Model {
models = append([]string{upstreamModel}, models...) models = append([]string{req.Model}, models...)
} }
httpClient := newHTTPClient(ctx, e.cfg, auth, 0) httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
@@ -115,10 +110,6 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
var lastStatus int var lastStatus int
var lastBody []byte var lastBody []byte
// NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be
// based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel
// is only used as the concrete model id sent to the upstream Gemini CLI endpoint (and the
// model label passed into response translation) when iterating fallback variants.
for idx, attemptModel := range models { for idx, attemptModel := range models {
payload := append([]byte(nil), basePayload...) payload := append([]byte(nil), basePayload...)
if action == "countTokens" { if action == "countTokens" {
@@ -223,27 +214,22 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
if upstreamModel == "" {
upstreamModel = strings.TrimSpace(req.Model)
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli") to := sdktranslator.FromString("gemini-cli")
basePayload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, upstreamModel) basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, basePayload) basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, basePayload) basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(upstreamModel, basePayload) basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(upstreamModel, basePayload) basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(upstreamModel, basePayload) basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "gemini", "request", basePayload) basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
projectID := resolveGeminiProjectID(auth) projectID := resolveGeminiProjectID(auth)
models := cliPreviewFallbackOrder(upstreamModel) models := cliPreviewFallbackOrder(req.Model)
if len(models) == 0 || models[0] != upstreamModel { if len(models) == 0 || models[0] != req.Model {
models = append([]string{upstreamModel}, models...) models = append([]string{req.Model}, models...)
} }
httpClient := newHTTPClient(ctx, e.cfg, auth, 0) httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
@@ -257,10 +243,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var lastStatus int var lastStatus int
var lastBody []byte var lastBody []byte
// NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be
// based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel
// is only used as the concrete model id sent to the upstream Gemini CLI endpoint (and the
// model label passed into response translation) when iterating fallback variants.
for idx, attemptModel := range models { for idx, attemptModel := range models {
payload := append([]byte(nil), basePayload...) payload := append([]byte(nil), basePayload...)
payload = setJSONField(payload, "project", projectID) payload = setJSONField(payload, "project", projectID)
@@ -417,14 +399,9 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli") to := sdktranslator.FromString("gemini-cli")
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata)) models := cliPreviewFallbackOrder(req.Model)
if upstreamModel == "" { if len(models) == 0 || models[0] != req.Model {
upstreamModel = strings.TrimSpace(req.Model) models = append([]string{req.Model}, models...)
}
models := cliPreviewFallbackOrder(upstreamModel)
if len(models) == 0 || models[0] != upstreamModel {
models = append([]string{upstreamModel}, models...)
} }
httpClient := newHTTPClient(ctx, e.cfg, auth, 0) httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
@@ -440,19 +417,17 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
var lastStatus int var lastStatus int
var lastBody []byte var lastBody []byte
// NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be // The loop variable attemptModel is only used as the concrete model id sent to the upstream
// based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel // Gemini CLI endpoint when iterating fallback variants.
// is only used as the concrete model id sent to the upstream Gemini CLI endpoint when iterating
// fallback variants.
for _, attemptModel := range models { for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, upstreamModel) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, payload) payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings") payload = deleteJSONField(payload, "request.safetySettings")
payload = util.StripThinkingConfigIfUnsupported(upstreamModel, payload) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiCLIImageAspectRatio(upstreamModel, payload) payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
tok, errTok := tokenSource.Token() tok, errTok := tokenSource.Token()
if errTok != nil { if errTok != nil {

View File

@@ -77,26 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { if override := e.resolveUpstreamModel(model, auth); override != "" {
upstreamModel = modelOverride model = override
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
// Official Gemini API via API key or OAuth bearer // Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model) body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -105,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} }
} }
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action) url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -180,28 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { if override := e.resolveUpstreamModel(model, auth); override != "" {
upstreamModel = modelOverride model = override
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model) body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -301,29 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, bearer := geminiCreds(auth) apiKey, bearer := geminiCreds(auth)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { if override := e.resolveUpstreamModel(model, auth); override != "" {
upstreamModel = modelOverride model = override
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model) translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "countTokens") url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
requestBody := bytes.NewReader(translatedReq) requestBody := bytes.NewReader(translatedReq)

View File

@@ -120,27 +120,22 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
} }
body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body) body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(upstreamModel, body) body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(upstreamModel, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(upstreamModel, body) body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", req.Model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -149,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
} }
} }
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action) url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -223,27 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
} }
body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(upstreamModel, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(upstreamModel, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(upstreamModel, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -256,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action) url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -327,30 +322,25 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
} }
body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body) body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(upstreamModel, body) body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(upstreamModel, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(upstreamModel, body) body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", req.Model)
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -447,33 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
} }
body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(upstreamModel, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(upstreamModel, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(upstreamModel, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
// For API key auth, use simpler URL format without project/location // For API key auth, use simpler URL format without project/location
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -564,31 +554,26 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
// countTokensWithServiceAccount counts tokens using service account credentials. // countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
} }
translatedReq = util.StripThinkingConfigIfUnsupported(upstreamModel, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(upstreamModel, translatedReq) translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens") url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil { if errNewReq != nil {
@@ -656,24 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
// countTokensWithAPIKey handles token counting using API key credentials. // countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
} }
translatedReq = util.StripThinkingConfigIfUnsupported(upstreamModel, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(upstreamModel, translatedReq) translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -683,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "countTokens") url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil { if errNewReq != nil {
@@ -826,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
} }
return tok.AccessToken, nil return tok.AccessToken, nil
} }
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
// It matches the requested model alias against configured models and returns the actual upstream name.
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveVertexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -54,25 +54,18 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if strings.TrimSpace(upstreamModel) == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
if upstreamModel != "" { body, _ = sjson.SetBytes(body, "model", req.Model)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body = NormalizeThinkingConfig(body, req.Model, false)
} if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyIFlowThinkingConfig(body) body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body) body = preserveReasoningContentInMessages(body)
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, req.Model, body)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -150,21 +143,14 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if strings.TrimSpace(upstreamModel) == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
if upstreamModel != "" { body, _ = sjson.SetBytes(body, "model", req.Model)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body = NormalizeThinkingConfig(body, req.Model, false)
} if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
body = applyIFlowThinkingConfig(body) body = applyIFlowThinkingConfig(body)
@@ -174,7 +160,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 { if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body) body = ensureToolsArray(body)
} }
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, req.Model, body)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -257,16 +243,11 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
} }
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if strings.TrimSpace(upstreamModel) == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
enc, err := tokenizerForModel(upstreamModel) enc, err := tokenizerForModel(req.Model)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
} }

View File

@@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth) allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if upstreamModel != "" && modelOverride == "" { if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
@@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth) allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if upstreamModel != "" && modelOverride == "" { if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }

View File

@@ -12,7 +12,6 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -48,23 +47,16 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if strings.TrimSpace(upstreamModel) == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
if upstreamModel != "" { body, _ = sjson.SetBytes(body, "model", req.Model)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body = NormalizeThinkingConfig(body, req.Model, false)
} if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, req.Model, body)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -131,21 +123,14 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if strings.TrimSpace(upstreamModel) == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
if upstreamModel != "" { body, _ = sjson.SetBytes(body, "model", req.Model)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body = NormalizeThinkingConfig(body, req.Model, false)
} if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
toolsResult := gjson.GetBytes(body, "tools") toolsResult := gjson.GetBytes(body, "tools")
@@ -155,7 +140,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`)) body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
} }
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true) body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
body = applyPayloadConfig(e.cfg, upstreamModel, body) body = applyPayloadConfig(e.cfg, req.Model, body)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -235,18 +220,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
} }
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if strings.TrimSpace(upstreamModel) == "" {
upstreamModel = req.Model
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
modelName := gjson.GetBytes(body, "model").String() modelName := gjson.GetBytes(body, "model").String()
if strings.TrimSpace(modelName) == "" { if strings.TrimSpace(modelName) == "" {
modelName = upstreamModel modelName = req.Model
} }
enc, err := tokenizerForModel(modelName) enc, err := tokenizerForModel(modelName)

View File

@@ -413,7 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts) resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil { if errExec != nil {
@@ -475,7 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil { if errExec != nil {
@@ -537,7 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil { if errStream != nil {
rerr := &Error{Message: errStream.Error()} rerr := &Error{Message: errStream.Error()}

View File

@@ -65,17 +65,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod
m.modelNameMappings.Store(table) m.modelNameMappings.Store(table)
} }
func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any { // applyOAuthModelMapping resolves the upstream model from OAuth model mappings
original := m.resolveOAuthUpstreamModel(auth, requestedModel) // and returns the resolved model along with updated metadata. If a mapping exists,
if original == "" { // the returned model is the upstream model and metadata contains the original
return metadata // requested model for response translation.
} func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
if metadata != nil { upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok { if upstreamModel == "" {
if s, okStr := v.(string); okStr && strings.EqualFold(s, original) { return requestedModel, metadata
return metadata
}
}
} }
out := make(map[string]any, 1) out := make(map[string]any, 1)
if len(metadata) > 0 { if len(metadata) > 0 {
@@ -84,8 +81,8 @@ func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel stri
out[k] = v out[k] = v
} }
} }
out[util.ModelMappingOriginalModelMetadataKey] = original out[util.ModelMappingOriginalModelMetadataKey] = upstreamModel
return out return upstreamModel, out
} }
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string { func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {