diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 2f48871b..f49d0133 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -99,123 +99,89 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth var lastStatus int var lastBody []byte - // Get max retry count from config, default to 3 if not set - maxRetries := e.cfg.RequestRetry - if maxRetries <= 0 { - maxRetries = 3 - } - for idx, attemptModel := range models { - // Inner retry loop for 429 errors on the same model - for retryCount := 0; retryCount <= maxRetries; retryCount++ { - payload := append([]byte(nil), basePayload...) - if action == "countTokens" { - payload = deleteJSONField(payload, "project") - payload = deleteJSONField(payload, "model") - } else { - payload = setJSONField(payload, "project", projectID) - payload = setJSONField(payload, "model", attemptModel) - } - - tok, errTok := tokenSource.Token() - if errTok != nil { - err = errTok - return resp, err - } - updateGeminiCLITokenMetadata(auth, baseTokenData, tok) - - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, action) - if opts.Alt != "" && action != "countTokens" { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) - } - - reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) - if errReq != nil { - err = errReq - return resp, err - } - reqHTTP.Header.Set("Content-Type", "application/json") - reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP) - reqHTTP.Header.Set("Accept", "application/json") - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: reqHTTP.Header.Clone(), - Body: payload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpResp, errDo := httpClient.Do(reqHTTP) - if errDo != nil { - recordAPIResponseError(ctx, e.cfg, errDo) - err = errDo - return resp, err - } - - data, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("gemini cli executor: close response body error: %v", errClose) - } - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - err = errRead - return resp, err - } - appendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 { - reporter.publish(ctx, parseGeminiCLIUsage(data)) - var param any - out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} - return resp, nil - } - - lastStatus = httpResp.StatusCode - lastBody = append([]byte(nil), data...) - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - - // Handle 429 rate limit errors with retry - if httpResp.StatusCode == 429 { - if retryCount < maxRetries { - // Parse retry delay from Google's response - retryDelay := parseRetryDelay(data) - log.Infof("gemini cli executor: rate limited (429), retrying model %s in %v (attempt %d/%d)", attemptModel, retryDelay, retryCount+1, maxRetries) - - // Wait for the specified delay - select { - case <-time.After(retryDelay): - // Continue to next retry iteration - continue - case <-ctx.Done(): - // Context cancelled, return immediately - err = ctx.Err() - return resp, err - } - } else { - // Exhausted retries for this model, try next model if available - if idx+1 < len(models) { - log.Infof("gemini cli executor: rate limited, exhausted %d retries for model %s, trying fallback model: %s", maxRetries, attemptModel, models[idx+1]) - break // Break inner loop to try next model - } else { - log.Infof("gemini cli executor: rate limited, exhausted %d retries for model %s, no additional fallback model", maxRetries, attemptModel) - // No more models to try, will return error below - } - } - } else { - // Non-429 error, don't retry this model - err = statusErr{code: httpResp.StatusCode, msg: string(data)} - return resp, err - } - - // Break inner loop if we hit this point (no retry needed or exhausted retries) - break + payload := append([]byte(nil), basePayload...) + if action == "countTokens" { + payload = deleteJSONField(payload, "project") + payload = deleteJSONField(payload, "model") + } else { + payload = setJSONField(payload, "project", projectID) + payload = setJSONField(payload, "model", attemptModel) } + + tok, errTok := tokenSource.Token() + if errTok != nil { + err = errTok + return resp, err + } + updateGeminiCLITokenMetadata(auth, baseTokenData, tok) + + url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, action) + if opts.Alt != "" && action != "countTokens" { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + + reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) + if errReq != nil { + err = errReq + return resp, err + } + reqHTTP.Header.Set("Content-Type", "application/json") + reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) + applyGeminiCLIHeaders(reqHTTP) + reqHTTP.Header.Set("Accept", "application/json") + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: reqHTTP.Header.Clone(), + Body: payload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpResp, errDo := httpClient.Do(reqHTTP) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + err = errDo + return resp, err + } + + data, errRead := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("gemini cli executor: close response body error: %v", errClose) + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + err = errRead + return resp, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 { + reporter.publish(ctx, parseGeminiCLIUsage(data)) + var param any + out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil + } + + lastStatus = httpResp.StatusCode + lastBody = append([]byte(nil), data...) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + if httpResp.StatusCode == 429 { + if idx+1 < len(models) { + log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) + } else { + log.Debug("gemini cli executor: rate limited, no additional fallback model") + } + continue + } + + err = newGeminiStatusErr(httpResp.StatusCode, data) + return resp, err } if len(lastBody) > 0 { @@ -224,7 +190,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth if lastStatus == 0 { lastStatus = 429 } - err = statusErr{code: lastStatus, msg: string(lastBody)} + err = newGeminiStatusErr(lastStatus, lastBody) return resp, err } @@ -269,133 +235,77 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut var lastStatus int var lastBody []byte - // Get max retry count from config, default to 3 if not set - maxRetries := e.cfg.RequestRetry - if maxRetries <= 0 { - maxRetries = 3 - } - for idx, attemptModel := range models { - var httpResp *http.Response - var payload []byte - var errDo error - shouldContinueToNextModel := false + payload := append([]byte(nil), basePayload...) + payload = setJSONField(payload, "project", projectID) + payload = setJSONField(payload, "model", attemptModel) - // Inner retry loop for 429 errors on the same model - for retryCount := 0; retryCount <= maxRetries; retryCount++ { - payload = append([]byte(nil), basePayload...) - payload = setJSONField(payload, "project", projectID) - payload = setJSONField(payload, "model", attemptModel) + tok, errTok := tokenSource.Token() + if errTok != nil { + err = errTok + return nil, err + } + updateGeminiCLITokenMetadata(auth, baseTokenData, tok) - tok, errTok := tokenSource.Token() - if errTok != nil { - err = errTok + url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "streamGenerateContent") + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + + reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) + if errReq != nil { + err = errReq + return nil, err + } + reqHTTP.Header.Set("Content-Type", "application/json") + reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) + applyGeminiCLIHeaders(reqHTTP) + reqHTTP.Header.Set("Accept", "text/event-stream") + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: reqHTTP.Header.Clone(), + Body: payload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpResp, errDo := httpClient.Do(reqHTTP) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + err = errDo + return nil, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + data, errRead := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("gemini cli executor: close response body error: %v", errClose) + } + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + err = errRead return nil, err } - updateGeminiCLITokenMetadata(auth, baseTokenData, tok) - - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "streamGenerateContent") - if opts.Alt == "" { - url = url + "?alt=sse" - } else { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) - } - - reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) - if errReq != nil { - err = errReq - return nil, err - } - reqHTTP.Header.Set("Content-Type", "application/json") - reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP) - reqHTTP.Header.Set("Accept", "text/event-stream") - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: reqHTTP.Header.Clone(), - Body: payload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpResp, errDo = httpClient.Do(reqHTTP) - if errDo != nil { - recordAPIResponseError(ctx, e.cfg, errDo) - err = errDo - return nil, err - } - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - data, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("gemini cli executor: close response body error: %v", errClose) - } - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - err = errRead - return nil, err - } - appendAPIResponseChunk(ctx, e.cfg, data) - lastStatus = httpResp.StatusCode - lastBody = append([]byte(nil), data...) - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - - // Handle 429 rate limit errors with retry - if httpResp.StatusCode == 429 { - if retryCount < maxRetries { - // Parse retry delay from Google's response - retryDelay := parseRetryDelay(data) - log.Infof("gemini cli executor: rate limited (429), retrying stream model %s in %v (attempt %d/%d)", attemptModel, retryDelay, retryCount+1, maxRetries) - - // Wait for the specified delay - select { - case <-time.After(retryDelay): - // Continue to next retry iteration - continue - case <-ctx.Done(): - // Context cancelled, return immediately - err = ctx.Err() - return nil, err - } - } else { - // Exhausted retries for this model, try next model if available - if idx+1 < len(models) { - log.Infof("gemini cli executor: rate limited, exhausted %d retries for stream model %s, trying fallback model: %s", maxRetries, attemptModel, models[idx+1]) - shouldContinueToNextModel = true - break // Break inner loop to try next model - } else { - log.Infof("gemini cli executor: rate limited, exhausted %d retries for stream model %s, no additional fallback model", maxRetries, attemptModel) - // No more models to try, will return error below - } - } + appendAPIResponseChunk(ctx, e.cfg, data) + lastStatus = httpResp.StatusCode + lastBody = append([]byte(nil), data...) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + if httpResp.StatusCode == 429 { + if idx+1 < len(models) { + log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) } else { - // Non-429 error, don't retry this model - err = statusErr{code: httpResp.StatusCode, msg: string(data)} - return nil, err + log.Debug("gemini cli executor: rate limited, no additional fallback model") } - - // Break inner loop if we hit this point (no retry needed or exhausted retries) - break + continue } - - // Success - httpResp.StatusCode is 2xx, break out of retry loop - // and proceed to streaming logic below - break - } - - // If we need to try the next fallback model, skip streaming logic - if shouldContinueToNextModel { - continue - } - - // If we have a failed response (non-2xx), don't attempt streaming - // Continue outer loop to try next model or return error - if httpResp == nil || httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - continue + err = newGeminiStatusErr(httpResp.StatusCode, data) + return nil, err } out := make(chan cliproxyexecutor.StreamChunk) @@ -467,7 +377,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut if lastStatus == 0 { lastStatus = 429 } - err = statusErr{code: lastStatus, msg: string(lastBody)} + err = newGeminiStatusErr(lastStatus, lastBody) return nil, err } @@ -575,7 +485,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. if lastStatus == 0 { lastStatus = 429 } - return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)} + return cliproxyexecutor.Response{}, newGeminiStatusErr(lastStatus, lastBody) } func (e *GeminiCLIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { @@ -860,19 +770,25 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte { return rawJSON } +func newGeminiStatusErr(statusCode int, body []byte) statusErr { + err := statusErr{code: statusCode, msg: string(body)} + if statusCode == http.StatusTooManyRequests { + if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil { + err.retryAfter = retryAfter + } + } + return err +} + // parseRetryDelay extracts the retry delay from a Google API 429 error response. // The error response contains a RetryInfo.retryDelay field in the format "0.847655010s". -// Returns the duration to wait, or a default duration if parsing fails. -func parseRetryDelay(errorBody []byte) time.Duration { - const defaultDelay = 1 * time.Second - const maxDelay = 60 * time.Second - +// Returns the parsed duration or an error if it cannot be determined. +func parseRetryDelay(errorBody []byte) (*time.Duration, error) { // Try to parse the retryDelay from the error response // Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo" details := gjson.GetBytes(errorBody, "error.details") if !details.Exists() || !details.IsArray() { - log.Debugf("parseRetryDelay: no error.details found, using default delay %v", defaultDelay) - return defaultDelay + return nil, fmt.Errorf("no error.details found") } for _, detail := range details.Array() { @@ -883,24 +799,12 @@ func parseRetryDelay(errorBody []byte) time.Duration { // Parse duration string like "0.847655010s" duration, err := time.ParseDuration(retryDelay) if err != nil { - log.Debugf("parseRetryDelay: failed to parse duration %q: %v, using default", retryDelay, err) - return defaultDelay + return nil, fmt.Errorf("failed to parse duration") } - // Cap at maxDelay to prevent excessive waits - if duration > maxDelay { - log.Debugf("parseRetryDelay: capping delay from %v to %v", duration, maxDelay) - return maxDelay - } - if duration < 0 { - log.Debugf("parseRetryDelay: negative delay %v, using default", duration) - return defaultDelay - } - log.Debugf("parseRetryDelay: using delay %v from API response", duration) - return duration + return &duration, nil } } } - log.Debugf("parseRetryDelay: no RetryInfo found, using default delay %v", defaultDelay) - return defaultDelay + return nil, fmt.Errorf("no RetryInfo found") } diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 42af9240..55ec6dc9 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "strings" + "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" @@ -340,8 +341,9 @@ func (e *OpenAICompatExecutor) overrideModel(payload []byte, model string) []byt } type statusErr struct { - code int - msg string + code int + msg string + retryAfter *time.Duration } func (e statusErr) Error() string { @@ -350,4 +352,5 @@ func (e statusErr) Error() string { } return fmt.Sprintf("status %d", e.code) } -func (e statusErr) StatusCode() int { return e.code } +func (e statusErr) StatusCode() int { return e.code } +func (e statusErr) RetryAfter() *time.Duration { return e.retryAfter } diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index f195c0e9..a056ae0b 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -62,6 +62,8 @@ type Result struct { Model string // Success marks whether the execution succeeded. Success bool + // RetryAfter carries a provider supplied retry hint (e.g. 429 retryDelay). + RetryAfter *time.Duration // Error describes the failure when Success is false. Error *Error } @@ -325,6 +327,9 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req if errors.As(errExec, &se) && se != nil { result.Error.HTTPStatus = se.StatusCode() } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } m.MarkResult(execCtx, result) lastErr = errExec continue @@ -370,6 +375,9 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, if errors.As(errExec, &se) && se != nil { result.Error.HTTPStatus = se.StatusCode() } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } m.MarkResult(execCtx, result) lastErr = errExec continue @@ -415,6 +423,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string rerr.HTTPStatus = se.StatusCode() } result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(errStream) m.MarkResult(execCtx, result) lastErr = errStream continue @@ -556,17 +565,23 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { suspendReason = "payment_required" shouldSuspendModel = true case 429: - cooldown, nextLevel := nextQuotaCooldown(state.Quota.BackoffLevel) var next time.Time - if cooldown > 0 { - next = now.Add(cooldown) + backoffLevel := state.Quota.BackoffLevel + if result.RetryAfter != nil { + next = now.Add(*result.RetryAfter) + } else { + cooldown, nextLevel := nextQuotaCooldown(backoffLevel) + if cooldown > 0 { + next = now.Add(cooldown) + } + backoffLevel = nextLevel } state.NextRetryAfter = next state.Quota = QuotaState{ Exceeded: true, Reason: "quota", NextRecoverAt: next, - BackoffLevel: nextLevel, + BackoffLevel: backoffLevel, } suspendReason = "quota" shouldSuspendModel = true @@ -582,7 +597,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { auth.UpdatedAt = now updateAggregatedAvailability(auth, now) } else { - applyAuthFailureState(auth, result.Error, now) + applyAuthFailureState(auth, result.Error, result.RetryAfter, now) } } @@ -742,6 +757,25 @@ func cloneError(err *Error) *Error { } } +func retryAfterFromError(err error) *time.Duration { + if err == nil { + return nil + } + type retryAfterProvider interface { + RetryAfter() *time.Duration + } + rap, ok := err.(retryAfterProvider) + if !ok || rap == nil { + return nil + } + retryAfter := rap.RetryAfter() + if retryAfter == nil { + return nil + } + val := *retryAfter + return &val +} + func statusCodeFromResult(err *Error) int { if err == nil { return 0 @@ -749,7 +783,7 @@ func statusCodeFromResult(err *Error) int { return err.StatusCode() } -func applyAuthFailureState(auth *Auth, resultErr *Error, now time.Time) { +func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) { if auth == nil { return } @@ -774,13 +808,17 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, now time.Time) { auth.StatusMessage = "quota exhausted" auth.Quota.Exceeded = true auth.Quota.Reason = "quota" - cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel) var next time.Time - if cooldown > 0 { - next = now.Add(cooldown) + if retryAfter != nil { + next = now.Add(*retryAfter) + } else { + cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel) + if cooldown > 0 { + next = now.Add(cooldown) + } + auth.Quota.BackoffLevel = nextLevel } auth.Quota.NextRecoverAt = next - auth.Quota.BackoffLevel = nextLevel auth.NextRetryAfter = next case 408, 500, 502, 503, 504: auth.StatusMessage = "transient upstream error"