feat: improve error handling with added status codes and headers

- Updated Execute methods to include enhanced error handling via `StatusCode` and `Headers` extraction.
- Introduced structured error responses for cooling down scenarios, providing additional metadata and retry suggestions.
- Refined quota management, allowing for differentiation between cool-down, disabled, and other block reasons.
- Improved model filtering logic based on client availability and suspension criteria.
This commit is contained in:
Luis Pater
2025-10-22 09:01:11 +08:00
parent 9678be7aa4
commit d225558dae
3 changed files with 211 additions and 32 deletions

View File

@@ -156,7 +156,19 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
}
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
return cloneBytes(resp.Payload), nil
}
@@ -187,7 +199,19 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
}
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
return cloneBytes(resp.Payload), nil
}
@@ -222,7 +246,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
close(errChan)
return nil, errChan
}
@@ -233,7 +269,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
defer close(errChan)
for chunk := range chunks {
if chunk.Err != nil {
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: chunk.Err}
status := http.StatusInternalServerError
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
return
}
if len(chunk.Payload) > 0 {
@@ -287,6 +335,17 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
if msg != nil && msg.StatusCode > 0 {
status = msg.StatusCode
}
if msg != nil && msg.Addon != nil {
for key, values := range msg.Addon {
if len(values) == 0 {
continue
}
c.Writer.Header().Del(key)
for _, value := range values {
c.Writer.Header().Add(key, value)
}
}
}
c.Status(status)
if msg != nil && msg.Error != nil {
_, _ = c.Writer.Write([]byte(msg.Error.Error()))

View File

@@ -2,7 +2,12 @@ package auth
import (
"context"
"encoding/json"
"fmt"
"math"
"net/http"
"sort"
"strconv"
"sync"
"time"
@@ -15,6 +20,84 @@ type RoundRobinSelector struct {
cursors map[string]int
}
type blockReason int
const (
blockReasonNone blockReason = iota
blockReasonCooldown
blockReasonDisabled
blockReasonOther
)
type modelCooldownError struct {
model string
resetIn time.Duration
provider string
}
func newModelCooldownError(model, provider string, resetIn time.Duration) *modelCooldownError {
if resetIn < 0 {
resetIn = 0
}
return &modelCooldownError{
model: model,
provider: provider,
resetIn: resetIn,
}
}
func (e *modelCooldownError) Error() string {
modelName := e.model
if modelName == "" {
modelName = "requested model"
}
message := fmt.Sprintf("All credentials for model %s are cooling down", modelName)
if e.provider != "" {
message = fmt.Sprintf("%s via provider %s", message, e.provider)
}
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
if resetSeconds < 0 {
resetSeconds = 0
}
displayDuration := e.resetIn
if displayDuration > 0 && displayDuration < time.Second {
displayDuration = time.Second
} else {
displayDuration = displayDuration.Round(time.Second)
}
errorBody := map[string]any{
"code": "model_cooldown",
"message": message,
"model": e.model,
"reset_time": displayDuration.String(),
"reset_seconds": resetSeconds,
}
if e.provider != "" {
errorBody["provider"] = e.provider
}
payload := map[string]any{"error": errorBody}
data, err := json.Marshal(payload)
if err != nil {
return fmt.Sprintf(`{"error":{"code":"model_cooldown","message":"%s"}}`, message)
}
return string(data)
}
func (e *modelCooldownError) StatusCode() int {
return http.StatusTooManyRequests
}
func (e *modelCooldownError) Headers() http.Header {
headers := make(http.Header)
headers.Set("Content-Type", "application/json")
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
if resetSeconds < 0 {
resetSeconds = 0
}
headers.Set("Retry-After", strconv.Itoa(resetSeconds))
return headers
}
// Pick selects the next available auth for the provider in a round-robin manner.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
@@ -27,14 +110,30 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
}
available := make([]*Auth, 0, len(auths))
now := time.Now()
cooldownCount := 0
var earliest time.Time
for i := 0; i < len(auths); i++ {
candidate := auths[i]
if isAuthBlockedForModel(candidate, model, now) {
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
if !blocked {
available = append(available, candidate)
continue
}
available = append(available, candidate)
if reason == blockReasonCooldown {
cooldownCount++
if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) {
earliest = next
}
}
}
if len(available) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() {
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return nil, newModelCooldownError(model, provider, resetIn)
}
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// Make round-robin deterministic even if caller's candidate order is unstable.
@@ -55,41 +154,54 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
return available[index%len(available)], nil
}
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) bool {
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) {
if auth == nil {
return true
return true, blockReasonOther, time.Time{}
}
if auth.Disabled || auth.Status == StatusDisabled {
return true
return true, blockReasonDisabled, time.Time{}
}
// If a specific model is requested, prefer its per-model state over any aggregated
// auth-level unavailable flag. This prevents a failure on one model (e.g., 429 quota)
// from blocking other models of the same provider that have no errors.
if model != "" {
if len(auth.ModelStates) > 0 {
if state, ok := auth.ModelStates[model]; ok && state != nil {
if state.Status == StatusDisabled {
return true
return true, blockReasonDisabled, time.Time{}
}
if state.Unavailable {
if state.NextRetryAfter.IsZero() {
return false
return false, blockReasonNone, time.Time{}
}
if state.NextRetryAfter.After(now) {
return true
next := state.NextRetryAfter
if !state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.After(now) {
next = state.Quota.NextRecoverAt
}
if next.Before(now) {
next = now
}
if state.Quota.Exceeded {
return true, blockReasonCooldown, next
}
return true, blockReasonOther, next
}
}
// Explicit state exists and is not blocking.
return false
return false, blockReasonNone, time.Time{}
}
}
// No explicit state for this model; do not block based on aggregated
// auth-level unavailable status. Allow trying this model.
return false
return false, blockReasonNone, time.Time{}
}
// No specific model context: fall back to auth-level unavailable window.
if auth.Unavailable && auth.NextRetryAfter.After(now) {
return true
next := auth.NextRetryAfter
if !auth.Quota.NextRecoverAt.IsZero() && auth.Quota.NextRecoverAt.After(now) {
next = auth.Quota.NextRecoverAt
}
if next.Before(now) {
next = now
}
if auth.Quota.Exceeded {
return true, blockReasonCooldown, next
}
return true, blockReasonOther, next
}
return false
return false, blockReasonNone, time.Time{}
}