mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-20 21:30:50 +08:00
feat(auth): add support for request_retry and disable_cooling overrides
Implement `request_retry` and `disable_cooling` metadata overrides for authentication management. Update retry and cooling logic accordingly across `Manager`, Antigravity executor, and file synthesizer. Add tests to validate new behaviors.
This commit is contained in:
@@ -148,7 +148,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
attempts := antigravityRetryAttempts(e.cfg)
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
|
|
||||||
attemptLoop:
|
attemptLoop:
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
@@ -289,7 +289,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
attempts := antigravityRetryAttempts(e.cfg)
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
|
|
||||||
attemptLoop:
|
attemptLoop:
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
@@ -677,7 +677,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
attempts := antigravityRetryAttempts(e.cfg)
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
|
|
||||||
attemptLoop:
|
attemptLoop:
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
@@ -1447,11 +1447,16 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
|
|||||||
return defaultAntigravityAgent
|
return defaultAntigravityAgent
|
||||||
}
|
}
|
||||||
|
|
||||||
func antigravityRetryAttempts(cfg *config.Config) int {
|
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
|
||||||
if cfg == nil {
|
retry := 0
|
||||||
return 1
|
if cfg != nil {
|
||||||
|
retry = cfg.RequestRetry
|
||||||
|
}
|
||||||
|
if auth != nil {
|
||||||
|
if override, ok := auth.RequestRetryOverride(); ok {
|
||||||
|
retry = override
|
||||||
|
}
|
||||||
}
|
}
|
||||||
retry := cfg.RequestRetry
|
|
||||||
if retry < 0 {
|
if retry < 0 {
|
||||||
retry = 0
|
retry = 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -167,6 +167,16 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
|||||||
"virtual_parent_id": primary.ID,
|
"virtual_parent_id": primary.ID,
|
||||||
"type": metadata["type"],
|
"type": metadata["type"],
|
||||||
}
|
}
|
||||||
|
if v, ok := metadata["disable_cooling"]; ok {
|
||||||
|
metadataCopy["disable_cooling"] = v
|
||||||
|
} else if v, ok := metadata["disable-cooling"]; ok {
|
||||||
|
metadataCopy["disable_cooling"] = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["request_retry"]; ok {
|
||||||
|
metadataCopy["request_retry"] = v
|
||||||
|
} else if v, ok := metadata["request-retry"]; ok {
|
||||||
|
metadataCopy["request_retry"] = v
|
||||||
|
}
|
||||||
proxy := strings.TrimSpace(primary.ProxyURL)
|
proxy := strings.TrimSpace(primary.ProxyURL)
|
||||||
if proxy != "" {
|
if proxy != "" {
|
||||||
metadataCopy["proxy_url"] = proxy
|
metadataCopy["proxy_url"] = proxy
|
||||||
|
|||||||
@@ -69,10 +69,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
|||||||
|
|
||||||
// Create a valid auth file
|
// Create a valid auth file
|
||||||
authData := map[string]any{
|
authData := map[string]any{
|
||||||
"type": "claude",
|
"type": "claude",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"proxy_url": "http://proxy.local",
|
"proxy_url": "http://proxy.local",
|
||||||
"prefix": "test-prefix",
|
"prefix": "test-prefix",
|
||||||
|
"disable_cooling": true,
|
||||||
|
"request_retry": 2,
|
||||||
}
|
}
|
||||||
data, _ := json.Marshal(authData)
|
data, _ := json.Marshal(authData)
|
||||||
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
||||||
@@ -108,6 +110,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
|||||||
if auths[0].ProxyURL != "http://proxy.local" {
|
if auths[0].ProxyURL != "http://proxy.local" {
|
||||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||||
}
|
}
|
||||||
|
if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v {
|
||||||
|
t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"])
|
||||||
|
}
|
||||||
|
if v, ok := auths[0].Metadata["request_retry"].(float64); !ok || int(v) != 2 {
|
||||||
|
t.Errorf("expected request_retry 2, got %v", auths[0].Metadata["request_retry"])
|
||||||
|
}
|
||||||
if auths[0].Status != coreauth.StatusActive {
|
if auths[0].Status != coreauth.StatusActive {
|
||||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||||
}
|
}
|
||||||
@@ -336,9 +344,11 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"project_id": "project-a, project-b, project-c",
|
"project_id": "project-a, project-b, project-c",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"type": "gemini",
|
"type": "gemini",
|
||||||
|
"request_retry": 2,
|
||||||
|
"disable_cooling": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||||
@@ -376,6 +386,12 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
|||||||
if v.ProxyURL != "http://proxy.local" {
|
if v.ProxyURL != "http://proxy.local" {
|
||||||
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
||||||
}
|
}
|
||||||
|
if vv, ok := v.Metadata["disable_cooling"].(bool); !ok || !vv {
|
||||||
|
t.Errorf("expected disable_cooling true, got %v", v.Metadata["disable_cooling"])
|
||||||
|
}
|
||||||
|
if vv, ok := v.Metadata["request_retry"].(int); !ok || vv != 2 {
|
||||||
|
t.Errorf("expected request_retry 2, got %v", v.Metadata["request_retry"])
|
||||||
|
}
|
||||||
if v.Attributes["runtime_only"] != "true" {
|
if v.Attributes["runtime_only"] != "true" {
|
||||||
t.Error("expected runtime_only=true")
|
t.Error("expected runtime_only=true")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,15 @@ func SetQuotaCooldownDisabled(disable bool) {
|
|||||||
quotaCooldownDisabled.Store(disable)
|
quotaCooldownDisabled.Store(disable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func quotaCooldownDisabledForAuth(auth *Auth) bool {
|
||||||
|
if auth != nil {
|
||||||
|
if override, ok := auth.DisableCoolingOverride(); ok {
|
||||||
|
return override
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return quotaCooldownDisabled.Load()
|
||||||
|
}
|
||||||
|
|
||||||
// Result captures execution outcome used to adjust auth state.
|
// Result captures execution outcome used to adjust auth state.
|
||||||
type Result struct {
|
type Result struct {
|
||||||
// AuthID references the auth that produced this result.
|
// AuthID references the auth that produced this result.
|
||||||
@@ -468,20 +477,16 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
|||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
_, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
|
||||||
if attempts < 1 {
|
|
||||||
attempts = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; ; attempt++ {
|
||||||
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
||||||
if errExec == nil {
|
if errExec == nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -503,20 +508,16 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
|||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
_, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
|
||||||
if attempts < 1 {
|
|
||||||
attempts = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; ; attempt++ {
|
||||||
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
||||||
if errExec == nil {
|
if errExec == nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -538,20 +539,16 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
_, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
|
||||||
if attempts < 1 {
|
|
||||||
attempts = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; ; attempt++ {
|
||||||
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||||
if errStream == nil {
|
if errStream == nil {
|
||||||
return chunks, nil
|
return chunks, nil
|
||||||
}
|
}
|
||||||
lastErr = errStream
|
lastErr = errStream
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -1034,11 +1031,15 @@ func (m *Manager) retrySettings() (int, time.Duration) {
|
|||||||
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
|
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) {
|
func (m *Manager) closestCooldownWait(providers []string, model string, attempt int) (time.Duration, bool) {
|
||||||
if m == nil || len(providers) == 0 {
|
if m == nil || len(providers) == 0 {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
defaultRetry := int(m.requestRetry.Load())
|
||||||
|
if defaultRetry < 0 {
|
||||||
|
defaultRetry = 0
|
||||||
|
}
|
||||||
providerSet := make(map[string]struct{}, len(providers))
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
for i := range providers {
|
for i := range providers {
|
||||||
key := strings.TrimSpace(strings.ToLower(providers[i]))
|
key := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||||
@@ -1061,6 +1062,16 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du
|
|||||||
if _, ok := providerSet[providerKey]; !ok {
|
if _, ok := providerSet[providerKey]; !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
effectiveRetry := defaultRetry
|
||||||
|
if override, ok := auth.RequestRetryOverride(); ok {
|
||||||
|
effectiveRetry = override
|
||||||
|
}
|
||||||
|
if effectiveRetry < 0 {
|
||||||
|
effectiveRetry = 0
|
||||||
|
}
|
||||||
|
if attempt >= effectiveRetry {
|
||||||
|
continue
|
||||||
|
}
|
||||||
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
||||||
if !blocked || next.IsZero() || reason == blockReasonDisabled {
|
if !blocked || next.IsZero() || reason == blockReasonDisabled {
|
||||||
continue
|
continue
|
||||||
@@ -1077,8 +1088,8 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du
|
|||||||
return minWait, found
|
return minWait, found
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
|
func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
|
||||||
if err == nil || attempt >= maxAttempts-1 {
|
if err == nil {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
if maxWait <= 0 {
|
if maxWait <= 0 {
|
||||||
@@ -1087,7 +1098,7 @@ func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, pro
|
|||||||
if status := statusCodeFromError(err); status == http.StatusOK {
|
if status := statusCodeFromError(err); status == http.StatusOK {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
wait, found := m.closestCooldownWait(providers, model)
|
wait, found := m.closestCooldownWait(providers, model, attempt)
|
||||||
if !found || wait > maxWait {
|
if !found || wait > maxWait {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
@@ -1176,7 +1187,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
|||||||
if result.RetryAfter != nil {
|
if result.RetryAfter != nil {
|
||||||
next = now.Add(*result.RetryAfter)
|
next = now.Add(*result.RetryAfter)
|
||||||
} else {
|
} else {
|
||||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel)
|
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||||
if cooldown > 0 {
|
if cooldown > 0 {
|
||||||
next = now.Add(cooldown)
|
next = now.Add(cooldown)
|
||||||
}
|
}
|
||||||
@@ -1193,7 +1204,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
|||||||
shouldSuspendModel = true
|
shouldSuspendModel = true
|
||||||
setModelQuota = true
|
setModelQuota = true
|
||||||
case 408, 500, 502, 503, 504:
|
case 408, 500, 502, 503, 504:
|
||||||
if quotaCooldownDisabled.Load() {
|
if quotaCooldownDisabledForAuth(auth) {
|
||||||
state.NextRetryAfter = time.Time{}
|
state.NextRetryAfter = time.Time{}
|
||||||
} else {
|
} else {
|
||||||
next := now.Add(1 * time.Minute)
|
next := now.Add(1 * time.Minute)
|
||||||
@@ -1439,7 +1450,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
|||||||
if retryAfter != nil {
|
if retryAfter != nil {
|
||||||
next = now.Add(*retryAfter)
|
next = now.Add(*retryAfter)
|
||||||
} else {
|
} else {
|
||||||
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel)
|
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||||
if cooldown > 0 {
|
if cooldown > 0 {
|
||||||
next = now.Add(cooldown)
|
next = now.Add(cooldown)
|
||||||
}
|
}
|
||||||
@@ -1449,7 +1460,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
|||||||
auth.NextRetryAfter = next
|
auth.NextRetryAfter = next
|
||||||
case 408, 500, 502, 503, 504:
|
case 408, 500, 502, 503, 504:
|
||||||
auth.StatusMessage = "transient upstream error"
|
auth.StatusMessage = "transient upstream error"
|
||||||
if quotaCooldownDisabled.Load() {
|
if quotaCooldownDisabledForAuth(auth) {
|
||||||
auth.NextRetryAfter = time.Time{}
|
auth.NextRetryAfter = time.Time{}
|
||||||
} else {
|
} else {
|
||||||
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
||||||
@@ -1462,11 +1473,11 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
|||||||
}
|
}
|
||||||
|
|
||||||
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
|
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
|
||||||
func nextQuotaCooldown(prevLevel int) (time.Duration, int) {
|
func nextQuotaCooldown(prevLevel int, disableCooling bool) (time.Duration, int) {
|
||||||
if prevLevel < 0 {
|
if prevLevel < 0 {
|
||||||
prevLevel = 0
|
prevLevel = 0
|
||||||
}
|
}
|
||||||
if quotaCooldownDisabled.Load() {
|
if disableCooling {
|
||||||
return 0, prevLevel
|
return 0, prevLevel
|
||||||
}
|
}
|
||||||
cooldown := quotaBackoffBase * time.Duration(1<<prevLevel)
|
cooldown := quotaBackoffBase * time.Duration(1<<prevLevel)
|
||||||
|
|||||||
97
sdk/cliproxy/auth/conductor_overrides_test.go
Normal file
97
sdk/cliproxy/auth/conductor_overrides_test.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil, nil)
|
||||||
|
m.SetRetryConfig(3, 30*time.Second)
|
||||||
|
|
||||||
|
model := "test-model"
|
||||||
|
next := time.Now().Add(5 * time.Second)
|
||||||
|
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "auth-1",
|
||||||
|
Provider: "claude",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"request_retry": float64(0),
|
||||||
|
},
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Unavailable: true,
|
||||||
|
Status: StatusError,
|
||||||
|
NextRetryAfter: next,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
|
||||||
|
t.Fatalf("register auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, maxWait := m.retrySettings()
|
||||||
|
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
|
||||||
|
if shouldRetry {
|
||||||
|
t.Fatalf("expected shouldRetry=false for request_retry=0, got true (wait=%v)", wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.Metadata["request_retry"] = float64(1)
|
||||||
|
if _, errUpdate := m.Update(context.Background(), auth); errUpdate != nil {
|
||||||
|
t.Fatalf("update auth: %v", errUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
wait, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
|
||||||
|
if !shouldRetry {
|
||||||
|
t.Fatalf("expected shouldRetry=true for request_retry=1, got false")
|
||||||
|
}
|
||||||
|
if wait <= 0 {
|
||||||
|
t.Fatalf("expected wait > 0, got %v", wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 1, []string{"claude"}, model, maxWait)
|
||||||
|
if shouldRetry {
|
||||||
|
t.Fatalf("expected shouldRetry=false on attempt=1 for request_retry=1, got true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
|
||||||
|
prev := quotaCooldownDisabled.Load()
|
||||||
|
quotaCooldownDisabled.Store(false)
|
||||||
|
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
|
||||||
|
|
||||||
|
m := NewManager(nil, nil, nil)
|
||||||
|
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "auth-1",
|
||||||
|
Provider: "claude",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"disable_cooling": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
|
||||||
|
t.Fatalf("register auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
model := "test-model"
|
||||||
|
m.MarkResult(context.Background(), Result{
|
||||||
|
AuthID: "auth-1",
|
||||||
|
Provider: "claude",
|
||||||
|
Model: model,
|
||||||
|
Success: false,
|
||||||
|
Error: &Error{HTTPStatus: 500, Message: "boom"},
|
||||||
|
})
|
||||||
|
|
||||||
|
updated, ok := m.GetByID("auth-1")
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth to be present")
|
||||||
|
}
|
||||||
|
state := updated.ModelStates[model]
|
||||||
|
if state == nil {
|
||||||
|
t.Fatalf("expected model state to be present")
|
||||||
|
}
|
||||||
|
if !state.NextRetryAfter.IsZero() {
|
||||||
|
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -194,6 +194,108 @@ func (a *Auth) ProxyInfo() string {
|
|||||||
return "via proxy"
|
return "via proxy"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DisableCoolingOverride returns the auth-file scoped disable_cooling override when present.
|
||||||
|
// The value is read from metadata key "disable_cooling" (or legacy "disable-cooling").
|
||||||
|
func (a *Auth) DisableCoolingOverride() (bool, bool) {
|
||||||
|
if a == nil || a.Metadata == nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
if val, ok := a.Metadata["disable_cooling"]; ok {
|
||||||
|
if parsed, okParse := parseBoolAny(val); okParse {
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val, ok := a.Metadata["disable-cooling"]; ok {
|
||||||
|
if parsed, okParse := parseBoolAny(val); okParse {
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
|
||||||
|
// The value is read from metadata key "request_retry" (or legacy "request-retry").
|
||||||
|
func (a *Auth) RequestRetryOverride() (int, bool) {
|
||||||
|
if a == nil || a.Metadata == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if val, ok := a.Metadata["request_retry"]; ok {
|
||||||
|
if parsed, okParse := parseIntAny(val); okParse {
|
||||||
|
if parsed < 0 {
|
||||||
|
parsed = 0
|
||||||
|
}
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val, ok := a.Metadata["request-retry"]; ok {
|
||||||
|
if parsed, okParse := parseIntAny(val); okParse {
|
||||||
|
if parsed < 0 {
|
||||||
|
parsed = 0
|
||||||
|
}
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBoolAny(val any) (bool, bool) {
|
||||||
|
switch typed := val.(type) {
|
||||||
|
case bool:
|
||||||
|
return typed, true
|
||||||
|
case string:
|
||||||
|
trimmed := strings.TrimSpace(typed)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
parsed, err := strconv.ParseBool(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return parsed, true
|
||||||
|
case float64:
|
||||||
|
return typed != 0, true
|
||||||
|
case json.Number:
|
||||||
|
parsed, err := typed.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return parsed != 0, true
|
||||||
|
default:
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIntAny(val any) (int, bool) {
|
||||||
|
switch typed := val.(type) {
|
||||||
|
case int:
|
||||||
|
return typed, true
|
||||||
|
case int32:
|
||||||
|
return int(typed), true
|
||||||
|
case int64:
|
||||||
|
return int(typed), true
|
||||||
|
case float64:
|
||||||
|
return int(typed), true
|
||||||
|
case json.Number:
|
||||||
|
parsed, err := typed.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(parsed), true
|
||||||
|
case string:
|
||||||
|
trimmed := strings.TrimSpace(typed)
|
||||||
|
if trimmed == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
parsed, err := strconv.Atoi(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return parsed, true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Auth) AccountInfo() (string, string) {
|
func (a *Auth) AccountInfo() (string, string) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
|
|||||||
Reference in New Issue
Block a user