mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
fix(auth): fix runtime auth reload on oauth blacklist change
This commit is contained in:
@@ -30,6 +30,16 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func matchProvider(provider string, targets []string) (string, bool) {
|
||||
p := strings.ToLower(strings.TrimSpace(provider))
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(p, strings.TrimSpace(t)) {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return p, false
|
||||
}
|
||||
|
||||
// storePersister captures persistence-capable token store methods used by the watcher.
|
||||
type storePersister interface {
|
||||
PersistConfig(ctx context.Context) error
|
||||
@@ -54,6 +64,7 @@ type Watcher struct {
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
currentAuths map[string]*coreauth.Auth
|
||||
runtimeAuths map[string]*coreauth.Auth
|
||||
dispatchMu sync.Mutex
|
||||
dispatchCond *sync.Cond
|
||||
pendingUpdates map[string]AuthUpdate
|
||||
@@ -169,7 +180,7 @@ func (w *Watcher) Start(ctx context.Context) error {
|
||||
go w.processEvents(ctx)
|
||||
|
||||
// Perform an initial full reload based on current config and auth dir
|
||||
w.reloadClients(true)
|
||||
w.reloadClients(true, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -221,9 +232,57 @@ func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||
}
|
||||
}
|
||||
|
||||
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
|
||||
// to push auth updates through the same queue used by file/config watchers.
|
||||
// Returns true if the update was enqueued; false if no queue is configured.
|
||||
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||
if w == nil {
|
||||
return false
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
if w.runtimeAuths == nil {
|
||||
w.runtimeAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
switch update.Action {
|
||||
case AuthUpdateActionAdd, AuthUpdateActionModify:
|
||||
if update.Auth != nil && update.Auth.ID != "" {
|
||||
clone := update.Auth.Clone()
|
||||
w.runtimeAuths[clone.ID] = clone
|
||||
if w.currentAuths == nil {
|
||||
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.currentAuths[clone.ID] = clone.Clone()
|
||||
}
|
||||
case AuthUpdateActionDelete:
|
||||
id := update.ID
|
||||
if id == "" && update.Auth != nil {
|
||||
id = update.Auth.ID
|
||||
}
|
||||
if id != "" {
|
||||
delete(w.runtimeAuths, id)
|
||||
if w.currentAuths != nil {
|
||||
delete(w.currentAuths, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
if w.getAuthQueue() == nil {
|
||||
return false
|
||||
}
|
||||
w.dispatchAuthUpdates([]AuthUpdate{update})
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *Watcher) refreshAuthState() {
|
||||
auths := w.SnapshotCoreAuths()
|
||||
w.clientsMutex.Lock()
|
||||
if len(w.runtimeAuths) > 0 {
|
||||
for _, a := range w.runtimeAuths {
|
||||
if a != nil {
|
||||
auths = append(auths, a.Clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
updates := w.prepareAuthUpdatesLocked(auths)
|
||||
w.clientsMutex.Unlock()
|
||||
w.dispatchAuthUpdates(updates)
|
||||
@@ -472,6 +531,80 @@ func computeModelBlacklistHash(blacklist []string) string {
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
type modelBlacklistSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
func summarizeModelBlacklist(list []string) modelBlacklistSummary {
|
||||
if len(list) == 0 {
|
||||
return modelBlacklistSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return modelBlacklistSummary{
|
||||
hash: computeModelBlacklistHash(normalized),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeOAuthBlacklistMap(entries map[string][]string) map[string]modelBlacklistSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]modelBlacklistSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = summarizeModelBlacklist(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func diffOAuthBlacklistChanges(oldMap, newMap map[string][]string) ([]string, []string) {
|
||||
oldSummary := summarizeOAuthBlacklistMap(oldMap)
|
||||
newSummary := summarizeOAuthBlacklistMap(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
func applyAuthModelBlacklistMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||
if auth == nil || cfg == nil {
|
||||
return
|
||||
@@ -696,6 +829,11 @@ func (w *Watcher) reloadConfig() bool {
|
||||
w.config = newConfig
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
var affectedOAuthProviders []string
|
||||
if oldConfig != nil {
|
||||
_, affectedOAuthProviders = diffOAuthBlacklistChanges(oldConfig.OAuthModelBlacklist, newConfig.OAuthModelBlacklist)
|
||||
}
|
||||
|
||||
// Always apply the current log level based on the latest config.
|
||||
// This ensures logrus reflects the desired level even if change detection misses.
|
||||
util.SetLogLevel(newConfig)
|
||||
@@ -721,12 +859,12 @@ func (w *Watcher) reloadConfig() bool {
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
// Reload clients with new config
|
||||
w.reloadClients(authDirChanged)
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders)
|
||||
return true
|
||||
}
|
||||
|
||||
// reloadClients performs a full scan and reload of all clients.
|
||||
func (w *Watcher) reloadClients(rescanAuth bool) {
|
||||
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) {
|
||||
log.Debugf("starting full client load process")
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
@@ -738,6 +876,28 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(affectedOAuthProviders) > 0 {
|
||||
w.clientsMutex.Lock()
|
||||
if w.currentAuths != nil {
|
||||
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
|
||||
for id, auth := range w.currentAuths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if _, match := matchProvider(provider, affectedOAuthProviders); match {
|
||||
continue
|
||||
}
|
||||
filtered[id] = auth
|
||||
}
|
||||
w.currentAuths = filtered
|
||||
log.Debugf("applying oauth-model-blacklist to providers %v", affectedOAuthProviders)
|
||||
} else {
|
||||
w.currentAuths = nil
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
// Unregister all old API key clients before creating new ones
|
||||
// no legacy clients to unregister
|
||||
|
||||
@@ -1533,6 +1693,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
|
||||
newBL := summarizeModelBlacklist(n.ModelBlacklist)
|
||||
if oldBL.hash != newBL.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
|
||||
}
|
||||
}
|
||||
if !reflect.DeepEqual(trimStrings(oldCfg.GlAPIKey), trimStrings(newCfg.GlAPIKey)) {
|
||||
changes = append(changes, "generative-language-api-key: values updated (legacy view, redacted)")
|
||||
@@ -1561,6 +1726,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
|
||||
newBL := summarizeModelBlacklist(n.ModelBlacklist)
|
||||
if oldBL.hash != newBL.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1586,9 +1756,18 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
|
||||
newBL := summarizeModelBlacklist(n.ModelBlacklist)
|
||||
if oldBL.hash != newBL.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if entries, _ := diffOAuthBlacklistChanges(oldCfg.OAuthModelBlacklist, newCfg.OAuthModelBlacklist); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
||||
|
||||
Reference in New Issue
Block a user