feat(auth): enable model suspension and resumption logic in AuthManager

- Added model suspension with reason tracking for 401 (unauthorized) and 402/403 (payment-related) errors.
- Implemented resumption logic upon model quota recovery or auth state changes.
- Enhanced registry to manage suspended clients, including counts and observability data.
- Updated availability computation to exclude suspended clients, ensuring accurate client model tracking.
This commit is contained in:
Luis Pater
2025-09-23 09:24:55 +08:00
parent ec08500924
commit e68a6037e2
2 changed files with 124 additions and 9 deletions

View File

@@ -58,6 +58,8 @@ type ModelRegistration struct {
QuotaExceededClients map[string]*time.Time
// Providers tracks available clients grouped by provider identifier
Providers map[string]int
// SuspendedClients tracks temporarily disabled clients keyed by client ID
SuspendedClients map[string]string
}
// ModelRegistry manages the global registry of available models
@@ -112,6 +114,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
// Model already exists, increment count
existing.Count++
existing.LastUpdated = now
if existing.SuspendedClients == nil {
existing.SuspendedClients = make(map[string]string)
}
if provider != "" {
if existing.Providers == nil {
existing.Providers = make(map[string]int)
@@ -126,6 +131,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
Count: 1,
LastUpdated: now,
QuotaExceededClients: make(map[string]*time.Time),
SuspendedClients: make(map[string]string),
}
if provider != "" {
registration.Providers = map[string]int{provider: 1}
@@ -172,6 +178,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
// Remove quota tracking for this client
delete(registration.QuotaExceededClients, clientID)
if registration.SuspendedClients != nil {
delete(registration.SuspendedClients, clientID)
}
if hasProvider && registration.Providers != nil {
if count, ok := registration.Providers[provider]; ok {
@@ -229,6 +238,60 @@ func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
}
}
// SuspendClientModel marks a client's model as temporarily unavailable until explicitly resumed.
// Parameters:
// - clientID: The client to suspend
// - modelID: The model affected by the suspension
// - reason: Optional description for observability
func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
if clientID == "" || modelID == "" {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
registration, exists := r.models[modelID]
if !exists || registration == nil {
return
}
if registration.SuspendedClients == nil {
registration.SuspendedClients = make(map[string]string)
}
if _, already := registration.SuspendedClients[clientID]; already {
return
}
registration.SuspendedClients[clientID] = reason
registration.LastUpdated = time.Now()
if reason != "" {
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
} else {
log.Debugf("Suspended client %s for model %s", clientID, modelID)
}
}
// ResumeClientModel clears a previous suspension so the client counts toward availability again.
// Parameters:
// - clientID: The client to resume
// - modelID: The model being resumed
func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
if clientID == "" || modelID == "" {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
registration, exists := r.models[modelID]
if !exists || registration == nil || registration.SuspendedClients == nil {
return
}
if _, ok := registration.SuspendedClients[clientID]; !ok {
return
}
delete(registration.SuspendedClients, clientID)
registration.LastUpdated = time.Now()
log.Debugf("Resumed client %s for model %s", clientID, modelID)
}
// GetAvailableModels returns all models that have at least one available client
// Parameters:
// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini")
@@ -255,7 +318,14 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
}
}
effectiveClients := availableClients - expiredClients
suspendedClients := 0
if registration.SuspendedClients != nil {
suspendedClients = len(registration.SuspendedClients)
}
effectiveClients := availableClients - expiredClients - suspendedClients
if effectiveClients < 0 {
effectiveClients = 0
}
// Only include models that have available clients
if effectiveClients > 0 {
@@ -290,8 +360,15 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
expiredClients++
}
}
return registration.Count - expiredClients
suspendedClients := 0
if registration.SuspendedClients != nil {
suspendedClients = len(registration.SuspendedClients)
}
result := registration.Count - expiredClients - suspendedClients
if result < 0 {
return 0
}
return result
}
return 0
}
@@ -316,11 +393,23 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
count int
}
providers := make([]providerCount, 0, len(registration.Providers))
suspendedByProvider := make(map[string]int)
if registration.SuspendedClients != nil {
for clientID := range registration.SuspendedClients {
if provider, ok := r.clientProviders[clientID]; ok && provider != "" {
suspendedByProvider[provider]++
}
}
}
for name, count := range registration.Providers {
if count <= 0 {
continue
}
providers = append(providers, providerCount{name: name, count: count})
adjusted := count - suspendedByProvider[name]
if adjusted <= 0 {
continue
}
providers = append(providers, providerCount{name: name, count: adjusted})
}
if len(providers) == 0 {
return nil

View File

@@ -486,6 +486,9 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
return
}
// Update in-memory auth status based on result.
shouldResumeModel := false
shouldSuspendModel := false
suspendReason := ""
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
now := time.Now()
@@ -501,6 +504,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
auth.UpdatedAt = now
if result.Model != "" {
registry.GetGlobalRegistry().ClearModelQuotaExceeded(auth.ID, result.Model)
shouldResumeModel = true
}
} else {
// Default transient error state.
@@ -511,7 +515,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
auth.LastError = &Error{Code: result.Error.Code, Message: result.Error.Message, Retryable: result.Error.Retryable}
}
// If the error carries a status code, adjust backoff/quota accordingly.
// 401 -> auth issue; 402/429 -> quota; 5xx -> transient.
// 401 -> auth issue; 402 -> billing; 403 -> forbidden; 429 -> quota; 5xx -> transient.
var statusCode int
if se, isOk := any(result.Error).(interface{ StatusCode() int }); isOk && se != nil {
statusCode = se.StatusCode()
@@ -519,19 +523,35 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
switch statusCode {
case 401:
auth.StatusMessage = "unauthorized"
auth.NextRefreshAfter = now.Add(5 * time.Minute)
case 402, 429:
auth.NextRefreshAfter = now.Add(30 * time.Minute)
if result.Model != "" {
shouldSuspendModel = true
suspendReason = "unauthorized"
}
case 402, 403:
auth.StatusMessage = "payment_required"
auth.NextRefreshAfter = now.Add(30 * time.Minute)
if result.Model != "" {
shouldSuspendModel = true
suspendReason = "payment_required"
}
case 429:
auth.StatusMessage = "quota exhausted"
auth.Quota.Exceeded = true
auth.Quota.Reason = "quota"
auth.Quota.NextRecoverAt = now.Add(10 * time.Minute)
auth.Quota.NextRecoverAt = now.Add(30 * time.Minute)
auth.NextRefreshAfter = auth.Quota.NextRecoverAt
if result.Model != "" {
shouldSuspendModel = true
registry.GetGlobalRegistry().SetModelQuotaExceeded(auth.ID, result.Model)
}
case 403, 408, 500, 502, 503, 504:
case 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error"
auth.NextRefreshAfter = now.Add(1 * time.Minute)
if result.Model != "" {
shouldSuspendModel = false
suspendReason = "forbidden"
}
default:
// keep generic
if auth.StatusMessage == "" {
@@ -544,6 +564,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
}
m.mu.Unlock()
if shouldResumeModel {
registry.GetGlobalRegistry().ResumeClientModel(result.AuthID, result.Model)
} else if shouldSuspendModel {
registry.GetGlobalRegistry().SuspendClientModel(result.AuthID, result.Model, suspendReason)
}
m.hook.OnResult(ctx, result)
}