diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 9a2c2602..5e932fbd 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -99,89 +99,123 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth var lastStatus int var lastBody []byte + // Get max retry count from config, default to 3 if not set + maxRetries := e.cfg.RequestRetry + if maxRetries <= 0 { + maxRetries = 3 + } + for idx, attemptModel := range models { - payload := append([]byte(nil), basePayload...) - if action == "countTokens" { - payload = deleteJSONField(payload, "project") - payload = deleteJSONField(payload, "model") - } else { - payload = setJSONField(payload, "project", projectID) - payload = setJSONField(payload, "model", attemptModel) - } - - tok, errTok := tokenSource.Token() - if errTok != nil { - err = errTok - return resp, err - } - updateGeminiCLITokenMetadata(auth, baseTokenData, tok) - - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, action) - if opts.Alt != "" && action != "countTokens" { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) - } - - reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) - if errReq != nil { - err = errReq - return resp, err - } - reqHTTP.Header.Set("Content-Type", "application/json") - reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP) - reqHTTP.Header.Set("Accept", "application/json") - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: reqHTTP.Header.Clone(), - Body: payload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpResp, errDo := httpClient.Do(reqHTTP) - if errDo != nil { - recordAPIResponseError(ctx, e.cfg, errDo) - err = errDo - return resp, err - } - - data, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("gemini cli executor: close response body error: %v", errClose) - } - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - err = errRead - return resp, err - } - appendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 { - reporter.publish(ctx, parseGeminiCLIUsage(data)) - var param any - out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} - return resp, nil - } - - lastStatus = httpResp.StatusCode - lastBody = append([]byte(nil), data...) - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - if httpResp.StatusCode == 429 { - if idx+1 < len(models) { - log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) + // Inner retry loop for 429 errors on the same model + for retryCount := 0; retryCount <= maxRetries; retryCount++ { + payload := append([]byte(nil), basePayload...) + if action == "countTokens" { + payload = deleteJSONField(payload, "project") + payload = deleteJSONField(payload, "model") } else { - log.Debug("gemini cli executor: rate limited, no additional fallback model") + payload = setJSONField(payload, "project", projectID) + payload = setJSONField(payload, "model", attemptModel) } - continue - } - err = statusErr{code: httpResp.StatusCode, msg: string(data)} - return resp, err + tok, errTok := tokenSource.Token() + if errTok != nil { + err = errTok + return resp, err + } + updateGeminiCLITokenMetadata(auth, baseTokenData, tok) + + url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, action) + if opts.Alt != "" && action != "countTokens" { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + + reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) + if errReq != nil { + err = errReq + return resp, err + } + reqHTTP.Header.Set("Content-Type", "application/json") + reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) + applyGeminiCLIHeaders(reqHTTP) + reqHTTP.Header.Set("Accept", "application/json") + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: reqHTTP.Header.Clone(), + Body: payload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpResp, errDo := httpClient.Do(reqHTTP) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + err = errDo + return resp, err + } + + data, errRead := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("gemini cli executor: close response body error: %v", errClose) + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + err = errRead + return resp, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 { + reporter.publish(ctx, parseGeminiCLIUsage(data)) + var param any + out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil + } + + lastStatus = httpResp.StatusCode + lastBody = append([]byte(nil), data...) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + + // Handle 429 rate limit errors with retry + if httpResp.StatusCode == 429 { + if retryCount < maxRetries { + // Parse retry delay from Google's response + retryDelay := parseRetryDelay(data) + log.Infof("gemini cli executor: rate limited (429), retrying model %s in %v (attempt %d/%d)", attemptModel, retryDelay, retryCount+1, maxRetries) + + // Wait for the specified delay + select { + case <-time.After(retryDelay): + // Continue to next retry iteration + continue + case <-ctx.Done(): + // Context cancelled, return immediately + err = ctx.Err() + return resp, err + } + } else { + // Exhausted retries for this model, try next model if available + if idx+1 < len(models) { + log.Infof("gemini cli executor: rate limited, exhausted %d retries for model %s, trying fallback model: %s", maxRetries, attemptModel, models[idx+1]) + break // Break inner loop to try next model + } else { + log.Infof("gemini cli executor: rate limited, exhausted %d retries for model %s, no additional fallback model", maxRetries, attemptModel) + // No more models to try, will return error below + } + } + } else { + // Non-429 error, don't retry this model + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return resp, err + } + + // Break inner loop if we hit this point (no retry needed or exhausted retries) + break + } } if len(lastBody) > 0 { @@ -235,77 +269,120 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut var lastStatus int var lastBody []byte + // Get max retry count from config, default to 3 if not set + maxRetries := e.cfg.RequestRetry + if maxRetries <= 0 { + maxRetries = 3 + } + for idx, attemptModel := range models { - payload := append([]byte(nil), basePayload...) - payload = setJSONField(payload, "project", projectID) - payload = setJSONField(payload, "model", attemptModel) + var httpResp *http.Response + var payload []byte + var errDo error - tok, errTok := tokenSource.Token() - if errTok != nil { - err = errTok - return nil, err - } - updateGeminiCLITokenMetadata(auth, baseTokenData, tok) + // Inner retry loop for 429 errors on the same model + for retryCount := 0; retryCount <= maxRetries; retryCount++ { + payload = append([]byte(nil), basePayload...) + payload = setJSONField(payload, "project", projectID) + payload = setJSONField(payload, "model", attemptModel) - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "streamGenerateContent") - if opts.Alt == "" { - url = url + "?alt=sse" - } else { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) - } - - reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) - if errReq != nil { - err = errReq - return nil, err - } - reqHTTP.Header.Set("Content-Type", "application/json") - reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP) - reqHTTP.Header.Set("Accept", "text/event-stream") - recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: reqHTTP.Header.Clone(), - Body: payload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpResp, errDo := httpClient.Do(reqHTTP) - if errDo != nil { - recordAPIResponseError(ctx, e.cfg, errDo) - err = errDo - return nil, err - } - recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - data, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("gemini cli executor: close response body error: %v", errClose) - } - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - err = errRead + tok, errTok := tokenSource.Token() + if errTok != nil { + err = errTok return nil, err } - appendAPIResponseChunk(ctx, e.cfg, data) - lastStatus = httpResp.StatusCode - lastBody = append([]byte(nil), data...) - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - if httpResp.StatusCode == 429 { - if idx+1 < len(models) { - log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) - } else { - log.Debug("gemini cli executor: rate limited, no additional fallback model") - } - continue + updateGeminiCLITokenMetadata(auth, baseTokenData, tok) + + url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "streamGenerateContent") + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } - err = statusErr{code: httpResp.StatusCode, msg: string(data)} - return nil, err + + reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) + if errReq != nil { + err = errReq + return nil, err + } + reqHTTP.Header.Set("Content-Type", "application/json") + reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) + applyGeminiCLIHeaders(reqHTTP) + reqHTTP.Header.Set("Accept", "text/event-stream") + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: reqHTTP.Header.Clone(), + Body: payload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpResp, errDo = httpClient.Do(reqHTTP) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + err = errDo + return nil, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + data, errRead := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("gemini cli executor: close response body error: %v", errClose) + } + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + err = errRead + return nil, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + lastStatus = httpResp.StatusCode + lastBody = append([]byte(nil), data...) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + + // Handle 429 rate limit errors with retry + if httpResp.StatusCode == 429 { + if retryCount < maxRetries { + // Parse retry delay from Google's response + retryDelay := parseRetryDelay(data) + log.Infof("gemini cli executor: rate limited (429), retrying stream model %s in %v (attempt %d/%d)", attemptModel, retryDelay, retryCount+1, maxRetries) + + // Wait for the specified delay + select { + case <-time.After(retryDelay): + // Continue to next retry iteration + continue + case <-ctx.Done(): + // Context cancelled, return immediately + err = ctx.Err() + return nil, err + } + } else { + // Exhausted retries for this model, try next model if available + if idx+1 < len(models) { + log.Infof("gemini cli executor: rate limited, exhausted %d retries for stream model %s, trying fallback model: %s", maxRetries, attemptModel, models[idx+1]) + break // Break inner loop to try next model + } else { + log.Infof("gemini cli executor: rate limited, exhausted %d retries for stream model %s, no additional fallback model", maxRetries, attemptModel) + // No more models to try, will return error below + } + } + } else { + // Non-429 error, don't retry this model + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return nil, err + } + + // Break inner loop if we hit this point (no retry needed or exhausted retries) + break + } + + // Success - httpResp.StatusCode is 2xx, break out of retry loop + // and proceed to streaming logic below + break } out := make(chan cliproxyexecutor.StreamChunk) @@ -769,3 +846,48 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte { } return rawJSON } + +// parseRetryDelay extracts the retry delay from a Google API 429 error response. +// The error response contains a RetryInfo.retryDelay field in the format "0.847655010s". +// Returns the duration to wait, or a default duration if parsing fails. +func parseRetryDelay(errorBody []byte) time.Duration { + const defaultDelay = 1 * time.Second + const maxDelay = 60 * time.Second + + // Try to parse the retryDelay from the error response + // Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo" + details := gjson.GetBytes(errorBody, "error.details") + if !details.Exists() || !details.IsArray() { + log.Debugf("parseRetryDelay: no error.details found, using default delay %v", defaultDelay) + return defaultDelay + } + + for _, detail := range details.Array() { + typeVal := detail.Get("@type").String() + if typeVal == "type.googleapis.com/google.rpc.RetryInfo" { + retryDelay := detail.Get("retryDelay").String() + if retryDelay != "" { + // Parse duration string like "0.847655010s" + duration, err := time.ParseDuration(retryDelay) + if err != nil { + log.Debugf("parseRetryDelay: failed to parse duration %q: %v, using default", retryDelay, err) + return defaultDelay + } + // Cap at maxDelay to prevent excessive waits + if duration > maxDelay { + log.Debugf("parseRetryDelay: capping delay from %v to %v", duration, maxDelay) + return maxDelay + } + if duration < 0 { + log.Debugf("parseRetryDelay: negative delay %v, using default", duration) + return defaultDelay + } + log.Debugf("parseRetryDelay: using delay %v from API response", duration) + return duration + } + } + } + + log.Debugf("parseRetryDelay: no RetryInfo found, using default delay %v", defaultDelay) + return defaultDelay +}