feat(auth): add oauth provider model blacklist

This commit is contained in:
hkfires
2025-11-28 10:37:10 +08:00
parent f8cebb9343
commit 5983e3ec87
4 changed files with 102 additions and 12 deletions

View File

@@ -127,3 +127,22 @@ ws-auth: false
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex # protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
# params: # JSON path (gjson/sjson syntax) -> value # params: # JSON path (gjson/sjson syntax) -> value
# "reasoning.effort": "high" # "reasoning.effort": "high"
# OAuth provider model blacklist
#oauth-model-blacklist:
# gemini-cli:
# - "gemini-3-pro-preview"
# vertex:
# - "gemini-3-pro-preview"
# aistudio:
# - "gemini-3-pro-preview"
# antigravity:
# - "gemini-3-pro-preview"
# claude:
# - "claude-3-5-haiku-20241022"
# codex:
# - "gpt-5-codex-mini"
# qwen:
# - "vision-model"
# iflow:
# - "tstars2.0"

View File

@@ -83,6 +83,9 @@ type Config struct {
// Payload defines default and override rules for provider payload parameters. // Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"` Payload PayloadConfig `yaml:"payload" json:"payload"`
// OAuthModelBlacklist defines per-provider global model blacklists applied to OAuth/file-backed auth entries.
OAuthModelBlacklist map[string][]string `yaml:"oauth-model-blacklist,omitempty" json:"oauth-model-blacklist,omitempty"`
} }
// TLSConfig holds HTTPS server settings. // TLSConfig holds HTTPS server settings.

View File

