mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(auth): enable model suspension and resumption logic in AuthManager
- Added model suspension with reason tracking for 401 (unauthorized) and 402/403 (payment-related) errors. - Implemented resumption logic upon model quota recovery or auth state changes. - Enhanced registry to manage suspended clients, including counts and observability data. - Updated availability computation to exclude suspended clients, ensuring accurate client model tracking.
This commit is contained in:
@@ -58,6 +58,8 @@ type ModelRegistration struct {
|
||||
QuotaExceededClients map[string]*time.Time
|
||||
// Providers tracks available clients grouped by provider identifier
|
||||
Providers map[string]int
|
||||
// SuspendedClients tracks temporarily disabled clients keyed by client ID
|
||||
SuspendedClients map[string]string
|
||||
}
|
||||
|
||||
// ModelRegistry manages the global registry of available models
|
||||
@@ -112,6 +114,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
// Model already exists, increment count
|
||||
existing.Count++
|
||||
existing.LastUpdated = now
|
||||
if existing.SuspendedClients == nil {
|
||||
existing.SuspendedClients = make(map[string]string)
|
||||
}
|
||||
if provider != "" {
|
||||
if existing.Providers == nil {
|
||||
existing.Providers = make(map[string]int)
|
||||
@@ -126,6 +131,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
Count: 1,
|
||||
LastUpdated: now,
|
||||
QuotaExceededClients: make(map[string]*time.Time),
|
||||
SuspendedClients: make(map[string]string),
|
||||
}
|
||||
if provider != "" {
|
||||
registration.Providers = map[string]int{provider: 1}
|
||||
@@ -172,6 +178,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
|
||||
// Remove quota tracking for this client
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
if registration.SuspendedClients != nil {
|
||||
delete(registration.SuspendedClients, clientID)
|
||||
}
|
||||
|
||||
if hasProvider && registration.Providers != nil {
|
||||
if count, ok := registration.Providers[provider]; ok {
|
||||
@@ -229,6 +238,60 @@ func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// SuspendClientModel marks a client's model as temporarily unavailable until explicitly resumed.
|
||||
// Parameters:
|
||||
// - clientID: The client to suspend
|
||||
// - modelID: The model affected by the suspension
|
||||
// - reason: Optional description for observability
|
||||
func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
||||
if clientID == "" || modelID == "" {
|
||||
return
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil {
|
||||
return
|
||||
}
|
||||
if registration.SuspendedClients == nil {
|
||||
registration.SuspendedClients = make(map[string]string)
|
||||
}
|
||||
if _, already := registration.SuspendedClients[clientID]; already {
|
||||
return
|
||||
}
|
||||
registration.SuspendedClients[clientID] = reason
|
||||
registration.LastUpdated = time.Now()
|
||||
if reason != "" {
|
||||
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
||||
} else {
|
||||
log.Debugf("Suspended client %s for model %s", clientID, modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// ResumeClientModel clears a previous suspension so the client counts toward availability again.
|
||||
// Parameters:
|
||||
// - clientID: The client to resume
|
||||
// - modelID: The model being resumed
|
||||
func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
||||
if clientID == "" || modelID == "" {
|
||||
return
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
registration, exists := r.models[modelID]
|
||||
if !exists || registration == nil || registration.SuspendedClients == nil {
|
||||
return
|
||||
}
|
||||
if _, ok := registration.SuspendedClients[clientID]; !ok {
|
||||
return
|
||||
}
|
||||
delete(registration.SuspendedClients, clientID)
|
||||
registration.LastUpdated = time.Now()
|
||||
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
||||
}
|
||||
|
||||
// GetAvailableModels returns all models that have at least one available client
|
||||
// Parameters:
|
||||
// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini")
|
||||
@@ -255,7 +318,14 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
}
|
||||
}
|
||||
|
||||
effectiveClients := availableClients - expiredClients
|
||||
suspendedClients := 0
|
||||
if registration.SuspendedClients != nil {
|
||||
suspendedClients = len(registration.SuspendedClients)
|
||||
}
|
||||
effectiveClients := availableClients - expiredClients - suspendedClients
|
||||
if effectiveClients < 0 {
|
||||
effectiveClients = 0
|
||||
}
|
||||
|
||||
// Only include models that have available clients
|
||||
if effectiveClients > 0 {
|
||||
@@ -290,8 +360,15 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
|
||||
return registration.Count - expiredClients
|
||||
suspendedClients := 0
|
||||
if registration.SuspendedClients != nil {
|
||||
suspendedClients = len(registration.SuspendedClients)
|
||||
}
|
||||
result := registration.Count - expiredClients - suspendedClients
|
||||
if result < 0 {
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -316,11 +393,23 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
|
||||
count int
|
||||
}
|
||||
providers := make([]providerCount, 0, len(registration.Providers))
|
||||
suspendedByProvider := make(map[string]int)
|
||||
if registration.SuspendedClients != nil {
|
||||
for clientID := range registration.SuspendedClients {
|
||||
if provider, ok := r.clientProviders[clientID]; ok && provider != "" {
|
||||
suspendedByProvider[provider]++
|
||||
}
|
||||
}
|
||||
}
|
||||
for name, count := range registration.Providers {
|
||||
if count <= 0 {
|
||||
continue
|
||||
}
|
||||
providers = append(providers, providerCount{name: name, count: count})
|
||||
adjusted := count - suspendedByProvider[name]
|
||||
if adjusted <= 0 {
|
||||
continue
|
||||
}
|
||||
providers = append(providers, providerCount{name: name, count: adjusted})
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user