diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index 45602773..16881df4 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -485,85 +485,90 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { if result.AuthID == "" { return } - // Update in-memory auth status based on result. + shouldResumeModel := false shouldSuspendModel := false suspendReason := "" + clearModelQuota := false + setModelQuota := false + m.mu.Lock() if auth, ok := m.auths[result.AuthID]; ok && auth != nil { now := time.Now() + if result.Success { - // Clear transient error/quota flags on success. - auth.Unavailable = false - auth.Status = StatusActive - auth.StatusMessage = "" - auth.Quota.Exceeded = false - auth.Quota.Reason = "" - auth.Quota.NextRecoverAt = time.Time{} - auth.LastError = nil - auth.UpdatedAt = now if result.Model != "" { - registry.GetGlobalRegistry().ClearModelQuotaExceeded(auth.ID, result.Model) + state := ensureModelState(auth, result.Model) + resetModelState(state, now) + updateAggregatedAvailability(auth, now) + if !hasModelError(auth, now) { + auth.LastError = nil + auth.StatusMessage = "" + auth.Status = StatusActive + } + auth.UpdatedAt = now shouldResumeModel = true + clearModelQuota = true + } else { + clearAuthStateOnSuccess(auth, now) } } else { - // Default transient error state. - auth.Unavailable = true - auth.Status = StatusError - auth.UpdatedAt = now - if result.Error != nil { - auth.LastError = &Error{Code: result.Error.Code, Message: result.Error.Message, Retryable: result.Error.Retryable} - } - // If the error carries a status code, adjust backoff/quota accordingly. - // 401 -> auth issue; 402 -> billing; 403 -> forbidden; 429 -> quota; 5xx -> transient. - var statusCode int - if se, isOk := any(result.Error).(interface{ StatusCode() int }); isOk && se != nil { - statusCode = se.StatusCode() - } - switch statusCode { - case 401: - auth.StatusMessage = "unauthorized" - auth.NextRetryAfter = now.Add(30 * time.Minute) - if result.Model != "" { - shouldSuspendModel = true + if result.Model != "" { + state := ensureModelState(auth, result.Model) + state.Unavailable = true + state.Status = StatusError + state.UpdatedAt = now + if result.Error != nil { + state.LastError = cloneError(result.Error) + state.StatusMessage = result.Error.Message + auth.LastError = cloneError(result.Error) + auth.StatusMessage = result.Error.Message + } + + statusCode := statusCodeFromResult(result.Error) + switch statusCode { + case 401: + next := now.Add(30 * time.Minute) + state.NextRetryAfter = next suspendReason = "unauthorized" - } - case 402, 403: - auth.StatusMessage = "payment_required" - auth.NextRetryAfter = now.Add(30 * time.Minute) - if result.Model != "" { shouldSuspendModel = true + case 402, 403: + next := now.Add(30 * time.Minute) + state.NextRetryAfter = next suspendReason = "payment_required" - } - case 429: - auth.StatusMessage = "quota exhausted" - auth.Quota.Exceeded = true - auth.Quota.Reason = "quota" - auth.Quota.NextRecoverAt = now.Add(30 * time.Minute) - auth.NextRetryAfter = auth.Quota.NextRecoverAt - if result.Model != "" { shouldSuspendModel = true - registry.GetGlobalRegistry().SetModelQuotaExceeded(auth.ID, result.Model) - } - case 408, 500, 502, 503, 504: - auth.StatusMessage = "transient upstream error" - auth.NextRetryAfter = now.Add(1 * time.Minute) - if result.Model != "" { - shouldSuspendModel = false - suspendReason = "forbidden" - } - default: - // keep generic - if auth.StatusMessage == "" { - auth.StatusMessage = "request failed" + case 429: + next := now.Add(30 * time.Minute) + state.NextRetryAfter = next + state.Quota = QuotaState{Exceeded: true, Reason: "quota", NextRecoverAt: next} + suspendReason = "quota" + shouldSuspendModel = true + setModelQuota = true + case 408, 500, 502, 503, 504: + next := now.Add(1 * time.Minute) + state.NextRetryAfter = next + default: + state.NextRetryAfter = time.Time{} } + + auth.Status = StatusError + auth.UpdatedAt = now + updateAggregatedAvailability(auth, now) + } else { + applyAuthFailureState(auth, result.Error, now) } } - // Persist best-effort (only metadata is stored for file store). + _ = m.persist(ctx, auth) } m.mu.Unlock() + if clearModelQuota && result.Model != "" { + registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model) + } + if setModelQuota && result.Model != "" { + registry.GetGlobalRegistry().SetModelQuotaExceeded(result.AuthID, result.Model) + } if shouldResumeModel { registry.GetGlobalRegistry().ResumeClientModel(result.AuthID, result.Model) } else if shouldSuspendModel { @@ -573,6 +578,180 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { m.hook.OnResult(ctx, result) } +func ensureModelState(auth *Auth, model string) *ModelState { + if auth == nil || model == "" { + return nil + } + if auth.ModelStates == nil { + auth.ModelStates = make(map[string]*ModelState) + } + if state, ok := auth.ModelStates[model]; ok && state != nil { + return state + } + state := &ModelState{Status: StatusActive} + auth.ModelStates[model] = state + return state +} + +func resetModelState(state *ModelState, now time.Time) { + if state == nil { + return + } + state.Unavailable = false + state.Status = StatusActive + state.StatusMessage = "" + state.NextRetryAfter = time.Time{} + state.LastError = nil + state.Quota = QuotaState{} + state.UpdatedAt = now +} + +func updateAggregatedAvailability(auth *Auth, now time.Time) { + if auth == nil || len(auth.ModelStates) == 0 { + return + } + allUnavailable := true + earliestRetry := time.Time{} + quotaExceeded := false + quotaRecover := time.Time{} + for _, state := range auth.ModelStates { + if state == nil { + continue + } + stateUnavailable := false + if state.Status == StatusDisabled { + stateUnavailable = true + } else if state.Unavailable { + if state.NextRetryAfter.IsZero() { + stateUnavailable = true + } else if state.NextRetryAfter.After(now) { + stateUnavailable = true + if earliestRetry.IsZero() || state.NextRetryAfter.Before(earliestRetry) { + earliestRetry = state.NextRetryAfter + } + } else { + state.Unavailable = false + state.NextRetryAfter = time.Time{} + } + } + if !stateUnavailable { + allUnavailable = false + } + if state.Quota.Exceeded { + quotaExceeded = true + if quotaRecover.IsZero() || (!state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.Before(quotaRecover)) { + quotaRecover = state.Quota.NextRecoverAt + } + } + } + auth.Unavailable = allUnavailable + if allUnavailable { + auth.NextRetryAfter = earliestRetry + } else { + auth.NextRetryAfter = time.Time{} + } + if quotaExceeded { + auth.Quota.Exceeded = true + auth.Quota.Reason = "quota" + auth.Quota.NextRecoverAt = quotaRecover + } else { + auth.Quota.Exceeded = false + auth.Quota.Reason = "" + auth.Quota.NextRecoverAt = time.Time{} + } +} + +func hasModelError(auth *Auth, now time.Time) bool { + if auth == nil || len(auth.ModelStates) == 0 { + return false + } + for _, state := range auth.ModelStates { + if state == nil { + continue + } + if state.LastError != nil { + return true + } + if state.Status == StatusError { + if state.Unavailable && (state.NextRetryAfter.IsZero() || state.NextRetryAfter.After(now)) { + return true + } + } + } + return false +} + +func clearAuthStateOnSuccess(auth *Auth, now time.Time) { + if auth == nil { + return + } + auth.Unavailable = false + auth.Status = StatusActive + auth.StatusMessage = "" + auth.Quota.Exceeded = false + auth.Quota.Reason = "" + auth.Quota.NextRecoverAt = time.Time{} + auth.LastError = nil + auth.NextRetryAfter = time.Time{} + auth.UpdatedAt = now +} + +func cloneError(err *Error) *Error { + if err == nil { + return nil + } + return &Error{ + Code: err.Code, + Message: err.Message, + Retryable: err.Retryable, + HTTPStatus: err.HTTPStatus, + } +} + +func statusCodeFromResult(err *Error) int { + if err == nil { + return 0 + } + return err.StatusCode() +} + +func applyAuthFailureState(auth *Auth, resultErr *Error, now time.Time) { + if auth == nil { + return + } + auth.Unavailable = true + auth.Status = StatusError + auth.UpdatedAt = now + if resultErr != nil { + auth.LastError = cloneError(resultErr) + if resultErr.Message != "" { + auth.StatusMessage = resultErr.Message + } + } + statusCode := statusCodeFromResult(resultErr) + switch statusCode { + case 401: + auth.StatusMessage = "unauthorized" + auth.NextRetryAfter = now.Add(30 * time.Minute) + case 402, 403: + auth.StatusMessage = "payment_required" + auth.NextRetryAfter = now.Add(30 * time.Minute) + case 429: + auth.StatusMessage = "quota exhausted" + auth.Quota.Exceeded = true + auth.Quota.Reason = "quota" + auth.Quota.NextRecoverAt = now.Add(30 * time.Minute) + auth.NextRetryAfter = auth.Quota.NextRecoverAt + case 408, 500, 502, 503, 504: + auth.StatusMessage = "transient upstream error" + auth.NextRetryAfter = now.Add(1 * time.Minute) + default: + if auth.StatusMessage == "" { + auth.StatusMessage = "request failed" + } + } +} + // List returns all auth entries currently known by the manager. func (m *Manager) List() []*Auth { m.mu.RLock() diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index 291c590c..f356cce9 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -28,10 +28,7 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o now := time.Now() for i := 0; i < len(auths); i++ { candidate := auths[i] - if candidate.Unavailable && candidate.NextRetryAfter.After(now) { - continue - } - if candidate.Status == StatusDisabled || candidate.Disabled { + if isAuthBlockedForModel(candidate, model, now) { continue } available = append(available, candidate) @@ -52,3 +49,31 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o // log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available)) return available[index%len(available)], nil } + +func isAuthBlockedForModel(auth *Auth, model string, now time.Time) bool { + if auth == nil { + return true + } + if auth.Disabled || auth.Status == StatusDisabled { + return true + } + if model != "" && len(auth.ModelStates) > 0 { + if state, ok := auth.ModelStates[model]; ok && state != nil { + if state.Status == StatusDisabled { + return true + } + if state.Unavailable { + if state.NextRetryAfter.IsZero() { + return false + } + if state.NextRetryAfter.After(now) { + return true + } + } + } + } + if auth.Unavailable && auth.NextRetryAfter.After(now) { + return true + } + return false +} diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 27b5314d..98b5f288 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -45,6 +45,8 @@ type Auth struct { NextRefreshAfter time.Time `json:"next_refresh_after"` // NextRetryAfter is the earliest time a retry should retrigger. NextRetryAfter time.Time `json:"next_retry_after"` + // ModelStates tracks per-model runtime availability data. + ModelStates map[string]*ModelState `json:"model_states,omitempty"` // Runtime carries non-serialisable data used during execution (in-memory only). Runtime any `json:"-"` @@ -60,6 +62,24 @@ type QuotaState struct { NextRecoverAt time.Time `json:"next_recover_at"` } +// ModelState captures the execution state for a specific model under an auth entry. +type ModelState struct { + // Status reflects the lifecycle status for this model. + Status Status `json:"status"` + // StatusMessage provides an optional short description of the status. + StatusMessage string `json:"status_message,omitempty"` + // Unavailable mirrors whether the model is temporarily blocked for retries. + Unavailable bool `json:"unavailable"` + // NextRetryAfter defines the per-model retry time. + NextRetryAfter time.Time `json:"next_retry_after"` + // LastError records the latest error observed for this model. + LastError *Error `json:"last_error,omitempty"` + // Quota retains quota information if this model hit rate limits. + Quota QuotaState `json:"quota"` + // UpdatedAt tracks the last update timestamp for this model state. + UpdatedAt time.Time `json:"updated_at"` +} + // Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation. func (a *Auth) Clone() *Auth { if a == nil { @@ -78,10 +98,33 @@ func (a *Auth) Clone() *Auth { copyAuth.Metadata[key] = value } } + if len(a.ModelStates) > 0 { + copyAuth.ModelStates = make(map[string]*ModelState, len(a.ModelStates)) + for key, state := range a.ModelStates { + copyAuth.ModelStates[key] = state.Clone() + } + } copyAuth.Runtime = a.Runtime return ©Auth } +// Clone duplicates a model state including nested error details. +func (m *ModelState) Clone() *ModelState { + if m == nil { + return nil + } + copyState := *m + if m.LastError != nil { + copyState.LastError = &Error{ + Code: m.LastError.Code, + Message: m.LastError.Message, + Retryable: m.LastError.Retryable, + HTTPStatus: m.LastError.HTTPStatus, + } + } + return ©State +} + func (a *Auth) AccountInfo() (string, string) { if a == nil { return "", ""