fix(registry): Handle duplicate model IDs in client registration

The previous model registration logic used a set-like map to track the models associated with a client. This caused issues when a client registered multiple instances of the same model ID, as they were all treated as a single registration.

This commit refactors the registration logic to use count maps for both the old and new model lists. This allows the system to accurately track the number of instances for each model ID provided by a client.

The changes ensure that:
- When a client updates its model list, the exact number of added or removed instances for each model ID is correctly calculated.
- Provider counts are accurately incremented or decremented based on the number of model instances being added, removed, or having their provider changed.
- The registry correctly handles scenarios where a client reduces the number of duplicate model registrations (e.g., from `[A, A]` to `[A]`), properly deregistering the surplus instance.
This commit is contained in:
hkfires
2025-09-26 18:52:58 +08:00
parent 2717ba3e50
commit a887a337a5

View File

@@ -105,10 +105,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
seen := make(map[string]struct{})
modelIDs := make([]string, 0, len(models))
newModels := make(map[string]*ModelInfo, len(models))
newCounts := make(map[string]int, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
newCounts[model.ID]++
if _, exists := seen[model.ID]; exists {
continue
}
@@ -148,36 +150,45 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
return
}
oldSet := make(map[string]struct{}, len(oldModels))
oldCounts := make(map[string]int, len(oldModels))
for _, id := range oldModels {
oldSet[id] = struct{}{}
oldCounts[id]++
}
added := make([]string, 0)
removed := make([]string, 0)
for _, id := range modelIDs {
if _, exists := oldSet[id]; !exists {
if oldCounts[id] == 0 {
added = append(added, id)
}
}
for _, id := range oldModels {
if _, exists := newModels[id]; !exists {
removed := make([]string, 0)
for id := range oldCounts {
if newCounts[id] == 0 {
removed = append(removed, id)
}
}
// Handle provider change for overlapping models before modifications.
if providerChanged && oldProvider != "" {
for _, id := range modelIDs {
if _, existed := oldSet[id]; !existed {
for id, newCount := range newCounts {
if newCount == 0 {
continue
}
oldCount := oldCounts[id]
if oldCount == 0 {
continue
}
toRemove := newCount
if oldCount < toRemove {
toRemove = oldCount
}
if reg, ok := r.models[id]; ok && reg.Providers != nil {
if count, okProv := reg.Providers[oldProvider]; okProv {
if count <= 1 {
if count <= toRemove {
delete(reg.Providers, oldProvider)
} else {
reg.Providers[oldProvider] = count - 1
reg.Providers[oldProvider] = count - toRemove
}
}
}
@@ -186,13 +197,34 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
// Apply removals first to keep counters accurate.
for _, id := range removed {
r.removeModelRegistration(clientID, id, oldProvider, now)
oldCount := oldCounts[id]
for i := 0; i < oldCount; i++ {
r.removeModelRegistration(clientID, id, oldProvider, now)
}
}
for id, oldCount := range oldCounts {
newCount := newCounts[id]
if newCount == 0 || oldCount <= newCount {
continue
}
overage := oldCount - newCount
for i := 0; i < overage; i++ {
r.removeModelRegistration(clientID, id, oldProvider, now)
}
}
// Apply additions.
for _, id := range added {
for id, newCount := range newCounts {
oldCount := oldCounts[id]
if newCount <= oldCount {
continue
}
model := newModels[id]
r.addModelRegistration(id, provider, model, now)
diff := newCount - oldCount
for i := 0; i < diff; i++ {
r.addModelRegistration(id, provider, model, now)
}
}
// Update metadata for models that remain associated with the client.
@@ -209,10 +241,17 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
if _, newlyAdded := addedSet[id]; newlyAdded {
continue
}
overlapCount := newCounts[id]
if oldCount := oldCounts[id]; oldCount < overlapCount {
overlapCount = oldCount
}
if overlapCount <= 0 {
continue
}
if reg.Providers == nil {
reg.Providers = make(map[string]int)
}
reg.Providers[provider]++
reg.Providers[provider] += overlapCount
}
}
}