diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index cc679941..dba10d9d 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -105,10 +105,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ seen := make(map[string]struct{}) modelIDs := make([]string, 0, len(models)) newModels := make(map[string]*ModelInfo, len(models)) + newCounts := make(map[string]int, len(models)) for _, model := range models { if model == nil || model.ID == "" { continue } + newCounts[model.ID]++ if _, exists := seen[model.ID]; exists { continue } @@ -148,36 +150,45 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ return } - oldSet := make(map[string]struct{}, len(oldModels)) + oldCounts := make(map[string]int, len(oldModels)) for _, id := range oldModels { - oldSet[id] = struct{}{} + oldCounts[id]++ } added := make([]string, 0) - removed := make([]string, 0) for _, id := range modelIDs { - if _, exists := oldSet[id]; !exists { + if oldCounts[id] == 0 { added = append(added, id) } } - for _, id := range oldModels { - if _, exists := newModels[id]; !exists { + + removed := make([]string, 0) + for id := range oldCounts { + if newCounts[id] == 0 { removed = append(removed, id) } } // Handle provider change for overlapping models before modifications. if providerChanged && oldProvider != "" { - for _, id := range modelIDs { - if _, existed := oldSet[id]; !existed { + for id, newCount := range newCounts { + if newCount == 0 { continue } + oldCount := oldCounts[id] + if oldCount == 0 { + continue + } + toRemove := newCount + if oldCount < toRemove { + toRemove = oldCount + } if reg, ok := r.models[id]; ok && reg.Providers != nil { if count, okProv := reg.Providers[oldProvider]; okProv { - if count <= 1 { + if count <= toRemove { delete(reg.Providers, oldProvider) } else { - reg.Providers[oldProvider] = count - 1 + reg.Providers[oldProvider] = count - toRemove } } } @@ -186,13 +197,34 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ // Apply removals first to keep counters accurate. for _, id := range removed { - r.removeModelRegistration(clientID, id, oldProvider, now) + oldCount := oldCounts[id] + for i := 0; i < oldCount; i++ { + r.removeModelRegistration(clientID, id, oldProvider, now) + } + } + + for id, oldCount := range oldCounts { + newCount := newCounts[id] + if newCount == 0 || oldCount <= newCount { + continue + } + overage := oldCount - newCount + for i := 0; i < overage; i++ { + r.removeModelRegistration(clientID, id, oldProvider, now) + } } // Apply additions. - for _, id := range added { + for id, newCount := range newCounts { + oldCount := oldCounts[id] + if newCount <= oldCount { + continue + } model := newModels[id] - r.addModelRegistration(id, provider, model, now) + diff := newCount - oldCount + for i := 0; i < diff; i++ { + r.addModelRegistration(id, provider, model, now) + } } // Update metadata for models that remain associated with the client. @@ -209,10 +241,17 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ if _, newlyAdded := addedSet[id]; newlyAdded { continue } + overlapCount := newCounts[id] + if oldCount := oldCounts[id]; oldCount < overlapCount { + overlapCount = oldCount + } + if overlapCount <= 0 { + continue + } if reg.Providers == nil { reg.Providers = make(map[string]int) } - reg.Providers[provider]++ + reg.Providers[provider] += overlapCount } } }