feat: improve error handling with added status codes and headers

- Updated Execute methods to include enhanced error handling via `StatusCode` and `Headers` extraction.
- Introduced structured error responses for cooling down scenarios, providing additional metadata and retry suggestions.
- Refined quota management, allowing for differentiation between cool-down, disabled, and other block reasons.
- Improved model filtering logic based on client availability and suspension criteria.
This commit is contained in:
Luis Pater
2025-10-22 09:01:11 +08:00
parent 9678be7aa4
commit d225558dae
3 changed files with 211 additions and 32 deletions

View File

@@ -352,14 +352,14 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if model == nil { if model == nil {
return nil return nil
} }
copy := *model copyModel := *model
if len(model.SupportedGenerationMethods) > 0 { if len(model.SupportedGenerationMethods) > 0 {
copy.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...) copyModel.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
} }
if len(model.SupportedParameters) > 0 { if len(model.SupportedParameters) > 0 {
copy.SupportedParameters = append([]string(nil), model.SupportedParameters...) copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
} }
return &copy return &copyModel
} }
// UnregisterClient removes a client and decrements counts for its models // UnregisterClient removes a client and decrements counts for its models
@@ -532,17 +532,25 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
} }
} }
suspendedClients := 0 cooldownSuspended := 0
otherSuspended := 0
if registration.SuspendedClients != nil { if registration.SuspendedClients != nil {
suspendedClients = len(registration.SuspendedClients) for _, reason := range registration.SuspendedClients {
if strings.EqualFold(reason, "quota") {
cooldownSuspended++
continue
}
otherSuspended++
}
} }
effectiveClients := availableClients - expiredClients - suspendedClients
effectiveClients := availableClients - expiredClients - otherSuspended
if effectiveClients < 0 { if effectiveClients < 0 {
effectiveClients = 0 effectiveClients = 0
} }
// Only include models that have available clients // Include models that have available clients, or those solely cooling down.
if effectiveClients > 0 { if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
model := r.convertModelToMap(registration.Info, handlerType) model := r.convertModelToMap(registration.Info, handlerType)
if model != nil { if model != nil {
models = append(models, model) models = append(models, model)

View File

@@ -156,7 +156,19 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
} }
resp, err := h.AuthManager.Execute(ctx, providers, req, opts) resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil { if err != nil {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
} }
return cloneBytes(resp.Payload), nil return cloneBytes(resp.Payload), nil
} }
@@ -187,7 +199,19 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
} }
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil { if err != nil {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
} }
return cloneBytes(resp.Payload), nil return cloneBytes(resp.Payload), nil
} }
@@ -222,7 +246,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil { if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1) errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
close(errChan) close(errChan)
return nil, errChan return nil, errChan
} }
@@ -233,7 +269,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
defer close(errChan) defer close(errChan)
for chunk := range chunks { for chunk := range chunks {
if chunk.Err != nil { if chunk.Err != nil {
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: chunk.Err} status := http.StatusInternalServerError
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
return return
} }
if len(chunk.Payload) > 0 { if len(chunk.Payload) > 0 {
@@ -287,6 +335,17 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
if msg != nil && msg.StatusCode > 0 { if msg != nil && msg.StatusCode > 0 {
status = msg.StatusCode status = msg.StatusCode
} }
if msg != nil && msg.Addon != nil {
for key, values := range msg.Addon {
if len(values) == 0 {
continue
}
c.Writer.Header().Del(key)
for _, value := range values {
c.Writer.Header().Add(key, value)
}
}
}
c.Status(status) c.Status(status)
if msg != nil && msg.Error != nil { if msg != nil && msg.Error != nil {
_, _ = c.Writer.Write([]byte(msg.Error.Error())) _, _ = c.Writer.Write([]byte(msg.Error.Error()))

View File

@@ -2,7 +2,12 @@ package auth
import ( import (
"context" "context"
"encoding/json"
"fmt"
"math"
"net/http"
"sort" "sort"
"strconv"
"sync" "sync"
"time" "time"
@@ -15,6 +20,84 @@ type RoundRobinSelector struct {
cursors map[string]int cursors map[string]int
} }
type blockReason int
const (
blockReasonNone blockReason = iota
blockReasonCooldown
blockReasonDisabled
blockReasonOther
)
type modelCooldownError struct {
model string
resetIn time.Duration
provider string
}
func newModelCooldownError(model, provider string, resetIn time.Duration) *modelCooldownError {
if resetIn < 0 {
resetIn = 0
}
return &modelCooldownError{
model: model,
provider: provider,
resetIn: resetIn,
}
}
func (e *modelCooldownError) Error() string {
modelName := e.model
if modelName == "" {
modelName = "requested model"
}
message := fmt.Sprintf("All credentials for model %s are cooling down", modelName)
if e.provider != "" {
message = fmt.Sprintf("%s via provider %s", message, e.provider)
}
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
if resetSeconds < 0 {
resetSeconds = 0
}
displayDuration := e.resetIn
if displayDuration > 0 && displayDuration < time.Second {
displayDuration = time.Second
} else {
displayDuration = displayDuration.Round(time.Second)
}
errorBody := map[string]any{
"code": "model_cooldown",
"message": message,
"model": e.model,
"reset_time": displayDuration.String(),
"reset_seconds": resetSeconds,
}
if e.provider != "" {
errorBody["provider"] = e.provider
}
payload := map[string]any{"error": errorBody}
data, err := json.Marshal(payload)
if err != nil {
return fmt.Sprintf(`{"error":{"code":"model_cooldown","message":"%s"}}`, message)
}
return string(data)
}
func (e *modelCooldownError) StatusCode() int {
return http.StatusTooManyRequests
}
func (e *modelCooldownError) Headers() http.Header {
headers := make(http.Header)
headers.Set("Content-Type", "application/json")
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
if resetSeconds < 0 {
resetSeconds = 0
}
headers.Set("Retry-After", strconv.Itoa(resetSeconds))
return headers
}
// Pick selects the next available auth for the provider in a round-robin manner. // Pick selects the next available auth for the provider in a round-robin manner.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx _ = ctx
@@ -27,14 +110,30 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
} }
available := make([]*Auth, 0, len(auths)) available := make([]*Auth, 0, len(auths))
now := time.Now() now := time.Now()
cooldownCount := 0
var earliest time.Time
for i := 0; i < len(auths); i++ { for i := 0; i < len(auths); i++ {
candidate := auths[i] candidate := auths[i]
if isAuthBlockedForModel(candidate, model, now) { blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
if !blocked {
available = append(available, candidate)
continue continue
} }
available = append(available, candidate) if reason == blockReasonCooldown {
cooldownCount++
if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) {
earliest = next
}
}
} }
if len(available) == 0 { if len(available) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() {
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return nil, newModelCooldownError(model, provider, resetIn)
}
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"} return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
} }
// Make round-robin deterministic even if caller's candidate order is unstable. // Make round-robin deterministic even if caller's candidate order is unstable.
@@ -55,41 +154,54 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
return available[index%len(available)], nil return available[index%len(available)], nil
} }
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) bool { func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) {
if auth == nil { if auth == nil {
return true return true, blockReasonOther, time.Time{}
} }
if auth.Disabled || auth.Status == StatusDisabled { if auth.Disabled || auth.Status == StatusDisabled {
return true return true, blockReasonDisabled, time.Time{}
} }
// If a specific model is requested, prefer its per-model state over any aggregated
// auth-level unavailable flag. This prevents a failure on one model (e.g., 429 quota)
// from blocking other models of the same provider that have no errors.
if model != "" { if model != "" {
if len(auth.ModelStates) > 0 { if len(auth.ModelStates) > 0 {
if state, ok := auth.ModelStates[model]; ok && state != nil { if state, ok := auth.ModelStates[model]; ok && state != nil {
if state.Status == StatusDisabled { if state.Status == StatusDisabled {
return true return true, blockReasonDisabled, time.Time{}
} }
if state.Unavailable { if state.Unavailable {
if state.NextRetryAfter.IsZero() { if state.NextRetryAfter.IsZero() {
return false return false, blockReasonNone, time.Time{}
} }
if state.NextRetryAfter.After(now) { if state.NextRetryAfter.After(now) {
return true next := state.NextRetryAfter
if !state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.After(now) {
next = state.Quota.NextRecoverAt
}
if next.Before(now) {
next = now
}
if state.Quota.Exceeded {
return true, blockReasonCooldown, next
}
return true, blockReasonOther, next
} }
} }
// Explicit state exists and is not blocking. return false, blockReasonNone, time.Time{}
return false
} }
} }
// No explicit state for this model; do not block based on aggregated return false, blockReasonNone, time.Time{}
// auth-level unavailable status. Allow trying this model.
return false
} }
// No specific model context: fall back to auth-level unavailable window.
if auth.Unavailable && auth.NextRetryAfter.After(now) { if auth.Unavailable && auth.NextRetryAfter.After(now) {
return true next := auth.NextRetryAfter
if !auth.Quota.NextRecoverAt.IsZero() && auth.Quota.NextRecoverAt.After(now) {
next = auth.Quota.NextRecoverAt
}
if next.Before(now) {
next = now
}
if auth.Quota.Exceeded {
return true, blockReasonCooldown, next
}
return true, blockReasonOther, next
} }
return false return false, blockReasonNone, time.Time{}
} }