From d225558dae11bc87084864b932b96ed4c245e32b Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 22 Oct 2025 09:01:11 +0800 Subject: [PATCH] feat: improve error handling with added status codes and headers - Updated Execute methods to include enhanced error handling via `StatusCode` and `Headers` extraction. - Introduced structured error responses for cooling down scenarios, providing additional metadata and retry suggestions. - Refined quota management, allowing for differentiation between cool-down, disabled, and other block reasons. - Improved model filtering logic based on client availability and suspension criteria. --- internal/registry/model_registry.go | 26 +++-- sdk/api/handlers/handlers.go | 67 ++++++++++++- sdk/cliproxy/auth/selector.go | 150 ++++++++++++++++++++++++---- 3 files changed, 211 insertions(+), 32 deletions(-) diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 29607a5b..e1223978 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -352,14 +352,14 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo { if model == nil { return nil } - copy := *model + copyModel := *model if len(model.SupportedGenerationMethods) > 0 { - copy.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...) + copyModel.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...) } if len(model.SupportedParameters) > 0 { - copy.SupportedParameters = append([]string(nil), model.SupportedParameters...) + copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...) } - return © + return ©Model } // UnregisterClient removes a client and decrements counts for its models @@ -532,17 +532,25 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any } } - suspendedClients := 0 + cooldownSuspended := 0 + otherSuspended := 0 if registration.SuspendedClients != nil { - suspendedClients = len(registration.SuspendedClients) + for _, reason := range registration.SuspendedClients { + if strings.EqualFold(reason, "quota") { + cooldownSuspended++ + continue + } + otherSuspended++ + } } - effectiveClients := availableClients - expiredClients - suspendedClients + + effectiveClients := availableClients - expiredClients - otherSuspended if effectiveClients < 0 { effectiveClients = 0 } - // Only include models that have available clients - if effectiveClients > 0 { + // Include models that have available clients, or those solely cooling down. + if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) { model := r.convertModelToMap(registration.Info, handlerType) if model != nil { models = append(models, model) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 0eb4588a..73c647f3 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -156,7 +156,19 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType } resp, err := h.AuthManager.Execute(ctx, providers, req, opts) if err != nil { - return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + var addon http.Header + if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } return cloneBytes(resp.Payload), nil } @@ -187,7 +199,19 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle } resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) if err != nil { - return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + var addon http.Header + if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } return cloneBytes(resp.Payload), nil } @@ -222,7 +246,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { errChan := make(chan *interfaces.ErrorMessage, 1) - errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + var addon http.Header + if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} close(errChan) return nil, errChan } @@ -233,7 +269,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl defer close(errChan) for chunk := range chunks { if chunk.Err != nil { - errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: chunk.Err} + status := http.StatusInternalServerError + if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + var addon http.Header + if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon} return } if len(chunk.Payload) > 0 { @@ -287,6 +335,17 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro if msg != nil && msg.StatusCode > 0 { status = msg.StatusCode } + if msg != nil && msg.Addon != nil { + for key, values := range msg.Addon { + if len(values) == 0 { + continue + } + c.Writer.Header().Del(key) + for _, value := range values { + c.Writer.Header().Add(key, value) + } + } + } c.Status(status) if msg != nil && msg.Error != nil { _, _ = c.Writer.Write([]byte(msg.Error.Error())) diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index 83d5d90d..d4edc8bd 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -2,7 +2,12 @@ package auth import ( "context" + "encoding/json" + "fmt" + "math" + "net/http" "sort" + "strconv" "sync" "time" @@ -15,6 +20,84 @@ type RoundRobinSelector struct { cursors map[string]int } +type blockReason int + +const ( + blockReasonNone blockReason = iota + blockReasonCooldown + blockReasonDisabled + blockReasonOther +) + +type modelCooldownError struct { + model string + resetIn time.Duration + provider string +} + +func newModelCooldownError(model, provider string, resetIn time.Duration) *modelCooldownError { + if resetIn < 0 { + resetIn = 0 + } + return &modelCooldownError{ + model: model, + provider: provider, + resetIn: resetIn, + } +} + +func (e *modelCooldownError) Error() string { + modelName := e.model + if modelName == "" { + modelName = "requested model" + } + message := fmt.Sprintf("All credentials for model %s are cooling down", modelName) + if e.provider != "" { + message = fmt.Sprintf("%s via provider %s", message, e.provider) + } + resetSeconds := int(math.Ceil(e.resetIn.Seconds())) + if resetSeconds < 0 { + resetSeconds = 0 + } + displayDuration := e.resetIn + if displayDuration > 0 && displayDuration < time.Second { + displayDuration = time.Second + } else { + displayDuration = displayDuration.Round(time.Second) + } + errorBody := map[string]any{ + "code": "model_cooldown", + "message": message, + "model": e.model, + "reset_time": displayDuration.String(), + "reset_seconds": resetSeconds, + } + if e.provider != "" { + errorBody["provider"] = e.provider + } + payload := map[string]any{"error": errorBody} + data, err := json.Marshal(payload) + if err != nil { + return fmt.Sprintf(`{"error":{"code":"model_cooldown","message":"%s"}}`, message) + } + return string(data) +} + +func (e *modelCooldownError) StatusCode() int { + return http.StatusTooManyRequests +} + +func (e *modelCooldownError) Headers() http.Header { + headers := make(http.Header) + headers.Set("Content-Type", "application/json") + resetSeconds := int(math.Ceil(e.resetIn.Seconds())) + if resetSeconds < 0 { + resetSeconds = 0 + } + headers.Set("Retry-After", strconv.Itoa(resetSeconds)) + return headers +} + // Pick selects the next available auth for the provider in a round-robin manner. func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = ctx @@ -27,14 +110,30 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o } available := make([]*Auth, 0, len(auths)) now := time.Now() + cooldownCount := 0 + var earliest time.Time for i := 0; i < len(auths); i++ { candidate := auths[i] - if isAuthBlockedForModel(candidate, model, now) { + blocked, reason, next := isAuthBlockedForModel(candidate, model, now) + if !blocked { + available = append(available, candidate) continue } - available = append(available, candidate) + if reason == blockReasonCooldown { + cooldownCount++ + if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) { + earliest = next + } + } } if len(available) == 0 { + if cooldownCount == len(auths) && !earliest.IsZero() { + resetIn := earliest.Sub(now) + if resetIn < 0 { + resetIn = 0 + } + return nil, newModelCooldownError(model, provider, resetIn) + } return nil, &Error{Code: "auth_unavailable", Message: "no auth available"} } // Make round-robin deterministic even if caller's candidate order is unstable. @@ -55,41 +154,54 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o return available[index%len(available)], nil } -func isAuthBlockedForModel(auth *Auth, model string, now time.Time) bool { +func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) { if auth == nil { - return true + return true, blockReasonOther, time.Time{} } if auth.Disabled || auth.Status == StatusDisabled { - return true + return true, blockReasonDisabled, time.Time{} } - // If a specific model is requested, prefer its per-model state over any aggregated - // auth-level unavailable flag. This prevents a failure on one model (e.g., 429 quota) - // from blocking other models of the same provider that have no errors. if model != "" { if len(auth.ModelStates) > 0 { if state, ok := auth.ModelStates[model]; ok && state != nil { if state.Status == StatusDisabled { - return true + return true, blockReasonDisabled, time.Time{} } if state.Unavailable { if state.NextRetryAfter.IsZero() { - return false + return false, blockReasonNone, time.Time{} } if state.NextRetryAfter.After(now) { - return true + next := state.NextRetryAfter + if !state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.After(now) { + next = state.Quota.NextRecoverAt + } + if next.Before(now) { + next = now + } + if state.Quota.Exceeded { + return true, blockReasonCooldown, next + } + return true, blockReasonOther, next } } - // Explicit state exists and is not blocking. - return false + return false, blockReasonNone, time.Time{} } } - // No explicit state for this model; do not block based on aggregated - // auth-level unavailable status. Allow trying this model. - return false + return false, blockReasonNone, time.Time{} } - // No specific model context: fall back to auth-level unavailable window. if auth.Unavailable && auth.NextRetryAfter.After(now) { - return true + next := auth.NextRetryAfter + if !auth.Quota.NextRecoverAt.IsZero() && auth.Quota.NextRecoverAt.After(now) { + next = auth.Quota.NextRecoverAt + } + if next.Before(now) { + next = now + } + if auth.Quota.Exceeded { + return true, blockReasonCooldown, next + } + return true, blockReasonOther, next } - return false + return false, blockReasonNone, time.Time{} }