diff --git a/config.example.yaml b/config.example.yaml index 8457e103..02e085b1 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -62,6 +62,8 @@ ws-auth: false # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" +# model-blacklist: +# - "gemini-2.0-pro-exp" # exclude specific models from this provider # - api-key: "AIzaSy...02" # API keys for official Generative Language API (legacy compatibility) @@ -76,6 +78,8 @@ ws-auth: false # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override +# model-blacklist: +# - "gpt-5" # exclude specific models from this provider # Claude API keys #claude-api-key: @@ -88,6 +92,8 @@ ws-auth: false # models: # - name: "claude-3-5-sonnet-20241022" # upstream model name # alias: "claude-sonnet-latest" # client alias mapped to the upstream model +# model-blacklist: +# - "claude-3-5-sonnet-20241022" # exclude specific models from this provider # OpenAI compatibility providers #openai-compatibility: diff --git a/internal/config/config.go b/internal/config/config.go index 31920075..0e1e3966 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -157,6 +157,9 @@ type ClaudeKey struct { // Headers optionally adds extra HTTP headers for requests sent with this key. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // ModelBlacklist lists model IDs that should be excluded for this provider. + ModelBlacklist []string `yaml:"model-blacklist,omitempty" json:"model-blacklist,omitempty"` } // ClaudeModel describes a mapping between an alias and the actual upstream model name. @@ -183,6 +186,9 @@ type CodexKey struct { // Headers optionally adds extra HTTP headers for requests sent with this key. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // ModelBlacklist lists model IDs that should be excluded for this provider. + ModelBlacklist []string `yaml:"model-blacklist,omitempty" json:"model-blacklist,omitempty"` } // GeminiKey represents the configuration for a Gemini API key, @@ -199,6 +205,9 @@ type GeminiKey struct { // Headers optionally adds extra HTTP headers for requests sent with this key. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // ModelBlacklist lists model IDs that should be excluded for this provider. + ModelBlacklist []string `yaml:"model-blacklist,omitempty" json:"model-blacklist,omitempty"` } // OpenAICompatibility represents the configuration for OpenAI API compatibility diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 0d955064..54789524 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -450,6 +450,28 @@ func computeClaudeModelsHash(models []config.ClaudeModel) string { return hex.EncodeToString(sum[:]) } +func computeModelBlacklistHash(blacklist []string) string { + if len(blacklist) == 0 { + return "" + } + normalized := make([]string, 0, len(blacklist)) + for _, entry := range blacklist { + if trimmed := strings.TrimSpace(entry); trimmed != "" { + normalized = append(normalized, strings.ToLower(trimmed)) + } + } + if len(normalized) == 0 { + return "" + } + sort.Strings(normalized) + data, err := json.Marshal(normalized) + if err != nil || len(data) == 0 { + return "" + } + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + // SetClients sets the file-based clients. // SetClients removed // SetAPIKeyClients removed @@ -838,6 +860,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if base != "" { attrs["base_url"] = base } + if hash := computeModelBlacklistHash(entry.ModelBlacklist); hash != "" { + attrs["model_blacklist_hash"] = hash + } addConfigHeadersToAttrs(entry.Headers, attrs) a := &coreauth.Auth{ ID: id, @@ -870,6 +895,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if hash := computeClaudeModelsHash(ck.Models); hash != "" { attrs["models_hash"] = hash } + if hash := computeModelBlacklistHash(ck.ModelBlacklist); hash != "" { + attrs["model_blacklist_hash"] = hash + } addConfigHeadersToAttrs(ck.Headers, attrs) proxyURL := strings.TrimSpace(ck.ProxyURL) a := &coreauth.Auth{ @@ -899,6 +927,9 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if ck.BaseURL != "" { attrs["base_url"] = ck.BaseURL } + if hash := computeModelBlacklistHash(ck.ModelBlacklist); hash != "" { + attrs["model_blacklist_hash"] = hash + } addConfigHeadersToAttrs(ck.Headers, attrs) proxyURL := strings.TrimSpace(ck.ProxyURL) a := &coreauth.Auth{ diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 6e303ed2..0405c6ae 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -640,6 +640,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { switch provider { case "gemini": models = registry.GetGeminiModels() + if entry := s.resolveConfigGeminiKey(a); entry != nil { + models = applyModelBlacklist(models, entry.ModelBlacklist) + } case "vertex": // Vertex AI Gemini supports the same model identifiers as Gemini. models = registry.GetGeminiVertexModels() @@ -653,11 +656,17 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { cancel() case "claude": models = registry.GetClaudeModels() - if entry := s.resolveConfigClaudeKey(a); entry != nil && len(entry.Models) > 0 { - models = buildClaudeConfigModels(entry) + if entry := s.resolveConfigClaudeKey(a); entry != nil { + if len(entry.Models) > 0 { + models = buildClaudeConfigModels(entry) + } + models = applyModelBlacklist(models, entry.ModelBlacklist) } case "codex": models = registry.GetOpenAIModels() + if entry := s.resolveConfigCodexKey(a); entry != nil { + models = applyModelBlacklist(models, entry.ModelBlacklist) + } case "qwen": models = registry.GetQwenModels() case "iflow": @@ -749,7 +758,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { key = strings.ToLower(strings.TrimSpace(a.Provider)) } GlobalModelRegistry().RegisterClient(a.ID, key, models) + return } + + GlobalModelRegistry().UnregisterClient(a.ID) } func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey { @@ -791,6 +803,84 @@ func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey return nil } +func (s *Service) resolveConfigGeminiKey(auth *coreauth.Auth) *config.GeminiKey { + if auth == nil || s.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range s.cfg.GeminiKey { + entry := &s.cfg.GeminiKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + return nil +} + +func (s *Service) resolveConfigCodexKey(auth *coreauth.Auth) *config.CodexKey { + if auth == nil || s.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range s.cfg.CodexKey { + entry := &s.cfg.CodexKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + return nil +} + +func applyModelBlacklist(models []*ModelInfo, blacklist []string) []*ModelInfo { + if len(models) == 0 || len(blacklist) == 0 { + return models + } + blocked := make(map[string]struct{}, len(blacklist)) + for _, item := range blacklist { + if trimmed := strings.TrimSpace(item); trimmed != "" { + blocked[strings.ToLower(trimmed)] = struct{}{} + } + } + if len(blocked) == 0 { + return models + } + filtered := make([]*ModelInfo, 0, len(models)) + for _, model := range models { + if model == nil { + continue + } + if _, blockedModel := blocked[strings.ToLower(strings.TrimSpace(model.ID))]; blockedModel { + continue + } + filtered = append(filtered, model) + } + return filtered +} + func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { if entry == nil || len(entry.Models) == 0 { return nil