diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 888db60c..d4f84481 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -90,6 +90,9 @@ type ModelRegistry struct { models map[string]*ModelRegistration // clientModels maps client ID to the models it provides clientModels map[string][]string + // clientModelInfos maps client ID to a map of model ID -> ModelInfo + // This preserves the original model info provided by each client + clientModelInfos map[string]map[string]*ModelInfo // clientProviders maps client ID to its provider identifier clientProviders map[string]string // mutex ensures thread-safe access to the registry @@ -104,10 +107,11 @@ var registryOnce sync.Once func GetGlobalRegistry() *ModelRegistry { registryOnce.Do(func() { globalRegistry = &ModelRegistry{ - models: make(map[string]*ModelRegistration), - clientModels: make(map[string][]string), - clientProviders: make(map[string]string), - mutex: &sync.RWMutex{}, + models: make(map[string]*ModelRegistration), + clientModels: make(map[string][]string), + clientModelInfos: make(map[string]map[string]*ModelInfo), + clientProviders: make(map[string]string), + mutex: &sync.RWMutex{}, } }) return globalRegistry @@ -144,6 +148,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ // No models supplied; unregister existing client state if present. r.unregisterClientInternal(clientID) delete(r.clientModels, clientID) + delete(r.clientModelInfos, clientID) delete(r.clientProviders, clientID) misc.LogCredentialSeparator() return @@ -152,7 +157,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ now := time.Now() oldModels, hadExisting := r.clientModels[clientID] - oldProvider, _ := r.clientProviders[clientID] + oldProvider := r.clientProviders[clientID] providerChanged := oldProvider != provider if !hadExisting { // Pure addition path. @@ -161,6 +166,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ r.addModelRegistration(modelID, provider, model, now) } r.clientModels[clientID] = append([]string(nil), rawModelIDs...) + // Store client's own model infos + clientInfos := make(map[string]*ModelInfo, len(newModels)) + for id, m := range newModels { + clientInfos[id] = cloneModelInfo(m) + } + r.clientModelInfos[clientID] = clientInfos if provider != "" { r.clientProviders[clientID] = provider } else { @@ -287,6 +298,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ if len(rawModelIDs) > 0 { r.clientModels[clientID] = append([]string(nil), rawModelIDs...) } + // Update client's own model infos + clientInfos := make(map[string]*ModelInfo, len(newModels)) + for id, m := range newModels { + clientInfos[id] = cloneModelInfo(m) + } + r.clientModelInfos[clientID] = clientInfos if provider != "" { r.clientProviders[clientID] = provider } else { @@ -436,6 +453,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) { } delete(r.clientModels, clientID) + delete(r.clientModelInfos, clientID) if hasProvider { delete(r.clientProviders, clientID) } @@ -887,6 +905,9 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo { return nil } + // Try to use client-specific model infos first + clientInfos := r.clientModelInfos[clientID] + seen := make(map[string]struct{}) result := make([]*ModelInfo, 0, len(modelIDs)) for _, modelID := range modelIDs { @@ -894,6 +915,15 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo { continue } seen[modelID] = struct{}{} + + // Prefer client's own model info to preserve original type/owned_by + if clientInfos != nil { + if info, ok := clientInfos[modelID]; ok && info != nil { + result = append(result, info) + continue + } + } + // Fallback to global registry (for backwards compatibility) if reg, ok := r.models[modelID]; ok && reg.Info != nil { result = append(result, reg.Info) }