mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
feat(auth): enable model suspension and resumption logic in AuthManager
- Added model suspension with reason tracking for 401 (unauthorized) and 402/403 (payment-related) errors. - Implemented resumption logic upon model quota recovery or auth state changes. - Enhanced registry to manage suspended clients, including counts and observability data. - Updated availability computation to exclude suspended clients, ensuring accurate client model tracking.
This commit is contained in:
@@ -58,6 +58,8 @@ type ModelRegistration struct {
|
||||
QuotaExceededClients map[string]*time.Time
|
||||
// Providers tracks available clients grouped by provider identifier
|
||||
Providers map[string]int
|
||||
// SuspendedClients tracks temporarily disabled clients keyed by client ID
|
||||
SuspendedClients map[string]string
|
||||
}
|
||||
|
||||
// ModelRegistry manages the global registry of available models
|
||||
@@ -112,6 +114,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
// Model already exists, increment count
|
||||
existing.Count++
|
||||
existing.LastUpdated = now
|
||||
if existing.SuspendedClients == nil {
|
||||
existing.SuspendedClients = make(map[string]string)
|
||||
}
|
||||
if provider != "" {
|
||||
if existing.Providers == nil {
|
||||
existing.Providers = make(map[string]int)
|
||||
@@ -126,6 +131,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
Count: 1,
|
||||
LastUpdated: now,
|
||||
QuotaExceededClients: make(map[string]*time.Time),
|
||||
SuspendedClients: make(map[string]string),
|
||||
}
|
||||
if provider != "" {
|
||||
registration.Providers = map[string]int{provider: 1}
|
||||
@@ -172,6 +178,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
|
||||
// Remove quota tracking for this client
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
if registration.SuspendedClients != nil {
|
||||
delete(registration.SuspendedClients, clientID)
|
||||
}
|
||||
|
||||
if hasProvider && registration.Providers != nil {
|
||||
if count, ok := registration.Providers[provider]; ok {
|
||||
@@ -229,6 +238,60 @@ func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// SuspendClientModel marks a client's model as temporarily unavailable until explicitly resumed.
|
||||
// Parameters:
|
||||
// - clientID: The client to suspend
|
||||
// - modelID: The model affected by the suspension
|
||||
// - reason: Optional description for observability
|
||||
func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
||||
if clientID == "" || modelID == "" {
|
||||
return
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil {
|
||||
return
|
||||
}
|
||||
if registration.SuspendedClients == nil {
|
||||
registration.SuspendedClients = make(map[string]string)
|
||||
}
|
||||
if _, already := registration.SuspendedClients[clientID]; already {
|
||||
return
|
||||
}
|
||||
registration.SuspendedClients[clientID] = reason
|
||||
registration.LastUpdated = time.Now()
|
||||
if reason != "" {
|
||||
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
||||
} else {
|
||||
log.Debugf("Suspended client %s for model %s", clientID, modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// ResumeClientModel clears a previous suspension so the client counts toward availability again.
|
||||
// Parameters:
|
||||
// - clientID: The client to resume
|
||||
// - modelID: The model being resumed
|
||||
func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
||||
if clientID == "" || modelID == "" {
|
||||
return
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil || registration.SuspendedClients == nil {
|
||||
return
|
||||
}
|
||||
if _, ok := registration.SuspendedClients[clientID]; !ok {
|
||||
return
|
||||
}
|
||||
delete(registration.SuspendedClients, clientID)
|
||||
registration.LastUpdated = time.Now()
|
||||
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
||||
}
|
||||
|
||||
// GetAvailableModels returns all models that have at least one available client
|
||||
// Parameters:
|
||||
// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini")
|
||||
@@ -255,7 +318,14 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
}
|
||||
}
|
||||
|
||||
effectiveClients := availableClients - expiredClients
|
||||
suspendedClients := 0
|
||||
if registration.SuspendedClients != nil {
|
||||
suspendedClients = len(registration.SuspendedClients)
|
||||
}
|
||||
effectiveClients := availableClients - expiredClients - suspendedClients
|
||||
if effectiveClients < 0 {
|
||||
effectiveClients = 0
|
||||
}
|
||||
|
||||
// Only include models that have available clients
|
||||
if effectiveClients > 0 {
|
||||
@@ -290,8 +360,15 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
|
||||
return registration.Count - expiredClients
|
||||
suspendedClients := 0
|
||||
if registration.SuspendedClients != nil {
|
||||
suspendedClients = len(registration.SuspendedClients)
|
||||
}
|
||||
result := registration.Count - expiredClients - suspendedClients
|
||||
if result < 0 {
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -316,11 +393,23 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
|
||||
count int
|
||||
}
|
||||
providers := make([]providerCount, 0, len(registration.Providers))
|
||||
suspendedByProvider := make(map[string]int)
|
||||
if registration.SuspendedClients != nil {
|
||||
for clientID := range registration.SuspendedClients {
|
||||
if provider, ok := r.clientProviders[clientID]; ok && provider != "" {
|
||||
suspendedByProvider[provider]++
|
||||
}
|
||||
}
|
||||
}
|
||||
for name, count := range registration.Providers {
|
||||
if count <= 0 {
|
||||
continue
|
||||
}
|
||||
providers = append(providers, providerCount{name: name, count: count})
|
||||
adjusted := count - suspendedByProvider[name]
|
||||
if adjusted <= 0 {
|
||||
continue
|
||||
}
|
||||
providers = append(providers, providerCount{name: name, count: adjusted})
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -486,6 +486,9 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
return
|
||||
}
|
||||
// Update in-memory auth status based on result.
|
||||
shouldResumeModel := false
|
||||
shouldSuspendModel := false
|
||||
suspendReason := ""
|
||||
m.mu.Lock()
|
||||
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
|
||||
now := time.Now()
|
||||
@@ -501,6 +504,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
auth.UpdatedAt = now
|
||||
if result.Model != "" {
|
||||
registry.GetGlobalRegistry().ClearModelQuotaExceeded(auth.ID, result.Model)
|
||||
shouldResumeModel = true
|
||||
}
|
||||
} else {
|
||||
// Default transient error state.
|
||||
@@ -511,7 +515,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
auth.LastError = &Error{Code: result.Error.Code, Message: result.Error.Message, Retryable: result.Error.Retryable}
|
||||
}
|
||||
// If the error carries a status code, adjust backoff/quota accordingly.
|
||||
// 401 -> auth issue; 402/429 -> quota; 5xx -> transient.
|
||||
// 401 -> auth issue; 402 -> billing; 403 -> forbidden; 429 -> quota; 5xx -> transient.
|
||||
var statusCode int
|
||||
if se, isOk := any(result.Error).(interface{ StatusCode() int }); isOk && se != nil {
|
||||
statusCode = se.StatusCode()
|
||||
@@ -519,19 +523,35 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
auth.StatusMessage = "unauthorized"
|
||||
auth.NextRefreshAfter = now.Add(5 * time.Minute)
|
||||
case 402, 429:
|
||||
auth.NextRefreshAfter = now.Add(30 * time.Minute)
|
||||
if result.Model != "" {
|
||||
shouldSuspendModel = true
|
||||
suspendReason = "unauthorized"
|
||||
}
|
||||
case 402, 403:
|
||||
auth.StatusMessage = "payment_required"
|
||||
auth.NextRefreshAfter = now.Add(30 * time.Minute)
|
||||
if result.Model != "" {
|
||||
shouldSuspendModel = true
|
||||
suspendReason = "payment_required"
|
||||
}
|
||||
case 429:
|
||||
auth.StatusMessage = "quota exhausted"
|
||||
auth.Quota.Exceeded = true
|
||||
auth.Quota.Reason = "quota"
|
||||
auth.Quota.NextRecoverAt = now.Add(10 * time.Minute)
|
||||
auth.Quota.NextRecoverAt = now.Add(30 * time.Minute)
|
||||
auth.NextRefreshAfter = auth.Quota.NextRecoverAt
|
||||
if result.Model != "" {
|
||||
shouldSuspendModel = true
|
||||
registry.GetGlobalRegistry().SetModelQuotaExceeded(auth.ID, result.Model)
|
||||
}
|
||||
case 403, 408, 500, 502, 503, 504:
|
||||
case 408, 500, 502, 503, 504:
|
||||
auth.StatusMessage = "transient upstream error"
|
||||
auth.NextRefreshAfter = now.Add(1 * time.Minute)
|
||||
if result.Model != "" {
|
||||
shouldSuspendModel = false
|
||||
suspendReason = "forbidden"
|
||||
}
|
||||
default:
|
||||
// keep generic
|
||||
if auth.StatusMessage == "" {
|
||||
@@ -544,6 +564,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if shouldResumeModel {
|
||||
registry.GetGlobalRegistry().ResumeClientModel(result.AuthID, result.Model)
|
||||
} else if shouldSuspendModel {
|
||||
registry.GetGlobalRegistry().SuspendClientModel(result.AuthID, result.Model, suspendReason)
|
||||
}
|
||||
|
||||
m.hook.OnResult(ctx, result)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user