mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-19 04:40:52 +08:00
refactor(executor): resolve upstream model at conductor level before execution
This commit is contained in:
@@ -55,17 +55,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
|
translatedReq, body, err := e.translateRequest(req, opts, false, req.Model)
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = strings.TrimSpace(req.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
translatedReq, body, err := e.translateRequest(req, opts, false, upstreamModel)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint := e.buildEndpoint(upstreamModel, body.action, opts.Alt)
|
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||||
wsReq := &wsrelay.HTTPRequest{
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
@@ -115,17 +110,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
|
translatedReq, body, err := e.translateRequest(req, opts, true, req.Model)
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = strings.TrimSpace(req.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
translatedReq, body, err := e.translateRequest(req, opts, true, upstreamModel)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint := e.buildEndpoint(upstreamModel, body.action, opts.Alt)
|
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||||
wsReq := &wsrelay.HTTPRequest{
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
@@ -266,12 +256,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
|
|
||||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||||
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
|
_, body, err := e.translateRequest(req, opts, false, req.Model)
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = strings.TrimSpace(req.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, body, err := e.translateRequest(req, opts, false, upstreamModel)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
@@ -280,7 +265,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
|||||||
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
|
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
|
||||||
body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings")
|
body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings")
|
||||||
|
|
||||||
endpoint := e.buildEndpoint(upstreamModel, "countTokens", "")
|
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
|
||||||
wsReq := &wsrelay.HTTPRequest{
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
|
|||||||
@@ -76,11 +76,7 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
|
|||||||
|
|
||||||
// Execute performs a non-streaming request to the Antigravity API.
|
// Execute performs a non-streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
|
|
||||||
if isClaude {
|
if isClaude {
|
||||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||||
}
|
}
|
||||||
@@ -98,13 +94,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel)
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated)
|
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||||
translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated)
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
translated = normalizeAntigravityThinking(upstreamModel, translated, isClaude)
|
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -114,7 +110,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL)
|
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -191,20 +187,15 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel)
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated)
|
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||||
translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated)
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
translated = normalizeAntigravityThinking(upstreamModel, translated, true)
|
translated = normalizeAntigravityThinking(req.Model, translated, true)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -214,7 +205,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL)
|
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -530,21 +521,17 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, upstreamModel)
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, translated)
|
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||||
translated = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, translated)
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
translated = normalizeAntigravityThinking(upstreamModel, translated, isClaude)
|
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "antigravity", "request", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -554,7 +541,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL)
|
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
err = errReq
|
err = errReq
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -692,11 +679,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
to := sdktranslator.FromString("antigravity")
|
to := sdktranslator.FromString("antigravity")
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
|
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -713,10 +696,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
for idx, baseURL := range baseURLs {
|
||||||
payload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, upstreamModel)
|
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||||
payload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, payload)
|
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||||
payload = normalizeAntigravityThinking(upstreamModel, payload, isClaude)
|
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
|
||||||
payload = deleteJSONField(payload, "project")
|
payload = deleteJSONField(payload, "project")
|
||||||
payload = deleteJSONField(payload, "model")
|
payload = deleteJSONField(payload, "model")
|
||||||
payload = deleteJSONField(payload, "request.safetySettings")
|
payload = deleteJSONField(payload, "request.safetySettings")
|
||||||
|
|||||||
@@ -49,36 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
}
|
}
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("claude")
|
to := sdktranslator.FromString("claude")
|
||||||
// Use streaming translation to preserve function calling, except for claude.
|
// Use streaming translation to preserve function calling, except for claude.
|
||||||
stream := from != to
|
stream := from != to
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), stream)
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
// Inject thinking config based on model metadata for thinking variants
|
// Inject thinking config based on model metadata for thinking variants
|
||||||
body = e.injectThinkingConfig(upstreamModel, req.Metadata, body)
|
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||||
|
|
||||||
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
|
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||||
body = checkSystemInstructions(body)
|
body = checkSystemInstructions(body)
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, model, body)
|
||||||
|
|
||||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||||
body = disableThinkingIfToolChoiceForced(body)
|
body = disableThinkingIfToolChoiceForced(body)
|
||||||
|
|
||||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||||
body = ensureMaxTokensForThinking(upstreamModel, body)
|
body = ensureMaxTokensForThinking(model, body)
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
@@ -170,29 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("claude")
|
to := sdktranslator.FromString("claude")
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
}
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||||
upstreamModel = modelOverride
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
|
||||||
// Inject thinking config based on model metadata for thinking variants
|
// Inject thinking config based on model metadata for thinking variants
|
||||||
body = e.injectThinkingConfig(upstreamModel, req.Metadata, body)
|
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||||
body = checkSystemInstructions(body)
|
body = checkSystemInstructions(body)
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, model, body)
|
||||||
|
|
||||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||||
body = disableThinkingIfToolChoiceForced(body)
|
body = disableThinkingIfToolChoiceForced(body)
|
||||||
|
|
||||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||||
body = ensureMaxTokensForThinking(upstreamModel, body)
|
body = ensureMaxTokensForThinking(model, body)
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
@@ -316,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
to := sdktranslator.FromString("claude")
|
to := sdktranslator.FromString("claude")
|
||||||
// Use streaming translation to preserve function calling, except for claude.
|
// Use streaming translation to preserve function calling, except for claude.
|
||||||
stream := from != to
|
stream := from != to
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
}
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||||
upstreamModel = modelOverride
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), stream)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
|
||||||
|
|
||||||
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
|
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||||
body = checkSystemInstructions(body)
|
body = checkSystemInstructions(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,28 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false)
|
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
body = NormalizeThinkingConfig(body, model, false)
|
||||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||||
return resp, errValidate
|
return resp, errValidate
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, model, body)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
|
|
||||||
@@ -156,30 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false)
|
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
body = NormalizeThinkingConfig(body, model, false)
|
||||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||||
return nil, errValidate
|
return nil, errValidate
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, model, body)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||||
@@ -266,30 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
modelForCounting := upstreamModel
|
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning.effort", false)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.SetBytes(body, "stream", false)
|
body, _ = sjson.SetBytes(body, "stream", false)
|
||||||
|
|
||||||
enc, err := tokenizerForCodexModel(modelForCounting)
|
enc, err := tokenizerForCodexModel(model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,21 +75,16 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
|
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = strings.TrimSpace(req.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, upstreamModel)
|
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, basePayload)
|
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, basePayload)
|
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(upstreamModel, basePayload)
|
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||||
basePayload = util.StripThinkingConfigIfUnsupported(upstreamModel, basePayload)
|
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(upstreamModel, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "gemini", "request", basePayload)
|
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -99,9 +94,9 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
projectID := resolveGeminiProjectID(auth)
|
projectID := resolveGeminiProjectID(auth)
|
||||||
models := cliPreviewFallbackOrder(upstreamModel)
|
models := cliPreviewFallbackOrder(req.Model)
|
||||||
if len(models) == 0 || models[0] != upstreamModel {
|
if len(models) == 0 || models[0] != req.Model {
|
||||||
models = append([]string{upstreamModel}, models...)
|
models = append([]string{req.Model}, models...)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -115,10 +110,6 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
var lastStatus int
|
var lastStatus int
|
||||||
var lastBody []byte
|
var lastBody []byte
|
||||||
|
|
||||||
// NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be
|
|
||||||
// based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel
|
|
||||||
// is only used as the concrete model id sent to the upstream Gemini CLI endpoint (and the
|
|
||||||
// model label passed into response translation) when iterating fallback variants.
|
|
||||||
for idx, attemptModel := range models {
|
for idx, attemptModel := range models {
|
||||||
payload := append([]byte(nil), basePayload...)
|
payload := append([]byte(nil), basePayload...)
|
||||||
if action == "countTokens" {
|
if action == "countTokens" {
|
||||||
@@ -223,27 +214,22 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
|
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = strings.TrimSpace(req.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, upstreamModel)
|
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, basePayload)
|
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(upstreamModel, basePayload)
|
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(upstreamModel, basePayload)
|
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||||
basePayload = util.StripThinkingConfigIfUnsupported(upstreamModel, basePayload)
|
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(upstreamModel, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, upstreamModel, "gemini", "request", basePayload)
|
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
||||||
|
|
||||||
projectID := resolveGeminiProjectID(auth)
|
projectID := resolveGeminiProjectID(auth)
|
||||||
|
|
||||||
models := cliPreviewFallbackOrder(upstreamModel)
|
models := cliPreviewFallbackOrder(req.Model)
|
||||||
if len(models) == 0 || models[0] != upstreamModel {
|
if len(models) == 0 || models[0] != req.Model {
|
||||||
models = append([]string{upstreamModel}, models...)
|
models = append([]string{req.Model}, models...)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -257,10 +243,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
var lastStatus int
|
var lastStatus int
|
||||||
var lastBody []byte
|
var lastBody []byte
|
||||||
|
|
||||||
// NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be
|
|
||||||
// based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel
|
|
||||||
// is only used as the concrete model id sent to the upstream Gemini CLI endpoint (and the
|
|
||||||
// model label passed into response translation) when iterating fallback variants.
|
|
||||||
for idx, attemptModel := range models {
|
for idx, attemptModel := range models {
|
||||||
payload := append([]byte(nil), basePayload...)
|
payload := append([]byte(nil), basePayload...)
|
||||||
payload = setJSONField(payload, "project", projectID)
|
payload = setJSONField(payload, "project", projectID)
|
||||||
@@ -417,14 +399,9 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
|
|
||||||
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
|
models := cliPreviewFallbackOrder(req.Model)
|
||||||
if upstreamModel == "" {
|
if len(models) == 0 || models[0] != req.Model {
|
||||||
upstreamModel = strings.TrimSpace(req.Model)
|
models = append([]string{req.Model}, models...)
|
||||||
}
|
|
||||||
|
|
||||||
models := cliPreviewFallbackOrder(upstreamModel)
|
|
||||||
if len(models) == 0 || models[0] != upstreamModel {
|
|
||||||
models = append([]string{upstreamModel}, models...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -440,19 +417,17 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
var lastStatus int
|
var lastStatus int
|
||||||
var lastBody []byte
|
var lastBody []byte
|
||||||
|
|
||||||
// NOTE: Model capability checks (thinking config, payload rules, image fixes, etc.) must be
|
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||||
// based on upstreamModel (resolved via oauth-model-mappings). The loop variable attemptModel
|
// Gemini CLI endpoint when iterating fallback variants.
|
||||||
// is only used as the concrete model id sent to the upstream Gemini CLI endpoint when iterating
|
|
||||||
// fallback variants.
|
|
||||||
for _, attemptModel := range models {
|
for _, attemptModel := range models {
|
||||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, upstreamModel)
|
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||||
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(upstreamModel, req.Metadata, payload)
|
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
|
||||||
payload = deleteJSONField(payload, "project")
|
payload = deleteJSONField(payload, "project")
|
||||||
payload = deleteJSONField(payload, "model")
|
payload = deleteJSONField(payload, "model")
|
||||||
payload = deleteJSONField(payload, "request.safetySettings")
|
payload = deleteJSONField(payload, "request.safetySettings")
|
||||||
payload = util.StripThinkingConfigIfUnsupported(upstreamModel, payload)
|
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||||
payload = fixGeminiCLIImageAspectRatio(upstreamModel, payload)
|
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
|
||||||
|
|
||||||
tok, errTok := tokenSource.Token()
|
tok, errTok := tokenSource.Token()
|
||||||
if errTok != nil {
|
if errTok != nil {
|
||||||
|
|||||||
@@ -77,26 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||||
upstreamModel = modelOverride
|
model = override
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Official Gemini API via API key or OAuth bearer
|
// Official Gemini API via API key or OAuth bearer
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
body = fixGeminiImageAspectRatio(model, body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, model, body)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -105,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action)
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
|
||||||
if opts.Alt != "" && action != "countTokens" {
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
}
|
}
|
||||||
@@ -180,28 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||||
upstreamModel = modelOverride
|
model = override
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
body = fixGeminiImageAspectRatio(model, body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, model, body)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
|
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
|
||||||
if opts.Alt == "" {
|
if opts.Alt == "" {
|
||||||
url = url + "?alt=sse"
|
url = url + "?alt=sse"
|
||||||
} else {
|
} else {
|
||||||
@@ -301,29 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
apiKey, bearer := geminiCreds(auth)
|
apiKey, bearer := geminiCreds(auth)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||||
upstreamModel = modelOverride
|
model = override
|
||||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
upstreamModel = modelOverride
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
||||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
||||||
|
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "countTokens")
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
|
||||||
|
|
||||||
requestBody := bytes.NewReader(translatedReq)
|
requestBody := bytes.NewReader(translatedReq)
|
||||||
|
|
||||||
|
|||||||
@@ -120,27 +120,22 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) {
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
if budgetOverride != nil {
|
if budgetOverride != nil {
|
||||||
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride)
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
budgetOverride = &norm
|
budgetOverride = &norm
|
||||||
}
|
}
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body)
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||||
body = util.NormalizeGeminiThinkingBudget(upstreamModel, body)
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(upstreamModel, body)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
body = fixGeminiImageAspectRatio(upstreamModel, body)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -149,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action)
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
|
||||||
if opts.Alt != "" && action != "countTokens" {
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
}
|
}
|
||||||
@@ -223,27 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) {
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||||
if budgetOverride != nil {
|
if budgetOverride != nil {
|
||||||
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride)
|
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||||
budgetOverride = &norm
|
budgetOverride = &norm
|
||||||
}
|
}
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body)
|
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||||
body = util.NormalizeGeminiThinkingBudget(upstreamModel, body)
|
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(upstreamModel, body)
|
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||||
body = fixGeminiImageAspectRatio(upstreamModel, body)
|
body = fixGeminiImageAspectRatio(model, body)
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, model, body)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -256,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action)
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
|
||||||
if opts.Alt != "" && action != "countTokens" {
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
}
|
}
|
||||||
@@ -327,30 +322,25 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) {
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
if budgetOverride != nil {
|
if budgetOverride != nil {
|
||||||
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride)
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
budgetOverride = &norm
|
budgetOverride = &norm
|
||||||
}
|
}
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body)
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||||
body = util.NormalizeGeminiThinkingBudget(upstreamModel, body)
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(upstreamModel, body)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
body = fixGeminiImageAspectRatio(upstreamModel, body)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||||
|
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
|
||||||
if opts.Alt == "" {
|
if opts.Alt == "" {
|
||||||
url = url + "?alt=sse"
|
url = url + "?alt=sse"
|
||||||
} else {
|
} else {
|
||||||
@@ -447,33 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) {
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||||
if budgetOverride != nil {
|
if budgetOverride != nil {
|
||||||
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride)
|
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||||
budgetOverride = &norm
|
budgetOverride = &norm
|
||||||
}
|
}
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(upstreamModel, body)
|
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||||
body = util.NormalizeGeminiThinkingBudget(upstreamModel, body)
|
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(upstreamModel, body)
|
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||||
body = fixGeminiImageAspectRatio(upstreamModel, body)
|
body = fixGeminiImageAspectRatio(model, body)
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, model, body)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body, _ = sjson.SetBytes(body, "model", model)
|
||||||
|
|
||||||
// For API key auth, use simpler URL format without project/location
|
// For API key auth, use simpler URL format without project/location
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
|
||||||
if opts.Alt == "" {
|
if opts.Alt == "" {
|
||||||
url = url + "?alt=sse"
|
url = url + "?alt=sse"
|
||||||
} else {
|
} else {
|
||||||
@@ -564,31 +554,26 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
|
|
||||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||||
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if upstreamModel == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) {
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
if budgetOverride != nil {
|
if budgetOverride != nil {
|
||||||
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride)
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
budgetOverride = &norm
|
budgetOverride = &norm
|
||||||
}
|
}
|
||||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(upstreamModel, translatedReq)
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||||
translatedReq = fixGeminiImageAspectRatio(upstreamModel, translatedReq)
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
|
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||||
|
|
||||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
if errNewReq != nil {
|
if errNewReq != nil {
|
||||||
@@ -656,24 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
|
|
||||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||||
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
model := req.Model
|
||||||
if upstreamModel == "" {
|
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||||
upstreamModel = req.Model
|
model = override
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(upstreamModel, req.Metadata); ok && util.ModelSupportsThinking(upstreamModel) {
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||||
if budgetOverride != nil {
|
if budgetOverride != nil {
|
||||||
norm := util.NormalizeThinkingBudget(upstreamModel, *budgetOverride)
|
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||||
budgetOverride = &norm
|
budgetOverride = &norm
|
||||||
}
|
}
|
||||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(upstreamModel, translatedReq)
|
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
||||||
translatedReq = fixGeminiImageAspectRatio(upstreamModel, translatedReq)
|
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
||||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
@@ -683,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "countTokens")
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
|
||||||
|
|
||||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
if errNewReq != nil {
|
if errNewReq != nil {
|
||||||
@@ -826,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
|
|||||||
}
|
}
|
||||||
return tok.AccessToken, nil
|
return tok.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
|
||||||
|
// It matches the requested model alias against configured models and returns the actual upstream name.
|
||||||
|
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||||
|
trimmed := strings.TrimSpace(alias)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := e.resolveVertexConfig(auth)
|
||||||
|
if entry == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||||
|
|
||||||
|
// Candidate names to match against configured aliases/names.
|
||||||
|
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||||
|
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||||
|
candidates = append(candidates, trimmed)
|
||||||
|
}
|
||||||
|
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||||
|
candidates = append(candidates, original)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range entry.Models {
|
||||||
|
model := entry.Models[i]
|
||||||
|
name := strings.TrimSpace(model.Name)
|
||||||
|
modelAlias := strings.TrimSpace(model.Alias)
|
||||||
|
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if candidate == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||||
|
if name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
if name != "" && strings.EqualFold(name, candidate) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
|
||||||
|
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
|
||||||
|
if auth == nil || e.cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var attrKey, attrBase string
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||||
|
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||||
|
}
|
||||||
|
for i := range e.cfg.VertexCompatAPIKey {
|
||||||
|
entry := &e.cfg.VertexCompatAPIKey[i]
|
||||||
|
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||||
|
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||||
|
if attrKey != "" && attrBase != "" {
|
||||||
|
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||||
|
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if attrKey != "" {
|
||||||
|
for i := range e.cfg.VertexCompatAPIKey {
|
||||||
|
entry := &e.cfg.VertexCompatAPIKey[i]
|
||||||
|
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,25 +54,18 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if strings.TrimSpace(upstreamModel) == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false)
|
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||||
if upstreamModel != "" {
|
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||||
}
|
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
|
||||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
|
||||||
return resp, errValidate
|
return resp, errValidate
|
||||||
}
|
}
|
||||||
body = applyIFlowThinkingConfig(body)
|
body = applyIFlowThinkingConfig(body)
|
||||||
body = preserveReasoningContentInMessages(body)
|
body = preserveReasoningContentInMessages(body)
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
|
|
||||||
@@ -150,21 +143,14 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if strings.TrimSpace(upstreamModel) == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false)
|
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||||
if upstreamModel != "" {
|
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||||
}
|
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
|
||||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
|
||||||
return nil, errValidate
|
return nil, errValidate
|
||||||
}
|
}
|
||||||
body = applyIFlowThinkingConfig(body)
|
body = applyIFlowThinkingConfig(body)
|
||||||
@@ -174,7 +160,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||||
body = ensureToolsArray(body)
|
body = ensureToolsArray(body)
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
|
|
||||||
@@ -257,16 +243,11 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if strings.TrimSpace(upstreamModel) == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
enc, err := tokenizerForModel(upstreamModel)
|
enc, err := tokenizerForModel(req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||||
if upstreamModel != "" && modelOverride == "" {
|
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
||||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
|
||||||
}
|
|
||||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
|
||||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
|
||||||
return resp, errValidate
|
return resp, errValidate
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||||
if upstreamModel != "" && modelOverride == "" {
|
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
||||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
|
||||||
}
|
|
||||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
|
||||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
|
||||||
return nil, errValidate
|
return nil, errValidate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -48,23 +47,16 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if strings.TrimSpace(upstreamModel) == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false)
|
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||||
if upstreamModel != "" {
|
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||||
}
|
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
|
||||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
|
||||||
return resp, errValidate
|
return resp, errValidate
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -131,21 +123,14 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if strings.TrimSpace(upstreamModel) == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, upstreamModel, "reasoning_effort", false)
|
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||||
if upstreamModel != "" {
|
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||||
}
|
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
|
||||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
|
||||||
return nil, errValidate
|
return nil, errValidate
|
||||||
}
|
}
|
||||||
toolsResult := gjson.GetBytes(body, "tools")
|
toolsResult := gjson.GetBytes(body, "tools")
|
||||||
@@ -155,7 +140,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||||
}
|
}
|
||||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||||
body = applyPayloadConfig(e.cfg, upstreamModel, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -235,18 +220,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
if strings.TrimSpace(upstreamModel) == "" {
|
|
||||||
upstreamModel = req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, upstreamModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
modelName := gjson.GetBytes(body, "model").String()
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
if strings.TrimSpace(modelName) == "" {
|
if strings.TrimSpace(modelName) == "" {
|
||||||
modelName = upstreamModel
|
modelName = req.Model
|
||||||
}
|
}
|
||||||
|
|
||||||
enc, err := tokenizerForModel(modelName)
|
enc, err := tokenizerForModel(modelName)
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
|||||||
}
|
}
|
||||||
execReq := req
|
execReq := req
|
||||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
if errExec != nil {
|
if errExec != nil {
|
||||||
@@ -475,7 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
|||||||
}
|
}
|
||||||
execReq := req
|
execReq := req
|
||||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
if errExec != nil {
|
if errExec != nil {
|
||||||
@@ -537,7 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
|||||||
}
|
}
|
||||||
execReq := req
|
execReq := req
|
||||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||||
if errStream != nil {
|
if errStream != nil {
|
||||||
rerr := &Error{Message: errStream.Error()}
|
rerr := &Error{Message: errStream.Error()}
|
||||||
|
|||||||
@@ -65,17 +65,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod
|
|||||||
m.modelNameMappings.Store(table)
|
m.modelNameMappings.Store(table)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any {
|
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings
|
||||||
original := m.resolveOAuthUpstreamModel(auth, requestedModel)
|
// and returns the resolved model along with updated metadata. If a mapping exists,
|
||||||
if original == "" {
|
// the returned model is the upstream model and metadata contains the original
|
||||||
return metadata
|
// requested model for response translation.
|
||||||
}
|
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
|
||||||
if metadata != nil {
|
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
|
||||||
if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok {
|
if upstreamModel == "" {
|
||||||
if s, okStr := v.(string); okStr && strings.EqualFold(s, original) {
|
return requestedModel, metadata
|
||||||
return metadata
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
out := make(map[string]any, 1)
|
out := make(map[string]any, 1)
|
||||||
if len(metadata) > 0 {
|
if len(metadata) > 0 {
|
||||||
@@ -84,8 +81,8 @@ func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel stri
|
|||||||
out[k] = v
|
out[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out[util.ModelMappingOriginalModelMetadataKey] = original
|
out[util.ModelMappingOriginalModelMetadataKey] = upstreamModel
|
||||||
return out
|
return upstreamModel, out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
|
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
|
||||||
|
|||||||
Reference in New Issue
Block a user