From 7b7871ede29aad123f90f0049a9a5849bcea8bda Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 30 Nov 2025 13:38:23 +0800 Subject: [PATCH] feat(api): add oauth excluded model management --- .../api/handlers/management/config_lists.go | 89 +++++++++++++++++++ internal/api/server.go | 5 ++ internal/config/config.go | 55 ++++++++++++ 3 files changed, 149 insertions(+) diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index b4b43b0f..71193084 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -223,6 +223,7 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) { value.APIKey = strings.TrimSpace(value.APIKey) value.BaseURL = strings.TrimSpace(value.BaseURL) value.ProxyURL = strings.TrimSpace(value.ProxyURL) + value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) if value.APIKey == "" { // Treat empty API key as delete. if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { @@ -504,6 +505,91 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) { c.JSON(400, gin.H{"error": "missing name or index"}) } +// oauth-excluded-models: map[string][]string +func (h *Handler) GetOAuthExcludedModels(c *gin.Context) { + c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)}) +} + +func (h *Handler) PutOAuthExcludedModels(c *gin.Context) { + data, err := c.GetRawData() + if err != nil { + c.JSON(400, gin.H{"error": "failed to read body"}) + return + } + var entries map[string][]string + if err = json.Unmarshal(data, &entries); err != nil { + var wrapper struct { + Items map[string][]string `json:"items"` + } + if err2 := json.Unmarshal(data, &wrapper); err2 != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + entries = wrapper.Items + } + h.cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries) + h.persist(c) +} + +func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) { + var body struct { + Provider *string `json:"provider"` + Models []string `json:"models"` + } + if err := c.ShouldBindJSON(&body); err != nil || body.Provider == nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + provider := strings.ToLower(strings.TrimSpace(*body.Provider)) + if provider == "" { + c.JSON(400, gin.H{"error": "invalid provider"}) + return + } + normalized := config.NormalizeExcludedModels(body.Models) + if len(normalized) == 0 { + if h.cfg.OAuthExcludedModels == nil { + c.JSON(404, gin.H{"error": "provider not found"}) + return + } + if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok { + c.JSON(404, gin.H{"error": "provider not found"}) + return + } + delete(h.cfg.OAuthExcludedModels, provider) + if len(h.cfg.OAuthExcludedModels) == 0 { + h.cfg.OAuthExcludedModels = nil + } + h.persist(c) + return + } + if h.cfg.OAuthExcludedModels == nil { + h.cfg.OAuthExcludedModels = make(map[string][]string) + } + h.cfg.OAuthExcludedModels[provider] = normalized + h.persist(c) +} + +func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) { + provider := strings.ToLower(strings.TrimSpace(c.Query("provider"))) + if provider == "" { + c.JSON(400, gin.H{"error": "missing provider"}) + return + } + if h.cfg.OAuthExcludedModels == nil { + c.JSON(404, gin.H{"error": "provider not found"}) + return + } + if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok { + c.JSON(404, gin.H{"error": "provider not found"}) + return + } + delete(h.cfg.OAuthExcludedModels, provider) + if len(h.cfg.OAuthExcludedModels) == 0 { + h.cfg.OAuthExcludedModels = nil + } + h.persist(c) +} + // codex-api-key: []CodexKey func (h *Handler) GetCodexKeys(c *gin.Context) { c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey}) @@ -533,6 +619,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) { entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = config.NormalizeHeaders(entry.Headers) + entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels) if entry.BaseURL == "" { continue } @@ -557,6 +644,7 @@ func (h *Handler) PatchCodexKey(c *gin.Context) { value.BaseURL = strings.TrimSpace(value.BaseURL) value.ProxyURL = strings.TrimSpace(value.ProxyURL) value.Headers = config.NormalizeHeaders(value.Headers) + value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) // If base-url becomes empty, delete instead of update if value.BaseURL == "" { if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { @@ -694,6 +782,7 @@ func normalizeClaudeKey(entry *config.ClaudeKey) { entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = config.NormalizeHeaders(entry.Headers) + entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels) if len(entry.Models) == 0 { return } diff --git a/internal/api/server.go b/internal/api/server.go index 3dd78c93..ab9c0354 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -543,6 +543,11 @@ func (s *Server) registerManagementRoutes() { mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat) mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat) + mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels) + mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels) + mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels) + mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels) + mgmt.GET("/auth-files", s.mgmt.ListAuthFiles) mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile) mgmt.POST("/auth-files", s.mgmt.UploadAuthFile) diff --git a/internal/config/config.go b/internal/config/config.go index 10440395..97b5a0c2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -334,6 +334,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize OpenAI compatibility providers: drop entries without base-url cfg.SanitizeOpenAICompatibility() + // Normalize OAuth provider model exclusion map. + cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels) + // Return the populated configuration struct. return &cfg, nil } @@ -371,6 +374,7 @@ func (cfg *Config) SanitizeCodexKeys() { e := cfg.CodexKey[i] e.BaseURL = strings.TrimSpace(e.BaseURL) e.Headers = NormalizeHeaders(e.Headers) + e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels) if e.BaseURL == "" { continue } @@ -387,6 +391,7 @@ func (cfg *Config) SanitizeClaudeKeys() { for i := range cfg.ClaudeKey { entry := &cfg.ClaudeKey[i] entry.Headers = NormalizeHeaders(entry.Headers) + entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) } } @@ -407,6 +412,7 @@ func (cfg *Config) SanitizeGeminiKeys() { entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = NormalizeHeaders(entry.Headers) + entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) if _, exists := seen[entry.APIKey]; exists { continue } @@ -469,6 +475,55 @@ func NormalizeHeaders(headers map[string]string) map[string]string { return clean } +// NormalizeExcludedModels trims, lowercases, and deduplicates model exclusion patterns. +// It preserves the order of first occurrences and drops empty entries. +func NormalizeExcludedModels(models []string) []string { + if len(models) == 0 { + return nil + } + seen := make(map[string]struct{}, len(models)) + out := make([]string, 0, len(models)) + for _, raw := range models { + trimmed := strings.ToLower(strings.TrimSpace(raw)) + if trimmed == "" { + continue + } + if _, exists := seen[trimmed]; exists { + continue + } + seen[trimmed] = struct{}{} + out = append(out, trimmed) + } + if len(out) == 0 { + return nil + } + return out +} + +// NormalizeOAuthExcludedModels cleans provider -> excluded models mappings by normalizing provider keys +// and applying model exclusion normalization to each entry. +func NormalizeOAuthExcludedModels(entries map[string][]string) map[string][]string { + if len(entries) == 0 { + return nil + } + out := make(map[string][]string, len(entries)) + for provider, models := range entries { + key := strings.ToLower(strings.TrimSpace(provider)) + if key == "" { + continue + } + normalized := NormalizeExcludedModels(models) + if len(normalized) == 0 { + continue + } + out[key] = normalized + } + if len(out) == 0 { + return nil + } + return out +} + // hashSecret hashes the given secret using bcrypt. func hashSecret(secret string) (string, error) { // Use default cost for simplicity.