fix(auth): fix runtime auth reload on oauth blacklist change

This commit is contained in:
hkfires
2025-11-29 20:30:11 +08:00
parent 5983e3ec87
commit 6a191358af
4 changed files with 229 additions and 8 deletions

View File

@@ -30,6 +30,16 @@ import (
log "github.com/sirupsen/logrus"
)
func matchProvider(provider string, targets []string) (string, bool) {
p := strings.ToLower(strings.TrimSpace(provider))
for _, t := range targets {
if strings.EqualFold(p, strings.TrimSpace(t)) {
return p, true
}
}
return p, false
}
// storePersister captures persistence-capable token store methods used by the watcher.
type storePersister interface {
PersistConfig(ctx context.Context) error
@@ -54,6 +64,7 @@ type Watcher struct {
lastConfigHash string
authQueue chan<- AuthUpdate
currentAuths map[string]*coreauth.Auth
runtimeAuths map[string]*coreauth.Auth
dispatchMu sync.Mutex
dispatchCond *sync.Cond
pendingUpdates map[string]AuthUpdate
@@ -169,7 +180,7 @@ func (w *Watcher) Start(ctx context.Context) error {
go w.processEvents(ctx)
// Perform an initial full reload based on current config and auth dir
w.reloadClients(true)
w.reloadClients(true, nil)
return nil
}
@@ -221,9 +232,57 @@ func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
}
}
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
// to push auth updates through the same queue used by file/config watchers.
// Returns true if the update was enqueued; false if no queue is configured.
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
if w == nil {
return false
}
w.clientsMutex.Lock()
if w.runtimeAuths == nil {
w.runtimeAuths = make(map[string]*coreauth.Auth)
}
switch update.Action {
case AuthUpdateActionAdd, AuthUpdateActionModify:
if update.Auth != nil && update.Auth.ID != "" {
clone := update.Auth.Clone()
w.runtimeAuths[clone.ID] = clone
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
w.currentAuths[clone.ID] = clone.Clone()
}
case AuthUpdateActionDelete:
id := update.ID
if id == "" && update.Auth != nil {
id = update.Auth.ID
}
if id != "" {
delete(w.runtimeAuths, id)
if w.currentAuths != nil {
delete(w.currentAuths, id)
}
}
}
w.clientsMutex.Unlock()
if w.getAuthQueue() == nil {
return false
}
w.dispatchAuthUpdates([]AuthUpdate{update})
return true
}
func (w *Watcher) refreshAuthState() {
auths := w.SnapshotCoreAuths()
w.clientsMutex.Lock()
if len(w.runtimeAuths) > 0 {
for _, a := range w.runtimeAuths {
if a != nil {
auths = append(auths, a.Clone())
}
}
}
updates := w.prepareAuthUpdatesLocked(auths)
w.clientsMutex.Unlock()
w.dispatchAuthUpdates(updates)
@@ -472,6 +531,80 @@ func computeModelBlacklistHash(blacklist []string) string {
return hex.EncodeToString(sum[:])
}
type modelBlacklistSummary struct {
hash string
count int
}
func summarizeModelBlacklist(list []string) modelBlacklistSummary {
if len(list) == 0 {
return modelBlacklistSummary{}
}
seen := make(map[string]struct{}, len(list))
normalized := make([]string, 0, len(list))
for _, entry := range list {
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
if _, exists := seen[trimmed]; exists {
continue
}
seen[trimmed] = struct{}{}
normalized = append(normalized, trimmed)
}
}
sort.Strings(normalized)
return modelBlacklistSummary{
hash: computeModelBlacklistHash(normalized),
count: len(normalized),
}
}
func summarizeOAuthBlacklistMap(entries map[string][]string) map[string]modelBlacklistSummary {
if len(entries) == 0 {
return nil
}
out := make(map[string]modelBlacklistSummary, len(entries))
for k, v := range entries {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
out[key] = summarizeModelBlacklist(v)
}
return out
}
func diffOAuthBlacklistChanges(oldMap, newMap map[string][]string) ([]string, []string) {
oldSummary := summarizeOAuthBlacklistMap(oldMap)
newSummary := summarizeOAuthBlacklistMap(newMap)
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
for k := range oldSummary {
keys[k] = struct{}{}
}
for k := range newSummary {
keys[k] = struct{}{}
}
changes := make([]string, 0, len(keys))
affected := make([]string, 0, len(keys))
for key := range keys {
oldInfo, okOld := oldSummary[key]
newInfo, okNew := newSummary[key]
switch {
case okOld && !okNew:
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: removed", key))
affected = append(affected, key)
case !okOld && okNew:
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: added (%d entries)", key, newInfo.count))
affected = append(affected, key)
case okOld && okNew && oldInfo.hash != newInfo.hash:
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
affected = append(affected, key)
}
}
sort.Strings(changes)
sort.Strings(affected)
return changes, affected
}
func applyAuthModelBlacklistMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
if auth == nil || cfg == nil {
return
@@ -696,6 +829,11 @@ func (w *Watcher) reloadConfig() bool {
w.config = newConfig
w.clientsMutex.Unlock()
var affectedOAuthProviders []string
if oldConfig != nil {
_, affectedOAuthProviders = diffOAuthBlacklistChanges(oldConfig.OAuthModelBlacklist, newConfig.OAuthModelBlacklist)
}
// Always apply the current log level based on the latest config.
// This ensures logrus reflects the desired level even if change detection misses.
util.SetLogLevel(newConfig)
@@ -721,12 +859,12 @@ func (w *Watcher) reloadConfig() bool {
log.Infof("config successfully reloaded, triggering client reload")
// Reload clients with new config
w.reloadClients(authDirChanged)
w.reloadClients(authDirChanged, affectedOAuthProviders)
return true
}
// reloadClients performs a full scan and reload of all clients.
func (w *Watcher) reloadClients(rescanAuth bool) {
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) {
log.Debugf("starting full client load process")
w.clientsMutex.RLock()
@@ -738,6 +876,28 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
return
}
if len(affectedOAuthProviders) > 0 {
w.clientsMutex.Lock()
if w.currentAuths != nil {
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
for id, auth := range w.currentAuths {
if auth == nil {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if _, match := matchProvider(provider, affectedOAuthProviders); match {
continue
}
filtered[id] = auth
}
w.currentAuths = filtered
log.Debugf("applying oauth-model-blacklist to providers %v", affectedOAuthProviders)
} else {
w.currentAuths = nil
}
w.clientsMutex.Unlock()
}
// Unregister all old API key clients before creating new ones
// no legacy clients to unregister
@@ -1533,6 +1693,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
}
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
newBL := summarizeModelBlacklist(n.ModelBlacklist)
if oldBL.hash != newBL.hash {
changes = append(changes, fmt.Sprintf("gemini[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
}
}
if !reflect.DeepEqual(trimStrings(oldCfg.GlAPIKey), trimStrings(newCfg.GlAPIKey)) {
changes = append(changes, "generative-language-api-key: values updated (legacy view, redacted)")
@@ -1561,6 +1726,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
}
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
newBL := summarizeModelBlacklist(n.ModelBlacklist)
if oldBL.hash != newBL.hash {
changes = append(changes, fmt.Sprintf("claude[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
}
}
}
@@ -1586,9 +1756,18 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
}
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
newBL := summarizeModelBlacklist(n.ModelBlacklist)
if oldBL.hash != newBL.hash {
changes = append(changes, fmt.Sprintf("codex[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
}
}
}
if entries, _ := diffOAuthBlacklistChanges(oldCfg.OAuthModelBlacklist, newCfg.OAuthModelBlacklist); len(entries) > 0 {
changes = append(changes, entries...)
}
// Remote management (never print the key)
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))