Files
CLIProxyAPI/sdk/cliproxy/auth/selector.go

268 lines
7.0 KiB
Go

package auth
import (
"context"
"encoding/json"
"fmt"
"math"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// RoundRobinSelector provides a simple provider scoped round-robin selection strategy.
type RoundRobinSelector struct {
mu sync.Mutex
cursors map[string]int
}
// FillFirstSelector selects the first available credential (deterministic ordering).
// This "burns" one account before moving to the next, which can help stagger
// rolling-window subscription caps (e.g. chat message limits).
type FillFirstSelector struct{}
type blockReason int
const (
blockReasonNone blockReason = iota
blockReasonCooldown
blockReasonDisabled
blockReasonOther
)
type modelCooldownError struct {
model string
resetIn time.Duration
provider string
}
func newModelCooldownError(model, provider string, resetIn time.Duration) *modelCooldownError {
if resetIn < 0 {
resetIn = 0
}
return &modelCooldownError{
model: model,
provider: provider,
resetIn: resetIn,
}
}
func (e *modelCooldownError) Error() string {
modelName := e.model
if modelName == "" {
modelName = "requested model"
}
message := fmt.Sprintf("All credentials for model %s are cooling down", modelName)
if e.provider != "" {
message = fmt.Sprintf("%s via provider %s", message, e.provider)
}
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
if resetSeconds < 0 {
resetSeconds = 0
}
displayDuration := e.resetIn
if displayDuration > 0 && displayDuration < time.Second {
displayDuration = time.Second
} else {
displayDuration = displayDuration.Round(time.Second)
}
errorBody := map[string]any{
"code": "model_cooldown",
"message": message,
"model": e.model,
"reset_time": displayDuration.String(),
"reset_seconds": resetSeconds,
}
if e.provider != "" {
errorBody["provider"] = e.provider
}
payload := map[string]any{"error": errorBody}
data, err := json.Marshal(payload)
if err != nil {
return fmt.Sprintf(`{"error":{"code":"model_cooldown","message":"%s"}}`, message)
}
return string(data)
}
func (e *modelCooldownError) StatusCode() int {
return http.StatusTooManyRequests
}
func (e *modelCooldownError) Headers() http.Header {
headers := make(http.Header)
headers.Set("Content-Type", "application/json")
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
if resetSeconds < 0 {
resetSeconds = 0
}
headers.Set("Retry-After", strconv.Itoa(resetSeconds))
return headers
}
func authPriority(auth *Auth) int {
if auth == nil || auth.Attributes == nil {
return 0
}
raw := strings.TrimSpace(auth.Attributes["priority"])
if raw == "" {
return 0
}
parsed, err := strconv.Atoi(raw)
if err != nil {
return 0
}
return parsed
}
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
available = make(map[int][]*Auth)
for i := 0; i < len(auths); i++ {
candidate := auths[i]
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
if !blocked {
priority := authPriority(candidate)
available[priority] = append(available[priority], candidate)
continue
}
if reason == blockReasonCooldown {
cooldownCount++
if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) {
earliest = next
}
}
}
return available, cooldownCount, earliest
}
func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]*Auth, error) {
if len(auths) == 0 {
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
}
availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
if len(availableByPriority) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return nil, newModelCooldownError(model, providerForError, resetIn)
}
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
}
bestPriority := 0
found := false
for priority := range availableByPriority {
if !found || priority > bestPriority {
bestPriority = priority
found = true
}
}
available := availableByPriority[bestPriority]
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, nil
}
// Pick selects the next available auth for the provider in a round-robin manner.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
key := provider + ":" + model
s.mu.Lock()
if s.cursors == nil {
s.cursors = make(map[string]int)
}
index := s.cursors[key]
if index >= 2_147_483_640 {
index = 0
}
s.cursors[key] = index + 1
s.mu.Unlock()
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
return available[index%len(available)], nil
}
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
return available[0], nil
}
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) {
if auth == nil {
return true, blockReasonOther, time.Time{}
}
if auth.Disabled || auth.Status == StatusDisabled {
return true, blockReasonDisabled, time.Time{}
}
if model != "" {
if len(auth.ModelStates) > 0 {
if state, ok := auth.ModelStates[model]; ok && state != nil {
if state.Status == StatusDisabled {
return true, blockReasonDisabled, time.Time{}
}
if state.Unavailable {
if state.NextRetryAfter.IsZero() {
return false, blockReasonNone, time.Time{}
}
if state.NextRetryAfter.After(now) {
next := state.NextRetryAfter
if !state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.After(now) {
next = state.Quota.NextRecoverAt
}
if next.Before(now) {
next = now
}
if state.Quota.Exceeded {
return true, blockReasonCooldown, next
}
return true, blockReasonOther, next
}
}
return false, blockReasonNone, time.Time{}
}
}
return false, blockReasonNone, time.Time{}
}
if auth.Unavailable && auth.NextRetryAfter.After(now) {
next := auth.NextRetryAfter
if !auth.Quota.NextRecoverAt.IsZero() && auth.Quota.NextRecoverAt.After(now) {
next = auth.Quota.NextRecoverAt
}
if next.Before(now) {
next = now
}
if auth.Quota.Exceeded {
return true, blockReasonCooldown, next
}
return true, blockReasonOther, next
}
return false, blockReasonNone, time.Time{}
}