diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index d4f84481..a2c154ca 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -625,6 +625,131 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any return models } +// GetAvailableModelsByProvider returns models available for the given provider identifier. +// Parameters: +// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity") +// +// Returns: +// - []*ModelInfo: List of available models for the provider +func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" { + return nil + } + + r.mutex.RLock() + defer r.mutex.RUnlock() + + type providerModel struct { + count int + info *ModelInfo + } + + providerModels := make(map[string]*providerModel) + + for clientID, clientProvider := range r.clientProviders { + if clientProvider != provider { + continue + } + modelIDs := r.clientModels[clientID] + if len(modelIDs) == 0 { + continue + } + clientInfos := r.clientModelInfos[clientID] + for _, modelID := range modelIDs { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + continue + } + entry := providerModels[modelID] + if entry == nil { + entry = &providerModel{} + providerModels[modelID] = entry + } + entry.count++ + if entry.info == nil { + if clientInfos != nil { + if info := clientInfos[modelID]; info != nil { + entry.info = info + } + } + if entry.info == nil { + if reg, ok := r.models[modelID]; ok && reg != nil && reg.Info != nil { + entry.info = reg.Info + } + } + } + } + } + + if len(providerModels) == 0 { + return nil + } + + quotaExpiredDuration := 5 * time.Minute + now := time.Now() + result := make([]*ModelInfo, 0, len(providerModels)) + + for modelID, entry := range providerModels { + if entry == nil || entry.count <= 0 { + continue + } + registration, ok := r.models[modelID] + + expiredClients := 0 + cooldownSuspended := 0 + otherSuspended := 0 + if ok && registration != nil { + if registration.QuotaExceededClients != nil { + for clientID, quotaTime := range registration.QuotaExceededClients { + if clientID == "" { + continue + } + if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider { + continue + } + if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + expiredClients++ + } + } + } + if registration.SuspendedClients != nil { + for clientID, reason := range registration.SuspendedClients { + if clientID == "" { + continue + } + if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider { + continue + } + if strings.EqualFold(reason, "quota") { + cooldownSuspended++ + continue + } + otherSuspended++ + } + } + } + + availableClients := entry.count + effectiveClients := availableClients - expiredClients - otherSuspended + if effectiveClients < 0 { + effectiveClients = 0 + } + + if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) { + if entry.info != nil { + result = append(result, entry.info) + continue + } + if ok && registration != nil && registration.Info != nil { + result = append(result, registration.Info) + } + } + } + + return result +} + // GetModelCount returns the number of available clients for a specific model // Parameters: // - modelID: The model ID to check diff --git a/sdk/cliproxy/model_registry.go b/sdk/cliproxy/model_registry.go index 44ef8d7d..3cd57842 100644 --- a/sdk/cliproxy/model_registry.go +++ b/sdk/cliproxy/model_registry.go @@ -13,6 +13,7 @@ type ModelRegistry interface { ClearModelQuotaExceeded(clientID, modelID string) ClientSupportsModel(clientID, modelID string) bool GetAvailableModels(handlerType string) []map[string]any + GetAvailableModelsByProvider(provider string) []*ModelInfo } // GlobalModelRegistry returns the shared registry instance.