mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
fix(auth): fix runtime auth reload on oauth blacklist change
This commit is contained in:
@@ -30,6 +30,16 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func matchProvider(provider string, targets []string) (string, bool) {
|
||||||
|
p := strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
for _, t := range targets {
|
||||||
|
if strings.EqualFold(p, strings.TrimSpace(t)) {
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, false
|
||||||
|
}
|
||||||
|
|
||||||
// storePersister captures persistence-capable token store methods used by the watcher.
|
// storePersister captures persistence-capable token store methods used by the watcher.
|
||||||
type storePersister interface {
|
type storePersister interface {
|
||||||
PersistConfig(ctx context.Context) error
|
PersistConfig(ctx context.Context) error
|
||||||
@@ -54,6 +64,7 @@ type Watcher struct {
|
|||||||
lastConfigHash string
|
lastConfigHash string
|
||||||
authQueue chan<- AuthUpdate
|
authQueue chan<- AuthUpdate
|
||||||
currentAuths map[string]*coreauth.Auth
|
currentAuths map[string]*coreauth.Auth
|
||||||
|
runtimeAuths map[string]*coreauth.Auth
|
||||||
dispatchMu sync.Mutex
|
dispatchMu sync.Mutex
|
||||||
dispatchCond *sync.Cond
|
dispatchCond *sync.Cond
|
||||||
pendingUpdates map[string]AuthUpdate
|
pendingUpdates map[string]AuthUpdate
|
||||||
@@ -169,7 +180,7 @@ func (w *Watcher) Start(ctx context.Context) error {
|
|||||||
go w.processEvents(ctx)
|
go w.processEvents(ctx)
|
||||||
|
|
||||||
// Perform an initial full reload based on current config and auth dir
|
// Perform an initial full reload based on current config and auth dir
|
||||||
w.reloadClients(true)
|
w.reloadClients(true, nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,9 +232,57 @@ func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
|
||||||
|
// to push auth updates through the same queue used by file/config watchers.
|
||||||
|
// Returns true if the update was enqueued; false if no queue is configured.
|
||||||
|
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||||
|
if w == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
if w.runtimeAuths == nil {
|
||||||
|
w.runtimeAuths = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
switch update.Action {
|
||||||
|
case AuthUpdateActionAdd, AuthUpdateActionModify:
|
||||||
|
if update.Auth != nil && update.Auth.ID != "" {
|
||||||
|
clone := update.Auth.Clone()
|
||||||
|
w.runtimeAuths[clone.ID] = clone
|
||||||
|
if w.currentAuths == nil {
|
||||||
|
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
w.currentAuths[clone.ID] = clone.Clone()
|
||||||
|
}
|
||||||
|
case AuthUpdateActionDelete:
|
||||||
|
id := update.ID
|
||||||
|
if id == "" && update.Auth != nil {
|
||||||
|
id = update.Auth.ID
|
||||||
|
}
|
||||||
|
if id != "" {
|
||||||
|
delete(w.runtimeAuths, id)
|
||||||
|
if w.currentAuths != nil {
|
||||||
|
delete(w.currentAuths, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
if w.getAuthQueue() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
w.dispatchAuthUpdates([]AuthUpdate{update})
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (w *Watcher) refreshAuthState() {
|
func (w *Watcher) refreshAuthState() {
|
||||||
auths := w.SnapshotCoreAuths()
|
auths := w.SnapshotCoreAuths()
|
||||||
w.clientsMutex.Lock()
|
w.clientsMutex.Lock()
|
||||||
|
if len(w.runtimeAuths) > 0 {
|
||||||
|
for _, a := range w.runtimeAuths {
|
||||||
|
if a != nil {
|
||||||
|
auths = append(auths, a.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
updates := w.prepareAuthUpdatesLocked(auths)
|
updates := w.prepareAuthUpdatesLocked(auths)
|
||||||
w.clientsMutex.Unlock()
|
w.clientsMutex.Unlock()
|
||||||
w.dispatchAuthUpdates(updates)
|
w.dispatchAuthUpdates(updates)
|
||||||
@@ -472,6 +531,80 @@ func computeModelBlacklistHash(blacklist []string) string {
|
|||||||
return hex.EncodeToString(sum[:])
|
return hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type modelBlacklistSummary struct {
|
||||||
|
hash string
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeModelBlacklist(list []string) modelBlacklistSummary {
|
||||||
|
if len(list) == 0 {
|
||||||
|
return modelBlacklistSummary{}
|
||||||
|
}
|
||||||
|
seen := make(map[string]struct{}, len(list))
|
||||||
|
normalized := make([]string, 0, len(list))
|
||||||
|
for _, entry := range list {
|
||||||
|
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
|
||||||
|
if _, exists := seen[trimmed]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[trimmed] = struct{}{}
|
||||||
|
normalized = append(normalized, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(normalized)
|
||||||
|
return modelBlacklistSummary{
|
||||||
|
hash: computeModelBlacklistHash(normalized),
|
||||||
|
count: len(normalized),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeOAuthBlacklistMap(entries map[string][]string) map[string]modelBlacklistSummary {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]modelBlacklistSummary, len(entries))
|
||||||
|
for k, v := range entries {
|
||||||
|
key := strings.ToLower(strings.TrimSpace(k))
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[key] = summarizeModelBlacklist(v)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func diffOAuthBlacklistChanges(oldMap, newMap map[string][]string) ([]string, []string) {
|
||||||
|
oldSummary := summarizeOAuthBlacklistMap(oldMap)
|
||||||
|
newSummary := summarizeOAuthBlacklistMap(newMap)
|
||||||
|
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||||
|
for k := range oldSummary {
|
||||||
|
keys[k] = struct{}{}
|
||||||
|
}
|
||||||
|
for k := range newSummary {
|
||||||
|
keys[k] = struct{}{}
|
||||||
|
}
|
||||||
|
changes := make([]string, 0, len(keys))
|
||||||
|
affected := make([]string, 0, len(keys))
|
||||||
|
for key := range keys {
|
||||||
|
oldInfo, okOld := oldSummary[key]
|
||||||
|
newInfo, okNew := newSummary[key]
|
||||||
|
switch {
|
||||||
|
case okOld && !okNew:
|
||||||
|
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: removed", key))
|
||||||
|
affected = append(affected, key)
|
||||||
|
case !okOld && okNew:
|
||||||
|
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: added (%d entries)", key, newInfo.count))
|
||||||
|
affected = append(affected, key)
|
||||||
|
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||||
|
changes = append(changes, fmt.Sprintf("oauth-model-blacklist[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||||
|
affected = append(affected, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(changes)
|
||||||
|
sort.Strings(affected)
|
||||||
|
return changes, affected
|
||||||
|
}
|
||||||
|
|
||||||
func applyAuthModelBlacklistMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
func applyAuthModelBlacklistMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||||
if auth == nil || cfg == nil {
|
if auth == nil || cfg == nil {
|
||||||
return
|
return
|
||||||
@@ -696,6 +829,11 @@ func (w *Watcher) reloadConfig() bool {
|
|||||||
w.config = newConfig
|
w.config = newConfig
|
||||||
w.clientsMutex.Unlock()
|
w.clientsMutex.Unlock()
|
||||||
|
|
||||||
|
var affectedOAuthProviders []string
|
||||||
|
if oldConfig != nil {
|
||||||
|
_, affectedOAuthProviders = diffOAuthBlacklistChanges(oldConfig.OAuthModelBlacklist, newConfig.OAuthModelBlacklist)
|
||||||
|
}
|
||||||
|
|
||||||
// Always apply the current log level based on the latest config.
|
// Always apply the current log level based on the latest config.
|
||||||
// This ensures logrus reflects the desired level even if change detection misses.
|
// This ensures logrus reflects the desired level even if change detection misses.
|
||||||
util.SetLogLevel(newConfig)
|
util.SetLogLevel(newConfig)
|
||||||
@@ -721,12 +859,12 @@ func (w *Watcher) reloadConfig() bool {
|
|||||||
|
|
||||||
log.Infof("config successfully reloaded, triggering client reload")
|
log.Infof("config successfully reloaded, triggering client reload")
|
||||||
// Reload clients with new config
|
// Reload clients with new config
|
||||||
w.reloadClients(authDirChanged)
|
w.reloadClients(authDirChanged, affectedOAuthProviders)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// reloadClients performs a full scan and reload of all clients.
|
// reloadClients performs a full scan and reload of all clients.
|
||||||
func (w *Watcher) reloadClients(rescanAuth bool) {
|
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) {
|
||||||
log.Debugf("starting full client load process")
|
log.Debugf("starting full client load process")
|
||||||
|
|
||||||
w.clientsMutex.RLock()
|
w.clientsMutex.RLock()
|
||||||
@@ -738,6 +876,28 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(affectedOAuthProviders) > 0 {
|
||||||
|
w.clientsMutex.Lock()
|
||||||
|
if w.currentAuths != nil {
|
||||||
|
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
|
||||||
|
for id, auth := range w.currentAuths {
|
||||||
|
if auth == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||||
|
if _, match := matchProvider(provider, affectedOAuthProviders); match {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered[id] = auth
|
||||||
|
}
|
||||||
|
w.currentAuths = filtered
|
||||||
|
log.Debugf("applying oauth-model-blacklist to providers %v", affectedOAuthProviders)
|
||||||
|
} else {
|
||||||
|
w.currentAuths = nil
|
||||||
|
}
|
||||||
|
w.clientsMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Unregister all old API key clients before creating new ones
|
// Unregister all old API key clients before creating new ones
|
||||||
// no legacy clients to unregister
|
// no legacy clients to unregister
|
||||||
|
|
||||||
@@ -1533,6 +1693,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if !equalStringMap(o.Headers, n.Headers) {
|
if !equalStringMap(o.Headers, n.Headers) {
|
||||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||||
}
|
}
|
||||||
|
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
|
||||||
|
newBL := summarizeModelBlacklist(n.ModelBlacklist)
|
||||||
|
if oldBL.hash != newBL.hash {
|
||||||
|
changes = append(changes, fmt.Sprintf("gemini[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(trimStrings(oldCfg.GlAPIKey), trimStrings(newCfg.GlAPIKey)) {
|
if !reflect.DeepEqual(trimStrings(oldCfg.GlAPIKey), trimStrings(newCfg.GlAPIKey)) {
|
||||||
changes = append(changes, "generative-language-api-key: values updated (legacy view, redacted)")
|
changes = append(changes, "generative-language-api-key: values updated (legacy view, redacted)")
|
||||||
@@ -1561,6 +1726,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if !equalStringMap(o.Headers, n.Headers) {
|
if !equalStringMap(o.Headers, n.Headers) {
|
||||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||||
}
|
}
|
||||||
|
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
|
||||||
|
newBL := summarizeModelBlacklist(n.ModelBlacklist)
|
||||||
|
if oldBL.hash != newBL.hash {
|
||||||
|
changes = append(changes, fmt.Sprintf("claude[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1586,9 +1756,18 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if !equalStringMap(o.Headers, n.Headers) {
|
if !equalStringMap(o.Headers, n.Headers) {
|
||||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||||
}
|
}
|
||||||
|
oldBL := summarizeModelBlacklist(o.ModelBlacklist)
|
||||||
|
newBL := summarizeModelBlacklist(n.ModelBlacklist)
|
||||||
|
if oldBL.hash != newBL.hash {
|
||||||
|
changes = append(changes, fmt.Sprintf("codex[%d].model-blacklist: updated (%d -> %d entries)", i, oldBL.count, newBL.count))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if entries, _ := diffOAuthBlacklistChanges(oldCfg.OAuthModelBlacklist, newCfg.OAuthModelBlacklist); len(entries) > 0 {
|
||||||
|
changes = append(changes, entries...)
|
||||||
|
}
|
||||||
|
|
||||||
// Remote management (never print the key)
|
// Remote management (never print the key)
|
||||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||||
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
||||||
|
|||||||
@@ -146,6 +146,27 @@ func (s *Service) consumeAuthUpdates(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) emitAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if s.watcher != nil && s.watcher.DispatchRuntimeAuthUpdate(update) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.authUpdates != nil {
|
||||||
|
select {
|
||||||
|
case s.authUpdates <- update:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
log.Debugf("auth update queue saturated, applying inline action=%v id=%s", update.Action, update.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.handleAuthUpdate(ctx, update)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
|
func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return
|
return
|
||||||
@@ -220,7 +241,11 @@ func (s *Service) wsOnConnected(channelID string) {
|
|||||||
Metadata: map[string]any{"email": channelID}, // metadata drives logging and usage tracking
|
Metadata: map[string]any{"email": channelID}, // metadata drives logging and usage tracking
|
||||||
}
|
}
|
||||||
log.Infof("websocket provider connected: %s", channelID)
|
log.Infof("websocket provider connected: %s", channelID)
|
||||||
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
|
s.emitAuthUpdate(context.Background(), watcher.AuthUpdate{
|
||||||
|
Action: watcher.AuthUpdateActionAdd,
|
||||||
|
ID: auth.ID,
|
||||||
|
Auth: auth,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) wsOnDisconnected(channelID string, reason error) {
|
func (s *Service) wsOnDisconnected(channelID string, reason error) {
|
||||||
@@ -237,7 +262,10 @@ func (s *Service) wsOnDisconnected(channelID string, reason error) {
|
|||||||
log.Infof("websocket provider disconnected: %s", channelID)
|
log.Infof("websocket provider disconnected: %s", channelID)
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
s.applyCoreAuthRemoval(ctx, channelID)
|
s.emitAuthUpdate(ctx, watcher.AuthUpdate{
|
||||||
|
Action: watcher.AuthUpdateActionDelete,
|
||||||
|
ID: channelID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
|
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
|
||||||
|
|||||||
@@ -83,9 +83,10 @@ type WatcherWrapper struct {
|
|||||||
start func(ctx context.Context) error
|
start func(ctx context.Context) error
|
||||||
stop func() error
|
stop func() error
|
||||||
|
|
||||||
setConfig func(cfg *config.Config)
|
setConfig func(cfg *config.Config)
|
||||||
snapshotAuths func() []*coreauth.Auth
|
snapshotAuths func() []*coreauth.Auth
|
||||||
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
|
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
|
||||||
|
dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start proxies to the underlying watcher Start implementation.
|
// Start proxies to the underlying watcher Start implementation.
|
||||||
@@ -112,6 +113,16 @@ func (w *WatcherWrapper) SetConfig(cfg *config.Config) {
|
|||||||
w.setConfig(cfg)
|
w.setConfig(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DispatchRuntimeAuthUpdate forwards runtime auth updates (e.g., websocket providers)
|
||||||
|
// into the watcher-managed auth update queue when available.
|
||||||
|
// Returns true if the update was enqueued successfully.
|
||||||
|
func (w *WatcherWrapper) DispatchRuntimeAuthUpdate(update watcher.AuthUpdate) bool {
|
||||||
|
if w == nil || w.dispatchRuntimeUpdate == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return w.dispatchRuntimeUpdate(update)
|
||||||
|
}
|
||||||
|
|
||||||
// SetClients updates the watcher file-backed clients registry.
|
// SetClients updates the watcher file-backed clients registry.
|
||||||
// SetClients and SetAPIKeyClients removed; watcher manages its own caches
|
// SetClients and SetAPIKeyClients removed; watcher manages its own caches
|
||||||
|
|
||||||
|
|||||||
@@ -28,5 +28,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi
|
|||||||
setUpdateQueue: func(queue chan<- watcher.AuthUpdate) {
|
setUpdateQueue: func(queue chan<- watcher.AuthUpdate) {
|
||||||
w.SetAuthUpdateQueue(queue)
|
w.SetAuthUpdateQueue(queue)
|
||||||
},
|
},
|
||||||
|
dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool {
|
||||||
|
return w.DispatchRuntimeAuthUpdate(update)
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user