fix(executor): use upstream model for thinking config and payload translation

This commit is contained in:
hkfires
2025-12-30 17:49:44 +08:00
parent 857c880f99
commit b055e00c1a
8 changed files with 255 additions and 162 deletions

View File

@@ -55,11 +55,17 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, false)
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
if upstreamModel == "" {
upstreamModel = strings.TrimSpace(req.Model)
}
translatedReq, body, err := e.translateRequest(req, opts, false, upstreamModel)
if err != nil {
return resp, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
endpoint := e.buildEndpoint(upstreamModel, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
@@ -109,11 +115,17 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, true)
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
if upstreamModel == "" {
upstreamModel = strings.TrimSpace(req.Model)
}
translatedReq, body, err := e.translateRequest(req, opts, true, upstreamModel)
if err != nil {
return nil, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
endpoint := e.buildEndpoint(upstreamModel, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
@@ -254,7 +266,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
// CountTokens counts tokens for the given request using the AI Studio API.
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_, body, err := e.translateRequest(req, opts, false)
upstreamModel := strings.TrimSpace(util.ResolveOriginalModel(req.Model, req.Metadata))
if upstreamModel == "" {
upstreamModel = strings.TrimSpace(req.Model)
}
_, body, err := e.translateRequest(req, opts, false, upstreamModel)
if err != nil {
return cliproxyexecutor.Response{}, err
}
@@ -263,7 +280,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings")
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
endpoint := e.buildEndpoint(upstreamModel, "countTokens", "")
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
@@ -318,18 +335,23 @@ type translatedPayload struct {
toFormat sdktranslator.Format
}
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool, upstreamModel string) ([]byte, translatedPayload, error) {
model := strings.TrimSpace(upstreamModel)
if model == "" {
model = strings.TrimSpace(req.Model)
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true)
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfig(e.cfg, req.Model, payload)
payload := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
payload = ApplyThinkingMetadata(payload, req.Metadata, model)
payload = util.ApplyGemini3ThinkingLevelFromMetadata(model, req.Metadata, payload)
payload = util.ApplyDefaultThinkingIfNeeded(model, payload)
payload = util.ConvertThinkingLevelToBudget(payload, model, true)
payload = util.NormalizeGeminiThinkingBudget(model, payload, true)
payload = util.StripThinkingConfigIfUnsupported(model, payload)
payload = fixGeminiImageAspectRatio(model, payload)
payload = applyPayloadConfig(e.cfg, model, payload)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")