mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Fix critical bug where ExecuteStream would create a streaming channel from a failed (non-2xx) response after exhausting all retries with no fallback models available. When retries were exhausted on the last model, the code would break from the inner loop but fall through to streaming channel creation (line 401), immediately returning at line 461. This made the error handling code at lines 464-471 unreachable, causing clients to receive an empty/closed stream instead of a proper error response. Solution: Check if httpResp is non-2xx before creating the streaming channel. If failed, continue the outer loop to reach error handling. Identified by: codex-bot review Ref: https://github.com/router-for-me/CLIProxyAPI/pull/280#pullrequestreview-3484560423
907 lines
30 KiB
Go
907 lines
30 KiB
Go
package executor
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
const (
|
|
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
|
|
codeAssistVersion = "v1internal"
|
|
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
|
geminiOauthClientSecret = "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"
|
|
)
|
|
|
|
var geminiOauthScopes = []string{
|
|
"https://www.googleapis.com/auth/cloud-platform",
|
|
"https://www.googleapis.com/auth/userinfo.email",
|
|
"https://www.googleapis.com/auth/userinfo.profile",
|
|
}
|
|
|
|
// GeminiCLIExecutor talks to the Cloud Code Assist endpoint using OAuth credentials from auth metadata.
|
|
type GeminiCLIExecutor struct {
|
|
cfg *config.Config
|
|
}
|
|
|
|
func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor {
|
|
return &GeminiCLIExecutor{cfg: cfg}
|
|
}
|
|
|
|
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
|
|
|
|
func (e *GeminiCLIExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
|
|
|
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini-cli")
|
|
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
|
|
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if hasOverride && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride)
|
|
}
|
|
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
|
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
|
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
|
|
|
action := "generateContent"
|
|
if req.Metadata != nil {
|
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
|
action = "countTokens"
|
|
}
|
|
}
|
|
|
|
projectID := resolveGeminiProjectID(auth)
|
|
models := cliPreviewFallbackOrder(req.Model)
|
|
if len(models) == 0 || models[0] != req.Model {
|
|
models = append([]string{req.Model}, models...)
|
|
}
|
|
|
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
|
|
var lastStatus int
|
|
var lastBody []byte
|
|
|
|
// Get max retry count from config, default to 3 if not set
|
|
maxRetries := e.cfg.RequestRetry
|
|
if maxRetries <= 0 {
|
|
maxRetries = 3
|
|
}
|
|
|
|
for idx, attemptModel := range models {
|
|
// Inner retry loop for 429 errors on the same model
|
|
for retryCount := 0; retryCount <= maxRetries; retryCount++ {
|
|
payload := append([]byte(nil), basePayload...)
|
|
if action == "countTokens" {
|
|
payload = deleteJSONField(payload, "project")
|
|
payload = deleteJSONField(payload, "model")
|
|
} else {
|
|
payload = setJSONField(payload, "project", projectID)
|
|
payload = setJSONField(payload, "model", attemptModel)
|
|
}
|
|
|
|
tok, errTok := tokenSource.Token()
|
|
if errTok != nil {
|
|
err = errTok
|
|
return resp, err
|
|
}
|
|
updateGeminiCLITokenMetadata(auth, baseTokenData, tok)
|
|
|
|
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, action)
|
|
if opts.Alt != "" && action != "countTokens" {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
|
|
reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
|
if errReq != nil {
|
|
err = errReq
|
|
return resp, err
|
|
}
|
|
reqHTTP.Header.Set("Content-Type", "application/json")
|
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
|
applyGeminiCLIHeaders(reqHTTP)
|
|
reqHTTP.Header.Set("Accept", "application/json")
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: reqHTTP.Header.Clone(),
|
|
Body: payload,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpResp, errDo := httpClient.Do(reqHTTP)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
err = errDo
|
|
return resp, err
|
|
}
|
|
|
|
data, errRead := io.ReadAll(httpResp.Body)
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
|
}
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
err = errRead
|
|
return resp, err
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
|
|
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
|
var param any
|
|
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m)
|
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
|
return resp, nil
|
|
}
|
|
|
|
lastStatus = httpResp.StatusCode
|
|
lastBody = append([]byte(nil), data...)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
|
|
// Handle 429 rate limit errors with retry
|
|
if httpResp.StatusCode == 429 {
|
|
if retryCount < maxRetries {
|
|
// Parse retry delay from Google's response
|
|
retryDelay := parseRetryDelay(data)
|
|
log.Infof("gemini cli executor: rate limited (429), retrying model %s in %v (attempt %d/%d)", attemptModel, retryDelay, retryCount+1, maxRetries)
|
|
|
|
// Wait for the specified delay
|
|
select {
|
|
case <-time.After(retryDelay):
|
|
// Continue to next retry iteration
|
|
continue
|
|
case <-ctx.Done():
|
|
// Context cancelled, return immediately
|
|
err = ctx.Err()
|
|
return resp, err
|
|
}
|
|
} else {
|
|
// Exhausted retries for this model, try next model if available
|
|
if idx+1 < len(models) {
|
|
log.Infof("gemini cli executor: rate limited, exhausted %d retries for model %s, trying fallback model: %s", maxRetries, attemptModel, models[idx+1])
|
|
break // Break inner loop to try next model
|
|
} else {
|
|
log.Infof("gemini cli executor: rate limited, exhausted %d retries for model %s, no additional fallback model", maxRetries, attemptModel)
|
|
// No more models to try, will return error below
|
|
}
|
|
}
|
|
} else {
|
|
// Non-429 error, don't retry this model
|
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
return resp, err
|
|
}
|
|
|
|
// Break inner loop if we hit this point (no retry needed or exhausted retries)
|
|
break
|
|
}
|
|
}
|
|
|
|
if len(lastBody) > 0 {
|
|
appendAPIResponseChunk(ctx, e.cfg, lastBody)
|
|
}
|
|
if lastStatus == 0 {
|
|
lastStatus = 429
|
|
}
|
|
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
|
return resp, err
|
|
}
|
|
|
|
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini-cli")
|
|
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
|
|
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
|
if hasOverride && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride)
|
|
}
|
|
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
|
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
|
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
|
|
|
projectID := resolveGeminiProjectID(auth)
|
|
|
|
models := cliPreviewFallbackOrder(req.Model)
|
|
if len(models) == 0 || models[0] != req.Model {
|
|
models = append([]string{req.Model}, models...)
|
|
}
|
|
|
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
|
|
var lastStatus int
|
|
var lastBody []byte
|
|
|
|
// Get max retry count from config, default to 3 if not set
|
|
maxRetries := e.cfg.RequestRetry
|
|
if maxRetries <= 0 {
|
|
maxRetries = 3
|
|
}
|
|
|
|
for idx, attemptModel := range models {
|
|
var httpResp *http.Response
|
|
var payload []byte
|
|
var errDo error
|
|
shouldContinueToNextModel := false
|
|
|
|
// Inner retry loop for 429 errors on the same model
|
|
for retryCount := 0; retryCount <= maxRetries; retryCount++ {
|
|
payload = append([]byte(nil), basePayload...)
|
|
payload = setJSONField(payload, "project", projectID)
|
|
payload = setJSONField(payload, "model", attemptModel)
|
|
|
|
tok, errTok := tokenSource.Token()
|
|
if errTok != nil {
|
|
err = errTok
|
|
return nil, err
|
|
}
|
|
updateGeminiCLITokenMetadata(auth, baseTokenData, tok)
|
|
|
|
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "streamGenerateContent")
|
|
if opts.Alt == "" {
|
|
url = url + "?alt=sse"
|
|
} else {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
|
|
reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
|
if errReq != nil {
|
|
err = errReq
|
|
return nil, err
|
|
}
|
|
reqHTTP.Header.Set("Content-Type", "application/json")
|
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
|
applyGeminiCLIHeaders(reqHTTP)
|
|
reqHTTP.Header.Set("Accept", "text/event-stream")
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: reqHTTP.Header.Clone(),
|
|
Body: payload,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpResp, errDo = httpClient.Do(reqHTTP)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
err = errDo
|
|
return nil, err
|
|
}
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
data, errRead := io.ReadAll(httpResp.Body)
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
|
}
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
err = errRead
|
|
return nil, err
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
lastStatus = httpResp.StatusCode
|
|
lastBody = append([]byte(nil), data...)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
|
|
// Handle 429 rate limit errors with retry
|
|
if httpResp.StatusCode == 429 {
|
|
if retryCount < maxRetries {
|
|
// Parse retry delay from Google's response
|
|
retryDelay := parseRetryDelay(data)
|
|
log.Infof("gemini cli executor: rate limited (429), retrying stream model %s in %v (attempt %d/%d)", attemptModel, retryDelay, retryCount+1, maxRetries)
|
|
|
|
// Wait for the specified delay
|
|
select {
|
|
case <-time.After(retryDelay):
|
|
// Continue to next retry iteration
|
|
continue
|
|
case <-ctx.Done():
|
|
// Context cancelled, return immediately
|
|
err = ctx.Err()
|
|
return nil, err
|
|
}
|
|
} else {
|
|
// Exhausted retries for this model, try next model if available
|
|
if idx+1 < len(models) {
|
|
log.Infof("gemini cli executor: rate limited, exhausted %d retries for stream model %s, trying fallback model: %s", maxRetries, attemptModel, models[idx+1])
|
|
shouldContinueToNextModel = true
|
|
break // Break inner loop to try next model
|
|
} else {
|
|
log.Infof("gemini cli executor: rate limited, exhausted %d retries for stream model %s, no additional fallback model", maxRetries, attemptModel)
|
|
// No more models to try, will return error below
|
|
}
|
|
}
|
|
} else {
|
|
// Non-429 error, don't retry this model
|
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
return nil, err
|
|
}
|
|
|
|
// Break inner loop if we hit this point (no retry needed or exhausted retries)
|
|
break
|
|
}
|
|
|
|
// Success - httpResp.StatusCode is 2xx, break out of retry loop
|
|
// and proceed to streaming logic below
|
|
break
|
|
}
|
|
|
|
// If we need to try the next fallback model, skip streaming logic
|
|
if shouldContinueToNextModel {
|
|
continue
|
|
}
|
|
|
|
// If we have a failed response (non-2xx), don't attempt streaming
|
|
// Continue outer loop to try next model or return error
|
|
if httpResp == nil || httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
continue
|
|
}
|
|
|
|
out := make(chan cliproxyexecutor.StreamChunk)
|
|
stream = out
|
|
go func(resp *http.Response, reqBody []byte, attempt string) {
|
|
defer close(out)
|
|
defer func() {
|
|
if errClose := resp.Body.Close(); errClose != nil {
|
|
log.Errorf("gemini cli executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
if opts.Alt == "" {
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(nil, 20_971_520)
|
|
var param any
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
if detail, ok := parseGeminiCLIStreamUsage(line); ok {
|
|
reporter.publish(ctx, detail)
|
|
}
|
|
if bytes.HasPrefix(line, dataTag) {
|
|
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
|
for i := range segments {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
|
}
|
|
}
|
|
}
|
|
|
|
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
|
for i := range segments {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
|
}
|
|
if errScan := scanner.Err(); errScan != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
reporter.publishFailure(ctx)
|
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
}
|
|
return
|
|
}
|
|
|
|
data, errRead := io.ReadAll(resp.Body)
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
reporter.publishFailure(ctx)
|
|
out <- cliproxyexecutor.StreamChunk{Err: errRead}
|
|
return
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
|
var param any
|
|
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
|
for i := range segments {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
|
}
|
|
|
|
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
|
for i := range segments {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
|
}
|
|
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
|
|
|
return stream, nil
|
|
}
|
|
|
|
if len(lastBody) > 0 {
|
|
appendAPIResponseChunk(ctx, e.cfg, lastBody)
|
|
}
|
|
if lastStatus == 0 {
|
|
lastStatus = 429
|
|
}
|
|
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
|
return nil, err
|
|
}
|
|
|
|
func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
|
if err != nil {
|
|
return cliproxyexecutor.Response{}, err
|
|
}
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini-cli")
|
|
|
|
models := cliPreviewFallbackOrder(req.Model)
|
|
if len(models) == 0 || models[0] != req.Model {
|
|
models = append([]string{req.Model}, models...)
|
|
}
|
|
|
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
|
|
var lastStatus int
|
|
var lastBody []byte
|
|
|
|
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
|
|
for _, attemptModel := range models {
|
|
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
|
if hasOverride && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
payload = util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
|
}
|
|
payload = deleteJSONField(payload, "project")
|
|
payload = deleteJSONField(payload, "model")
|
|
payload = deleteJSONField(payload, "request.safetySettings")
|
|
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
|
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload)
|
|
|
|
tok, errTok := tokenSource.Token()
|
|
if errTok != nil {
|
|
return cliproxyexecutor.Response{}, errTok
|
|
}
|
|
updateGeminiCLITokenMetadata(auth, baseTokenData, tok)
|
|
|
|
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "countTokens")
|
|
if opts.Alt != "" {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
|
|
reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
|
if errReq != nil {
|
|
return cliproxyexecutor.Response{}, errReq
|
|
}
|
|
reqHTTP.Header.Set("Content-Type", "application/json")
|
|
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
|
applyGeminiCLIHeaders(reqHTTP)
|
|
reqHTTP.Header.Set("Accept", "application/json")
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: reqHTTP.Header.Clone(),
|
|
Body: payload,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
resp, errDo := httpClient.Do(reqHTTP)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return cliproxyexecutor.Response{}, errDo
|
|
}
|
|
data, errRead := io.ReadAll(resp.Body)
|
|
_ = resp.Body.Close()
|
|
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
return cliproxyexecutor.Response{}, errRead
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
|
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
|
}
|
|
lastStatus = resp.StatusCode
|
|
lastBody = append([]byte(nil), data...)
|
|
if resp.StatusCode == 429 {
|
|
log.Debugf("gemini cli executor: rate limited, retrying with next model")
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
|
|
if lastStatus == 0 {
|
|
lastStatus = 429
|
|
}
|
|
return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)}
|
|
}
|
|
|
|
func (e *GeminiCLIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
|
log.Debugf("gemini cli executor: refresh called")
|
|
_ = ctx
|
|
return auth, nil
|
|
}
|
|
|
|
func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth) (oauth2.TokenSource, map[string]any, error) {
|
|
metadata := geminiOAuthMetadata(auth)
|
|
if auth == nil || metadata == nil {
|
|
return nil, nil, fmt.Errorf("gemini-cli auth metadata missing")
|
|
}
|
|
|
|
var base map[string]any
|
|
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
|
|
base = cloneMap(tokenRaw)
|
|
} else {
|
|
base = make(map[string]any)
|
|
}
|
|
|
|
var token oauth2.Token
|
|
if len(base) > 0 {
|
|
if raw, err := json.Marshal(base); err == nil {
|
|
_ = json.Unmarshal(raw, &token)
|
|
}
|
|
}
|
|
|
|
if token.AccessToken == "" {
|
|
token.AccessToken = stringValue(metadata, "access_token")
|
|
}
|
|
if token.RefreshToken == "" {
|
|
token.RefreshToken = stringValue(metadata, "refresh_token")
|
|
}
|
|
if token.TokenType == "" {
|
|
token.TokenType = stringValue(metadata, "token_type")
|
|
}
|
|
if token.Expiry.IsZero() {
|
|
if expiry := stringValue(metadata, "expiry"); expiry != "" {
|
|
if ts, err := time.Parse(time.RFC3339, expiry); err == nil {
|
|
token.Expiry = ts
|
|
}
|
|
}
|
|
}
|
|
|
|
conf := &oauth2.Config{
|
|
ClientID: geminiOauthClientID,
|
|
ClientSecret: geminiOauthClientSecret,
|
|
Scopes: geminiOauthScopes,
|
|
Endpoint: google.Endpoint,
|
|
}
|
|
|
|
ctxToken := ctx
|
|
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
|
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
|
|
}
|
|
|
|
src := conf.TokenSource(ctxToken, &token)
|
|
currentToken, err := src.Token()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
updateGeminiCLITokenMetadata(auth, base, currentToken)
|
|
return oauth2.ReuseTokenSource(currentToken, src), base, nil
|
|
}
|
|
|
|
func updateGeminiCLITokenMetadata(auth *cliproxyauth.Auth, base map[string]any, tok *oauth2.Token) {
|
|
if auth == nil || tok == nil {
|
|
return
|
|
}
|
|
merged := buildGeminiTokenMap(base, tok)
|
|
fields := buildGeminiTokenFields(tok, merged)
|
|
shared := geminicli.ResolveSharedCredential(auth.Runtime)
|
|
if shared != nil {
|
|
snapshot := shared.MergeMetadata(fields)
|
|
if !geminicli.IsVirtual(auth.Runtime) {
|
|
auth.Metadata = snapshot
|
|
}
|
|
return
|
|
}
|
|
if auth.Metadata == nil {
|
|
auth.Metadata = make(map[string]any)
|
|
}
|
|
for k, v := range fields {
|
|
auth.Metadata[k] = v
|
|
}
|
|
}
|
|
|
|
func buildGeminiTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
|
|
merged := cloneMap(base)
|
|
if merged == nil {
|
|
merged = make(map[string]any)
|
|
}
|
|
if raw, err := json.Marshal(tok); err == nil {
|
|
var tokenMap map[string]any
|
|
if err = json.Unmarshal(raw, &tokenMap); err == nil {
|
|
for k, v := range tokenMap {
|
|
merged[k] = v
|
|
}
|
|
}
|
|
}
|
|
return merged
|
|
}
|
|
|
|
func buildGeminiTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
|
|
fields := make(map[string]any, 5)
|
|
if tok.AccessToken != "" {
|
|
fields["access_token"] = tok.AccessToken
|
|
}
|
|
if tok.TokenType != "" {
|
|
fields["token_type"] = tok.TokenType
|
|
}
|
|
if tok.RefreshToken != "" {
|
|
fields["refresh_token"] = tok.RefreshToken
|
|
}
|
|
if !tok.Expiry.IsZero() {
|
|
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
|
|
}
|
|
if len(merged) > 0 {
|
|
fields["token"] = cloneMap(merged)
|
|
}
|
|
return fields
|
|
}
|
|
|
|
func resolveGeminiProjectID(auth *cliproxyauth.Auth) string {
|
|
if auth == nil {
|
|
return ""
|
|
}
|
|
if runtime := auth.Runtime; runtime != nil {
|
|
if virtual, ok := runtime.(*geminicli.VirtualCredential); ok && virtual != nil {
|
|
return strings.TrimSpace(virtual.ProjectID)
|
|
}
|
|
}
|
|
return strings.TrimSpace(stringValue(auth.Metadata, "project_id"))
|
|
}
|
|
|
|
func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any {
|
|
if auth == nil {
|
|
return nil
|
|
}
|
|
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
|
if snapshot := shared.MetadataSnapshot(); len(snapshot) > 0 {
|
|
return snapshot
|
|
}
|
|
}
|
|
return auth.Metadata
|
|
}
|
|
|
|
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
|
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
|
}
|
|
|
|
func cloneMap(in map[string]any) map[string]any {
|
|
if in == nil {
|
|
return nil
|
|
}
|
|
out := make(map[string]any, len(in))
|
|
for k, v := range in {
|
|
out[k] = v
|
|
}
|
|
return out
|
|
}
|
|
|
|
func stringValue(m map[string]any, key string) string {
|
|
if m == nil {
|
|
return ""
|
|
}
|
|
if v, ok := m[key]; ok {
|
|
switch typed := v.(type) {
|
|
case string:
|
|
return typed
|
|
case fmt.Stringer:
|
|
return typed.String()
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream.
|
|
func applyGeminiCLIHeaders(r *http.Request) {
|
|
var ginHeaders http.Header
|
|
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
|
ginHeaders = ginCtx.Request.Header
|
|
}
|
|
|
|
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "google-api-nodejs-client/9.15.1")
|
|
misc.EnsureHeader(r.Header, ginHeaders, "X-Goog-Api-Client", "gl-node/22.17.0")
|
|
misc.EnsureHeader(r.Header, ginHeaders, "Client-Metadata", geminiCLIClientMetadata())
|
|
}
|
|
|
|
// geminiCLIClientMetadata returns a compact metadata string required by upstream.
|
|
func geminiCLIClientMetadata() string {
|
|
// Keep parity with CLI client defaults
|
|
return "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
|
}
|
|
|
|
// cliPreviewFallbackOrder returns preview model candidates for a base model.
|
|
func cliPreviewFallbackOrder(model string) []string {
|
|
switch model {
|
|
case "gemini-2.5-pro":
|
|
return []string{
|
|
// "gemini-2.5-pro-preview-05-06",
|
|
"gemini-2.5-pro-preview-06-05",
|
|
}
|
|
case "gemini-2.5-flash":
|
|
return []string{
|
|
// "gemini-2.5-flash-preview-04-17",
|
|
// "gemini-2.5-flash-preview-05-20",
|
|
}
|
|
case "gemini-2.5-flash-lite":
|
|
return []string{
|
|
// "gemini-2.5-flash-lite-preview-06-17",
|
|
}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// setJSONField sets a top-level JSON field on a byte slice payload via sjson.
|
|
func setJSONField(body []byte, key, value string) []byte {
|
|
if key == "" {
|
|
return body
|
|
}
|
|
updated, err := sjson.SetBytes(body, key, value)
|
|
if err != nil {
|
|
return body
|
|
}
|
|
return updated
|
|
}
|
|
|
|
// deleteJSONField removes a top-level key if present (best-effort) via sjson.
|
|
func deleteJSONField(body []byte, key string) []byte {
|
|
if key == "" || len(body) == 0 {
|
|
return body
|
|
}
|
|
updated, err := sjson.DeleteBytes(body, key)
|
|
if err != nil {
|
|
return body
|
|
}
|
|
return updated
|
|
}
|
|
|
|
func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
|
if modelName == "gemini-2.5-flash-image-preview" {
|
|
aspectRatioResult := gjson.GetBytes(rawJSON, "request.generationConfig.imageConfig.aspectRatio")
|
|
if aspectRatioResult.Exists() {
|
|
contents := gjson.GetBytes(rawJSON, "request.contents")
|
|
contentArray := contents.Array()
|
|
if len(contentArray) > 0 {
|
|
hasInlineData := false
|
|
loopContent:
|
|
for i := 0; i < len(contentArray); i++ {
|
|
parts := contentArray[i].Get("parts").Array()
|
|
for j := 0; j < len(parts); j++ {
|
|
if parts[j].Get("inlineData").Exists() {
|
|
hasInlineData = true
|
|
break loopContent
|
|
}
|
|
}
|
|
}
|
|
|
|
if !hasInlineData {
|
|
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
|
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
|
|
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
|
newPartsJson := `[]`
|
|
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
|
|
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
|
|
|
|
parts := contentArray[0].Get("parts").Array()
|
|
for j := 0; j < len(parts); j++ {
|
|
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
|
|
}
|
|
|
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
|
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
|
}
|
|
}
|
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "request.generationConfig.imageConfig")
|
|
}
|
|
}
|
|
return rawJSON
|
|
}
|
|
|
|
// parseRetryDelay extracts the retry delay from a Google API 429 error response.
|
|
// The error response contains a RetryInfo.retryDelay field in the format "0.847655010s".
|
|
// Returns the duration to wait, or a default duration if parsing fails.
|
|
func parseRetryDelay(errorBody []byte) time.Duration {
|
|
const defaultDelay = 1 * time.Second
|
|
const maxDelay = 60 * time.Second
|
|
|
|
// Try to parse the retryDelay from the error response
|
|
// Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo"
|
|
details := gjson.GetBytes(errorBody, "error.details")
|
|
if !details.Exists() || !details.IsArray() {
|
|
log.Debugf("parseRetryDelay: no error.details found, using default delay %v", defaultDelay)
|
|
return defaultDelay
|
|
}
|
|
|
|
for _, detail := range details.Array() {
|
|
typeVal := detail.Get("@type").String()
|
|
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
|
retryDelay := detail.Get("retryDelay").String()
|
|
if retryDelay != "" {
|
|
// Parse duration string like "0.847655010s"
|
|
duration, err := time.ParseDuration(retryDelay)
|
|
if err != nil {
|
|
log.Debugf("parseRetryDelay: failed to parse duration %q: %v, using default", retryDelay, err)
|
|
return defaultDelay
|
|
}
|
|
// Cap at maxDelay to prevent excessive waits
|
|
if duration > maxDelay {
|
|
log.Debugf("parseRetryDelay: capping delay from %v to %v", duration, maxDelay)
|
|
return maxDelay
|
|
}
|
|
if duration < 0 {
|
|
log.Debugf("parseRetryDelay: negative delay %v, using default", duration)
|
|
return defaultDelay
|
|
}
|
|
log.Debugf("parseRetryDelay: using delay %v from API response", duration)
|
|
return duration
|
|
}
|
|
}
|
|
}
|
|
|
|
log.Debugf("parseRetryDelay: no RetryInfo found, using default delay %v", defaultDelay)
|
|
return defaultDelay
|
|
}
|