@@ -472,6 +472,46 @@ func computeModelBlacklistHash(blacklist []string) string {
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
} }
func applyAuthModelBlacklistMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
if auth == nil || cfg == nil {
return
}
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
seen := make(map[string]struct{})
add := func(list []string) {
for _, entry := range list {
if trimmed := strings.TrimSpace(entry); trimmed != "" {
key := strings.ToLower(trimmed)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
}
}
}
if authKindKey == "apikey" {
add(perKey)
} else if cfg.OAuthModelBlacklist != nil {
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
add(cfg.OAuthModelBlacklist[providerKey])
}
combined := make([]string, 0, len(seen))
for k := range seen {
combined = append(combined, k)
}
sort.Strings(combined)
hash := computeModelBlacklistHash(combined)
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
if hash != "" {
auth.Attributes["model_blacklist_hash"] = hash
}
if authKind != "" {
auth.Attributes["auth_kind"] = authKind
}
}
// SetClients sets the file-based clients. // SetClients sets the file-based clients.
// SetClients removed // SetClients removed
// SetAPIKeyClients removed // SetAPIKeyClients removed
@@ -860,9 +900,6 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
if base != "" { if base != "" {
attrs["base_url"] = base attrs["base_url"] = base
} }
if hash := computeModelBlacklistHash(entry.ModelBlacklist); hash != "" {
attrs["model_blacklist_hash"] = hash
}
addConfigHeadersToAttrs(entry.Headers, attrs) addConfigHeadersToAttrs(entry.Headers, attrs)
a := &coreauth.Auth{ a := &coreauth.Auth{
ID: id, ID: id,
@@ -874,6 +911,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
} }
applyAuthModelBlacklistMeta(a, cfg, entry.ModelBlacklist, "apikey")
out = append(out, a) out = append(out, a)
} }
// Claude API keys -> synthesize auths // Claude API keys -> synthesize auths
@@ -895,9 +933,6 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
if hash := computeClaudeModelsHash(ck.Models); hash != "" { if hash := computeClaudeModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash attrs["models_hash"] = hash
} }
if hash := computeModelBlacklistHash(ck.ModelBlacklist); hash != "" {
attrs["model_blacklist_hash"] = hash
}
addConfigHeadersToAttrs(ck.Headers, attrs) addConfigHeadersToAttrs(ck.Headers, attrs)
proxyURL := strings.TrimSpace(ck.ProxyURL) proxyURL := strings.TrimSpace(ck.ProxyURL)
a := &coreauth.Auth{ a := &coreauth.Auth{
@@ -910,6 +945,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
} }
applyAuthModelBlacklistMeta(a, cfg, ck.ModelBlacklist, "apikey")
out = append(out, a) out = append(out, a)
} }
// Codex API keys -> synthesize auths // Codex API keys -> synthesize auths
@@ -927,9 +963,6 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
if ck.BaseURL != "" { if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL attrs["base_url"] = ck.BaseURL
} }
if hash := computeModelBlacklistHash(ck.ModelBlacklist); hash != "" {
attrs["model_blacklist_hash"] = hash
}
addConfigHeadersToAttrs(ck.Headers, attrs) addConfigHeadersToAttrs(ck.Headers, attrs)
proxyURL := strings.TrimSpace(ck.ProxyURL) proxyURL := strings.TrimSpace(ck.ProxyURL)
a := &coreauth.Auth{ a := &coreauth.Auth{
@@ -942,6 +975,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
} }
applyAuthModelBlacklistMeta(a, cfg, ck.ModelBlacklist, "apikey")
out = append(out, a) out = append(out, a)
} }
for i := range cfg.OpenAICompatibility { for i := range cfg.OpenAICompatibility {
@@ -1102,8 +1136,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
} }
applyAuthModelBlacklistMeta(a, cfg, nil, "oauth")
if provider == "gemini-cli" { if provider == "gemini-cli" {
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
applyAuthModelBlacklistMeta(v, cfg, nil, "oauth")
}
out = append(out, a) out = append(out, a)
out = append(out, virtuals...) out = append(out, virtuals...)
continue continue

View File

@@ -617,6 +617,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if a == nil || a.ID == "" { if a == nil || a.ID == "" {
return return
} }
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
if a.Attributes != nil { if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") { if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
GlobalModelRegistry().UnregisterClient(a.ID) GlobalModelRegistry().UnregisterClient(a.ID)
@@ -636,41 +637,57 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if compatDetected { if compatDetected {
provider = "openai-compatibility" provider = "openai-compatibility"
} }
blacklist := s.oauthBlacklist(provider, authKind)
var models []*ModelInfo var models []*ModelInfo
switch provider { switch provider {
case "gemini": case "gemini":
models = registry.GetGeminiModels() models = registry.GetGeminiModels()
if entry := s.resolveConfigGeminiKey(a); entry != nil { if entry := s.resolveConfigGeminiKey(a); entry != nil {
models = applyModelBlacklist(models, entry.ModelBlacklist) if authKind == "apikey" {
blacklist = entry.ModelBlacklist
} }
}
models = applyModelBlacklist(models, blacklist)
case "vertex": case "vertex":
// Vertex AI Gemini supports the same model identifiers as Gemini. // Vertex AI Gemini supports the same model identifiers as Gemini.
models = registry.GetGeminiVertexModels() models = registry.GetGeminiVertexModels()
models = applyModelBlacklist(models, blacklist)
case "gemini-cli": case "gemini-cli":
models = registry.GetGeminiCLIModels() models = registry.GetGeminiCLIModels()
models = applyModelBlacklist(models, blacklist)
case "aistudio": case "aistudio":
models = registry.GetAIStudioModels() models = registry.GetAIStudioModels()
models = applyModelBlacklist(models, blacklist)
case "antigravity": case "antigravity":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
models = executor.FetchAntigravityModels(ctx, a, s.cfg) models = executor.FetchAntigravityModels(ctx, a, s.cfg)
cancel() cancel()
models = applyModelBlacklist(models, blacklist)
case "claude": case "claude":
models = registry.GetClaudeModels() models = registry.GetClaudeModels()
if entry := s.resolveConfigClaudeKey(a); entry != nil { if entry := s.resolveConfigClaudeKey(a); entry != nil {
if len(entry.Models) > 0 { if len(entry.Models) > 0 {
models = buildClaudeConfigModels(entry) models = buildClaudeConfigModels(entry)
} }
models = applyModelBlacklist(models, entry.ModelBlacklist) if authKind == "apikey" {
blacklist = entry.ModelBlacklist
} }
}
models = applyModelBlacklist(models, blacklist)
case "codex": case "codex":
models = registry.GetOpenAIModels() models = registry.GetOpenAIModels()
if entry := s.resolveConfigCodexKey(a); entry != nil { if entry := s.resolveConfigCodexKey(a); entry != nil {
models = applyModelBlacklist(models, entry.ModelBlacklist) if authKind == "apikey" {
blacklist = entry.ModelBlacklist
} }
}
models = applyModelBlacklist(models, blacklist)
case "qwen": case "qwen":
models = registry.GetQwenModels() models = registry.GetQwenModels()
models = applyModelBlacklist(models, blacklist)
case "iflow": case "iflow":
models = registry.GetIFlowModels() models = registry.GetIFlowModels()
models = applyModelBlacklist(models, blacklist)
default: default:
// Handle OpenAI-compatibility providers by name using config // Handle OpenAI-compatibility providers by name using config
if s.cfg != nil { if s.cfg != nil {
@@ -855,6 +872,19 @@ func (s *Service) resolveConfigCodexKey(auth *coreauth.Auth) *config.CodexKey {
return nil return nil
} }
func (s *Service) oauthBlacklist(provider, authKind string) []string {
cfg := s.cfg
if cfg == nil {
return nil
}
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
providerKey := strings.ToLower(strings.TrimSpace(provider))
if authKindKey == "apikey" {
return nil
}
return cfg.OAuthModelBlacklist[providerKey]
}
func applyModelBlacklist(models []*ModelInfo, blacklist []string) []*ModelInfo { func applyModelBlacklist(models []*ModelInfo, blacklist []string) []*ModelInfo {
if len(models) == 0 || len(blacklist) == 0 { if len(models) == 0 || len(blacklist) == 0 {
return models return models