From 6a191358affe7a46044872f354ba2ea79076ab3e Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 29 Nov 2025 20:30:11 +0800 Subject: [PATCH] fix(auth): fix runtime auth reload on oauth blacklist change --- internal/watcher/watcher.go | 185 +++++++++++++++++++++++++++++++++++- sdk/cliproxy/service.go | 32 ++++++- sdk/cliproxy/types.go | 17 +++- sdk/cliproxy/watcher.go | 3 + 4 files changed, 229 insertions(+), 8 deletions(-) diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 5f035718..c10a18a3 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -30,6 +30,16 @@ import ( log "github.com/sirupsen/logrus" ) +func matchProvider(provider string, targets []string) (string, bool) { + p := strings.ToLower(strings.TrimSpace(provider)) + for _, t := range targets { + if strings.EqualFold(p, strings.TrimSpace(t)) { + return p, true + } + } + return p, false +} + // storePersister captures persistence-capable token store methods used by the watcher. type storePersister interface { PersistConfig(ctx context.Context) error @@ -54,6 +64,7 @@ type Watcher struct { lastConfigHash string authQueue chan<- AuthUpdate currentAuths map[string]*coreauth.Auth + runtimeAuths map[string]*coreauth.Auth dispatchMu sync.Mutex dispatchCond *sync.Cond pendingUpdates map[string]AuthUpdate @@ -169,7 +180,7 @@ func (w *Watcher) Start(ctx context.Context) error { go w.processEvents(ctx) // Perform an initial full reload based on current config and auth dir - w.reloadClients(true) + w.reloadClients(true, nil) return nil } @@ -221,9 +232,57 @@ func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) { } } +// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths) +// to push auth updates through the same queue used by file/config watchers. +// Returns true if the update was enqueued; false if no queue is configured. +func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool { + if w == nil { + return false + } + w.clientsMutex.Lock() + if w.runtimeAuths == nil { + w.runtimeAuths = make(map[string]*coreauth.Auth) + } + switch update.Action { + case AuthUpdateActionAdd, AuthUpdateActionModify: + if update.Auth != nil && update.Auth.ID != "" { + clone := update.Auth.Clone() + w.runtimeAuths[clone.ID] = clone + if w.currentAuths == nil { + w.currentAuths = make(map[string]*coreauth.Auth) + } + w.currentAuths[clone.ID] = clone.Clone() + } + case AuthUpdateActionDelete: + id := update.ID + if id == "" && update.Auth != nil { + id = update.Auth.ID + } + if id != "" { + delete(w.runtimeAuths, id) + if w.currentAuths != nil { + delete(w.currentAuths, id) + } + } + } + w.clientsMutex.Unlock() + if w.getAuthQueue() == nil { + return false + } + w.dispatchAuthUpdates([]AuthUpdate{update}) + return true +} + func (w *Watcher) refreshAuthState() { auths := w.SnapshotCoreAuths() w.clientsMutex.Lock() + if len(w.runtimeAuths) > 0 { + for _, a := range w.runtimeAuths { + if a != nil { + auths = append(auths, a.Clone()) + } + } + } updates := w.prepareAuthUpdatesLocked(auths) w.clientsMutex.Unlock() w.dispatchAuthUpdates(updates) @@ -472,6 +531,80 @@ func computeModelBlacklistHash(blacklist []string) string { return hex.EncodeToString(sum[:]) } +type modelBlacklistSummary struct { + hash string + count int +} + +func summarizeModelBlacklist(list []string) modelBlacklistSummary { + if len(list) == 0 { + return modelBlacklistSummary{} + } + seen := make(map[string]struct{}, len(list)) + normalized := make([]string, 0, len(list)) + for _, entry := range list { + if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" { + if _, exists := seen[trimmed]; exists { + continue + } + seen[trimmed] = struct{}{} + normalized = append(normalized, trimmed) + } + } + sort.Strings(normalized) + return modelBlacklistSummary{ + hash: computeModelBlacklistHash(normalized), + count: len(normalized), + } +} + +func summarizeOAuthBlacklistMap(entries map[string][]string) map[string]modelBlacklistSummary { + if len(entries) == 0 { + return nil + } + out := make(map[string]modelBlacklistSummary, len(entries)) + for k, v := range entries { + key := strings.ToLower(strings.TrimSpace(k)) + if key == "" { + continue + } + out[key] = summarizeModelBlacklist(v) + } + return out +} + +func diffOAuthBlacklistChanges(oldMap, newMap map[string][]string) ([]string, []string) { + oldSummary := summarizeOAuthBlacklistMap(oldMap) + newSummary := summarizeOAuthBlacklistMap(newMap) + keys := make(map[string]struct{}, len(oldSummary)+len(newSummary)) + for k := range oldSummary { + keys[k] = struct{}{} + } + for k := range newSummary { + keys[k] = struct{}{} + } + changes := make([]string, 0, len(keys)) + affected := make([]string, 0, len(keys)) + for key := range keys { + oldInfo, okOld := oldSummary[key] + newInfo, okNew := newSummary[key] + switch { + case okOld && !okNew: + changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: removed", key)) + affected = append(affected, key) + case !okOld && okNew: + changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: added (%d entries)", key, newInfo.count)) + affected = append(affected, key) + case okOld && okNew && oldInfo.hash != newInfo.hash: + changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count)) + affected = append(affected, key) + } + } + sort.Strings(changes) + sort.Strings(affected) + return changes, affected +} + func applyAuthModelBlacklistMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) { if auth == nil || cfg == nil { return @@ -696,6 +829,11 @@ func (w *Watcher) reloadConfig() bool { w.config = newConfig w.clientsMutex.Unlock() + var affectedOAuthProviders []string + if oldConfig != nil { + _, affectedOAuthProviders = diffOAuthBlacklistChanges(oldConfig.OAuthModelBlacklist, newConfig.OAuthModelBlacklist) + } + // Always apply the current log level based on the latest config. // This ensures logrus reflects the desired level even if change detection misses. util.SetLogLevel(newConfig) @@ -721,12 +859,12 @@ func (w *Watcher) reloadConfig() bool { log.Infof("config successfully reloaded, triggering client reload") // Reload clients with new config - w.reloadClients(authDirChanged) + w.reloadClients(authDirChanged, affectedOAuthProviders) return true } // reloadClients performs a full scan and reload of all clients. -func (w *Watcher) reloadClients(rescanAuth bool) { +func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) { log.Debugf("starting full client load process") w.clientsMutex.RLock() @@ -738,6 +876,28 @@ func (w *Watcher) reloadClients(rescanAuth bool) { return } + if len(affectedOAuthProviders) > 0 { + w.clientsMutex.Lock() + if w.currentAuths != nil { + filtered := make(map[string]*coreauth.Auth, len(w.currentAuths)) + for id, auth := range w.currentAuths { + if auth == nil { + continue + } + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if _, match := matchProvider(provider, affectedOAuthProviders); match { + continue + } + filtered[id] = auth + } + w.currentAuths = filtered + log.Debugf("applying oauth-model-blacklist to providers %v", affectedOAuthProviders) + } else { + w.currentAuths = nil + } + w.clientsMutex.Unlock() + } + // Unregister all old API key clients before creating new ones // no legacy clients to unregister @@ -1533,6 +1693,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i)) } + oldBL := summarizeModelBlacklist(o.ModelBlacklist) + newBL := summarizeModelBlacklist(n.ModelBlacklist) + if oldBL.hash != newBL.hash { + changes = append(changes, fmt.Sprintf("gemini[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count)) + } } if !reflect.DeepEqual(trimStrings(oldCfg.GlAPIKey), trimStrings(newCfg.GlAPIKey)) { changes = append(changes, "generative-language-api-key: values updated (legacy view, redacted)") @@ -1561,6 +1726,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i)) } + oldBL := summarizeModelBlacklist(o.ModelBlacklist) + newBL := summarizeModelBlacklist(n.ModelBlacklist) + if oldBL.hash != newBL.hash { + changes = append(changes, fmt.Sprintf("claude[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count)) + } } } @@ -1586,9 +1756,18 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i)) } + oldBL := summarizeModelBlacklist(o.ModelBlacklist) + newBL := summarizeModelBlacklist(n.ModelBlacklist) + if oldBL.hash != newBL.hash { + changes = append(changes, fmt.Sprintf("codex[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count)) + } } } + if entries, _ := diffOAuthBlacklistChanges(oldCfg.OAuthModelBlacklist, newCfg.OAuthModelBlacklist); len(entries) > 0 { + changes = append(changes, entries...) + } + // Remote management (never print the key) if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote { changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote)) diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 4b6e70fc..9eb1d584 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -146,6 +146,27 @@ func (s *Service) consumeAuthUpdates(ctx context.Context) { } } +func (s *Service) emitAuthUpdate(ctx context.Context, update watcher.AuthUpdate) { + if s == nil { + return + } + if ctx == nil { + ctx = context.Background() + } + if s.watcher != nil && s.watcher.DispatchRuntimeAuthUpdate(update) { + return + } + if s.authUpdates != nil { + select { + case s.authUpdates <- update: + return + default: + log.Debugf("auth update queue saturated, applying inline action=%v id=%s", update.Action, update.ID) + } + } + s.handleAuthUpdate(ctx, update) +} + func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdate) { if s == nil { return @@ -220,7 +241,11 @@ func (s *Service) wsOnConnected(channelID string) { Metadata: map[string]any{"email": channelID}, // metadata drives logging and usage tracking } log.Infof("websocket provider connected: %s", channelID) - s.applyCoreAuthAddOrUpdate(context.Background(), auth) + s.emitAuthUpdate(context.Background(), watcher.AuthUpdate{ + Action: watcher.AuthUpdateActionAdd, + ID: auth.ID, + Auth: auth, + }) } func (s *Service) wsOnDisconnected(channelID string, reason error) { @@ -237,7 +262,10 @@ func (s *Service) wsOnDisconnected(channelID string, reason error) { log.Infof("websocket provider disconnected: %s", channelID) } ctx := context.Background() - s.applyCoreAuthRemoval(ctx, channelID) + s.emitAuthUpdate(ctx, watcher.AuthUpdate{ + Action: watcher.AuthUpdateActionDelete, + ID: channelID, + }) } func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) { diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index 1d577153..b44185d1 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -83,9 +83,10 @@ type WatcherWrapper struct { start func(ctx context.Context) error stop func() error - setConfig func(cfg *config.Config) - snapshotAuths func() []*coreauth.Auth - setUpdateQueue func(queue chan<- watcher.AuthUpdate) + setConfig func(cfg *config.Config) + snapshotAuths func() []*coreauth.Auth + setUpdateQueue func(queue chan<- watcher.AuthUpdate) + dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool } // Start proxies to the underlying watcher Start implementation. @@ -112,6 +113,16 @@ func (w *WatcherWrapper) SetConfig(cfg *config.Config) { w.setConfig(cfg) } +// DispatchRuntimeAuthUpdate forwards runtime auth updates (e.g., websocket providers) +// into the watcher-managed auth update queue when available. +// Returns true if the update was enqueued successfully. +func (w *WatcherWrapper) DispatchRuntimeAuthUpdate(update watcher.AuthUpdate) bool { + if w == nil || w.dispatchRuntimeUpdate == nil { + return false + } + return w.dispatchRuntimeUpdate(update) +} + // SetClients updates the watcher file-backed clients registry. // SetClients and SetAPIKeyClients removed; watcher manages its own caches diff --git a/sdk/cliproxy/watcher.go b/sdk/cliproxy/watcher.go index 81e4c18a..921e2068 100644 --- a/sdk/cliproxy/watcher.go +++ b/sdk/cliproxy/watcher.go @@ -28,5 +28,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi setUpdateQueue: func(queue chan<- watcher.AuthUpdate) { w.SetAuthUpdateQueue(queue) }, + dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool { + return w.DispatchRuntimeAuthUpdate(update) + }, }, nil }