Files
CLIProxyAPI/internal/registry/model_registry.go
hkfires a887a337a5 fix(registry): Handle duplicate model IDs in client registration
The previous model registration logic used a set-like map to track the models associated with a client. This caused issues when a client registered multiple instances of the same model ID, as they were all treated as a single registration.

This commit refactors the registration logic to use count maps for both the old and new model lists. This allows the system to accurately track the number of instances for each model ID provided by a client.

The changes ensure that:
- When a client updates its model list, the exact number of added or removed instances for each model ID is correctly calculated.
- Provider counts are accurately incremented or decremented based on the number of model instances being added, removed, or having their provider changed.
- The registry correctly handles scenarios where a client reduces the number of duplicate model registrations (e.g., from `[A, A]` to `[A]`), properly deregistering the surplus instance.
2025-09-26 18:52:58 +08:00

758 lines
22 KiB
Go

// Package registry provides centralized model management for all AI service providers.
// It implements a dynamic model registry with reference counting to track active clients
// and automatically hide models when no clients are available or when quota is exceeded.
package registry
import (
"sort"
"strings"
"sync"
"time"
misc "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// ModelInfo represents information about an available model
type ModelInfo struct {
// ID is the unique identifier for the model
ID string `json:"id"`
// Object type for the model (typically "model")
Object string `json:"object"`
// Created timestamp when the model was created
Created int64 `json:"created"`
// OwnedBy indicates the organization that owns the model
OwnedBy string `json:"owned_by"`
// Type indicates the model type (e.g., "claude", "gemini", "openai")
Type string `json:"type"`
// DisplayName is the human-readable name for the model
DisplayName string `json:"display_name,omitempty"`
// Name is used for Gemini-style model names
Name string `json:"name,omitempty"`
// Version is the model version
Version string `json:"version,omitempty"`
// Description provides detailed information about the model
Description string `json:"description,omitempty"`
// InputTokenLimit is the maximum input token limit
InputTokenLimit int `json:"inputTokenLimit,omitempty"`
// OutputTokenLimit is the maximum output token limit
OutputTokenLimit int `json:"outputTokenLimit,omitempty"`
// SupportedGenerationMethods lists supported generation methods
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
// ContextLength is the context window size
ContextLength int `json:"context_length,omitempty"`
// MaxCompletionTokens is the maximum completion tokens
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// SupportedParameters lists supported parameters
SupportedParameters []string `json:"supported_parameters,omitempty"`
}
// ModelRegistration tracks a model's availability
type ModelRegistration struct {
// Info contains the model metadata
Info *ModelInfo
// Count is the number of active clients that can provide this model
Count int
// LastUpdated tracks when this registration was last modified
LastUpdated time.Time
// QuotaExceededClients tracks which clients have exceeded quota for this model
QuotaExceededClients map[string]*time.Time
// Providers tracks available clients grouped by provider identifier
Providers map[string]int
// SuspendedClients tracks temporarily disabled clients keyed by client ID
SuspendedClients map[string]string
}
// ModelRegistry manages the global registry of available models
type ModelRegistry struct {
// models maps model ID to registration information
models map[string]*ModelRegistration
// clientModels maps client ID to the models it provides
clientModels map[string][]string
// clientProviders maps client ID to its provider identifier
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
}
// Global model registry instance
var globalRegistry *ModelRegistry
var registryOnce sync.Once
// GetGlobalRegistry returns the global model registry instance
func GetGlobalRegistry() *ModelRegistry {
registryOnce.Do(func() {
globalRegistry = &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
}
})
return globalRegistry
}
// RegisterClient registers a client and its supported models
// Parameters:
// - clientID: Unique identifier for the client
// - clientProvider: Provider name (e.g., "gemini", "claude", "openai")
// - models: List of models that this client can provide
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
r.mutex.Lock()
defer r.mutex.Unlock()
provider := strings.ToLower(clientProvider)
seen := make(map[string]struct{})
modelIDs := make([]string, 0, len(models))
newModels := make(map[string]*ModelInfo, len(models))
newCounts := make(map[string]int, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
newCounts[model.ID]++
if _, exists := seen[model.ID]; exists {
continue
}
seen[model.ID] = struct{}{}
modelIDs = append(modelIDs, model.ID)
newModels[model.ID] = model
}
if len(modelIDs) == 0 {
// No models supplied; unregister existing client state if present.
r.unregisterClientInternal(clientID)
delete(r.clientModels, clientID)
delete(r.clientProviders, clientID)
misc.LogCredentialSeparator()
return
}
now := time.Now()
oldModels, hadExisting := r.clientModels[clientID]
oldProvider, _ := r.clientProviders[clientID]
providerChanged := oldProvider != provider
if !hadExisting {
// Pure addition path.
for _, modelID := range modelIDs {
model := newModels[modelID]
r.addModelRegistration(modelID, provider, model, now)
}
r.clientModels[clientID] = modelIDs
if provider != "" {
r.clientProviders[clientID] = provider
} else {
delete(r.clientProviders, clientID)
}
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(modelIDs))
misc.LogCredentialSeparator()
return
}
oldCounts := make(map[string]int, len(oldModels))
for _, id := range oldModels {
oldCounts[id]++
}
added := make([]string, 0)
for _, id := range modelIDs {
if oldCounts[id] == 0 {
added = append(added, id)
}
}
removed := make([]string, 0)
for id := range oldCounts {
if newCounts[id] == 0 {
removed = append(removed, id)
}
}
// Handle provider change for overlapping models before modifications.
if providerChanged && oldProvider != "" {
for id, newCount := range newCounts {
if newCount == 0 {
continue
}
oldCount := oldCounts[id]
if oldCount == 0 {
continue
}
toRemove := newCount
if oldCount < toRemove {
toRemove = oldCount
}
if reg, ok := r.models[id]; ok && reg.Providers != nil {
if count, okProv := reg.Providers[oldProvider]; okProv {
if count <= toRemove {
delete(reg.Providers, oldProvider)
} else {
reg.Providers[oldProvider] = count - toRemove
}
}
}
}
}
// Apply removals first to keep counters accurate.
for _, id := range removed {
oldCount := oldCounts[id]
for i := 0; i < oldCount; i++ {
r.removeModelRegistration(clientID, id, oldProvider, now)
}
}
for id, oldCount := range oldCounts {
newCount := newCounts[id]
if newCount == 0 || oldCount <= newCount {
continue
}
overage := oldCount - newCount
for i := 0; i < overage; i++ {
r.removeModelRegistration(clientID, id, oldProvider, now)
}
}
// Apply additions.
for id, newCount := range newCounts {
oldCount := oldCounts[id]
if newCount <= oldCount {
continue
}
model := newModels[id]
diff := newCount - oldCount
for i := 0; i < diff; i++ {
r.addModelRegistration(id, provider, model, now)
}
}
// Update metadata for models that remain associated with the client.
addedSet := make(map[string]struct{}, len(added))
for _, id := range added {
addedSet[id] = struct{}{}
}
for _, id := range modelIDs {
model := newModels[id]
if reg, ok := r.models[id]; ok {
reg.Info = cloneModelInfo(model)
reg.LastUpdated = now
if providerChanged && provider != "" {
if _, newlyAdded := addedSet[id]; newlyAdded {
continue
}
overlapCount := newCounts[id]
if oldCount := oldCounts[id]; oldCount < overlapCount {
overlapCount = oldCount
}
if overlapCount <= 0 {
continue
}
if reg.Providers == nil {
reg.Providers = make(map[string]int)
}
reg.Providers[provider] += overlapCount
}
}
}
// Update client bookkeeping.
if len(modelIDs) > 0 {
r.clientModels[clientID] = modelIDs
}
if provider != "" {
r.clientProviders[clientID] = provider
} else {
delete(r.clientProviders, clientID)
}
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed.
misc.LogCredentialSeparator()
return
}
log.Debugf("Reconciled client %s (provider %s) models: +%d, -%d", clientID, provider, len(added), len(removed))
misc.LogCredentialSeparator()
}
func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *ModelInfo, now time.Time) {
if model == nil || modelID == "" {
return
}
if existing, exists := r.models[modelID]; exists {
existing.Count++
existing.LastUpdated = now
existing.Info = cloneModelInfo(model)
if existing.SuspendedClients == nil {
existing.SuspendedClients = make(map[string]string)
}
if provider != "" {
if existing.Providers == nil {
existing.Providers = make(map[string]int)
}
existing.Providers[provider]++
}
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
return
}
registration := &ModelRegistration{
Info: cloneModelInfo(model),
Count: 1,
LastUpdated: now,
QuotaExceededClients: make(map[string]*time.Time),
SuspendedClients: make(map[string]string),
}
if provider != "" {
registration.Providers = map[string]int{provider: 1}
}
r.models[modelID] = registration
log.Debugf("Registered new model %s from provider %s", modelID, provider)
}
func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider string, now time.Time) {
registration, exists := r.models[modelID]
if !exists {
return
}
registration.Count--
registration.LastUpdated = now
if registration.QuotaExceededClients != nil {
delete(registration.QuotaExceededClients, clientID)
}
if registration.SuspendedClients != nil {
delete(registration.SuspendedClients, clientID)
}
if registration.Count < 0 {
registration.Count = 0
}
if provider != "" && registration.Providers != nil {
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
} else {
registration.Providers[provider] = count - 1
}
}
}
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
if registration.Count <= 0 {
delete(r.models, modelID)
log.Debugf("Removed model %s as no clients remain", modelID)
}
}
func cloneModelInfo(model *ModelInfo) *ModelInfo {
if model == nil {
return nil
}
copy := *model
if len(model.SupportedGenerationMethods) > 0 {
copy.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
copy.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
return &copy
}
// UnregisterClient removes a client and decrements counts for its models
// Parameters:
// - clientID: Unique identifier for the client to remove
func (r *ModelRegistry) UnregisterClient(clientID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.unregisterClientInternal(clientID)
}
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
func (r *ModelRegistry) unregisterClientInternal(clientID string) {
models, exists := r.clientModels[clientID]
provider, hasProvider := r.clientProviders[clientID]
if !exists {
if hasProvider {
delete(r.clientProviders, clientID)
}
return
}
now := time.Now()
for _, modelID := range models {
if registration, isExists := r.models[modelID]; isExists {
registration.Count--
registration.LastUpdated = now
// Remove quota tracking for this client
delete(registration.QuotaExceededClients, clientID)
if registration.SuspendedClients != nil {
delete(registration.SuspendedClients, clientID)
}
if hasProvider && registration.Providers != nil {
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
} else {
registration.Providers[provider] = count - 1
}
}
}
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
// Remove model if no clients remain
if registration.Count <= 0 {
delete(r.models, modelID)
log.Debugf("Removed model %s as no clients remain", modelID)
}
}
}
delete(r.clientModels, clientID)
if hasProvider {
delete(r.clientProviders, clientID)
}
log.Debugf("Unregistered client %s", clientID)
// Separator line after completing client unregistration (after the summary line)
misc.LogCredentialSeparator()
}
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
// Parameters:
// - clientID: The client that exceeded quota
// - modelID: The model that exceeded quota
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if registration, exists := r.models[modelID]; exists {
now := time.Now()
registration.QuotaExceededClients[clientID] = &now
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
}
}
// ClearModelQuotaExceeded removes quota exceeded status for a model and client
// Parameters:
// - clientID: The client to clear quota status for
// - modelID: The model to clear quota status for
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if registration, exists := r.models[modelID]; exists {
delete(registration.QuotaExceededClients, clientID)
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
}
}
// SuspendClientModel marks a client's model as temporarily unavailable until explicitly resumed.
// Parameters:
// - clientID: The client to suspend
// - modelID: The model affected by the suspension
// - reason: Optional description for observability
func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
if clientID == "" || modelID == "" {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
registration, exists := r.models[modelID]
if !exists || registration == nil {
return
}
if registration.SuspendedClients == nil {
registration.SuspendedClients = make(map[string]string)
}
if _, already := registration.SuspendedClients[clientID]; already {
return
}
registration.SuspendedClients[clientID] = reason
registration.LastUpdated = time.Now()
if reason != "" {
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
} else {
log.Debugf("Suspended client %s for model %s", clientID, modelID)
}
}
// ResumeClientModel clears a previous suspension so the client counts toward availability again.
// Parameters:
// - clientID: The client to resume
// - modelID: The model being resumed
func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
if clientID == "" || modelID == "" {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
registration, exists := r.models[modelID]
if !exists || registration == nil || registration.SuspendedClients == nil {
return
}
if _, ok := registration.SuspendedClients[clientID]; !ok {
return
}
delete(registration.SuspendedClients, clientID)
registration.LastUpdated = time.Now()
log.Debugf("Resumed client %s for model %s", clientID, modelID)
}
// GetAvailableModels returns all models that have at least one available client
// Parameters:
// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini")
//
// Returns:
// - []map[string]any: List of available models in the requested format
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
r.mutex.RLock()
defer r.mutex.RUnlock()
models := make([]map[string]any, 0)
quotaExpiredDuration := 5 * time.Minute
for _, registration := range r.models {
// Check if model has any non-quota-exceeded clients
availableClients := registration.Count
now := time.Now()
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
expiredClients++
}
}
suspendedClients := 0
if registration.SuspendedClients != nil {
suspendedClients = len(registration.SuspendedClients)
}
effectiveClients := availableClients - expiredClients - suspendedClients
if effectiveClients < 0 {
effectiveClients = 0
}
// Only include models that have available clients
if effectiveClients > 0 {
model := r.convertModelToMap(registration.Info, handlerType)
if model != nil {
models = append(models, model)
}
}
}
return models
}
// GetModelCount returns the number of available clients for a specific model
// Parameters:
// - modelID: The model ID to check
//
// Returns:
// - int: Number of available clients for the model
func (r *ModelRegistry) GetModelCount(modelID string) int {
r.mutex.RLock()
defer r.mutex.RUnlock()
if registration, exists := r.models[modelID]; exists {
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
expiredClients++
}
}
suspendedClients := 0
if registration.SuspendedClients != nil {
suspendedClients = len(registration.SuspendedClients)
}
result := registration.Count - expiredClients - suspendedClients
if result < 0 {
return 0
}
return result
}
return 0
}
// GetModelProviders returns provider identifiers that currently supply the given model
// Parameters:
// - modelID: The model ID to check
//
// Returns:
// - []string: Provider identifiers ordered by availability count (descending)
func (r *ModelRegistry) GetModelProviders(modelID string) []string {
r.mutex.RLock()
defer r.mutex.RUnlock()
registration, exists := r.models[modelID]
if !exists || registration == nil || len(registration.Providers) == 0 {
return nil
}
type providerCount struct {
name string
count int
}
providers := make([]providerCount, 0, len(registration.Providers))
// suspendedByProvider := make(map[string]int)
// if registration.SuspendedClients != nil {
// for clientID := range registration.SuspendedClients {
// if provider, ok := r.clientProviders[clientID]; ok && provider != "" {
// suspendedByProvider[provider]++
// }
// }
// }
for name, count := range registration.Providers {
if count <= 0 {
continue
}
// adjusted := count - suspendedByProvider[name]
// if adjusted <= 0 {
// continue
// }
// providers = append(providers, providerCount{name: name, count: adjusted})
providers = append(providers, providerCount{name: name, count: count})
}
if len(providers) == 0 {
return nil
}
sort.Slice(providers, func(i, j int) bool {
if providers[i].count == providers[j].count {
return providers[i].name < providers[j].name
}
return providers[i].count > providers[j].count
})
result := make([]string, 0, len(providers))
for _, item := range providers {
result = append(result, item.name)
}
return result
}
// convertModelToMap converts ModelInfo to the appropriate format for different handler types
func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) map[string]any {
if model == nil {
return nil
}
switch handlerType {
case "openai":
result := map[string]any{
"id": model.ID,
"object": "model",
"owned_by": model.OwnedBy,
}
if model.Created > 0 {
result["created"] = model.Created
}
if model.Type != "" {
result["type"] = model.Type
}
if model.DisplayName != "" {
result["display_name"] = model.DisplayName
}
if model.Version != "" {
result["version"] = model.Version
}
if model.Description != "" {
result["description"] = model.Description
}
if model.ContextLength > 0 {
result["context_length"] = model.ContextLength
}
if model.MaxCompletionTokens > 0 {
result["max_completion_tokens"] = model.MaxCompletionTokens
}
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
}
return result
case "claude":
result := map[string]any{
"id": model.ID,
"object": "model",
"owned_by": model.OwnedBy,
}
if model.Created > 0 {
result["created"] = model.Created
}
if model.Type != "" {
result["type"] = model.Type
}
if model.DisplayName != "" {
result["display_name"] = model.DisplayName
}
return result
case "gemini":
result := map[string]any{}
if model.Name != "" {
result["name"] = model.Name
} else {
result["name"] = model.ID
}
if model.Version != "" {
result["version"] = model.Version
}
if model.DisplayName != "" {
result["displayName"] = model.DisplayName
}
if model.Description != "" {
result["description"] = model.Description
}
if model.InputTokenLimit > 0 {
result["inputTokenLimit"] = model.InputTokenLimit
}
if model.OutputTokenLimit > 0 {
result["outputTokenLimit"] = model.OutputTokenLimit
}
if len(model.SupportedGenerationMethods) > 0 {
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
}
return result
default:
// Generic format
result := map[string]any{
"id": model.ID,
"object": "model",
}
if model.OwnedBy != "" {
result["owned_by"] = model.OwnedBy
}
if model.Type != "" {
result["type"] = model.Type
}
return result
}
}
// CleanupExpiredQuotas removes expired quota tracking entries
func (r *ModelRegistry) CleanupExpiredQuotas() {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
delete(registration.QuotaExceededClients, clientID)
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
}
}
}
}