Merge pull request #577 from router-for-me/refactor-watcher-phase3

Refactor-watcher-phase3
This commit is contained in:
Luis Pater
2025-12-17 17:53:04 +08:00
committed by GitHub
6 changed files with 1754 additions and 855 deletions

270
internal/watcher/clients.go Normal file
View File

@@ -0,0 +1,270 @@
// clients.go implements watcher client lifecycle logic and persistence helpers.
// It reloads clients, handles incremental auth file changes, and persists updates when supported.
package watcher
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
log.Debugf("starting full client load process")
w.clientsMutex.RLock()
cfg := w.config
w.clientsMutex.RUnlock()
if cfg == nil {
log.Error("config is nil, cannot reload clients")
return
}
if len(affectedOAuthProviders) > 0 {
w.clientsMutex.Lock()
if w.currentAuths != nil {
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
for id, auth := range w.currentAuths {
if auth == nil {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if _, match := matchProvider(provider, affectedOAuthProviders); match {
continue
}
filtered[id] = auth
}
w.currentAuths = filtered
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
} else {
w.currentAuths = nil
}
w.clientsMutex.Unlock()
}
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
var authFileCount int
if rescanAuth {
authFileCount = w.loadFileClients(cfg)
log.Debugf("loaded %d file-based clients", authFileCount)
} else {
w.clientsMutex.RLock()
authFileCount = len(w.lastAuthHashes)
w.clientsMutex.RUnlock()
log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount)
}
if rescanAuth {
w.clientsMutex.Lock()
w.lastAuthHashes = make(map[string]string)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
} else if resolvedAuthDir != "" {
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return nil
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
sum := sha256.Sum256(data)
normalizedPath := w.normalizeAuthPath(path)
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
}
}
return nil
})
}
w.clientsMutex.Unlock()
}
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
if w.reloadCallback != nil {
log.Debugf("triggering server update callback before auth refresh")
w.reloadCallback(cfg)
}
w.refreshAuthState(forceAuthRefresh)
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
totalNewClients,
authFileCount,
geminiAPIKeyCount,
vertexCompatAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,
openAICompatCount,
)
}
func (w *Watcher) addOrUpdateClient(path string) {
data, errRead := os.ReadFile(path)
if errRead != nil {
log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty auth file: %s", filepath.Base(path))
return
}
sum := sha256.Sum256(data)
curHash := hex.EncodeToString(sum[:])
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
if cfg == nil {
log.Error("config is nil, cannot add or update client")
w.clientsMutex.Unlock()
return
}
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
w.clientsMutex.Unlock()
return
}
w.lastAuthHashes[normalized] = curHash
w.clientsMutex.Unlock() // Unlock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after add/update")
w.reloadCallback(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
}
func (w *Watcher) removeClient(path string) {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
delete(w.lastAuthHashes, normalized)
w.clientsMutex.Unlock() // Release the lock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after removal")
w.reloadCallback(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
}
func (w *Watcher) loadFileClients(cfg *config.Config) int {
authFileCount := 0
successfulAuthCount := 0
authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir)
if errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
return 0
}
if authDir == "" {
return 0
}
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
log.Debugf("error accessing path %s: %v", path, err)
return err
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
authFileCount++
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
successfulAuthCount++
}
}
return nil
})
if errWalk != nil {
log.Errorf("error walking auth directory: %v", errWalk)
}
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
return authFileCount
}
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
geminiAPIKeyCount := 0
vertexCompatAPIKeyCount := 0
claudeAPIKeyCount := 0
codexAPIKeyCount := 0
openAICompatCount := 0
if len(cfg.GeminiKey) > 0 {
geminiAPIKeyCount += len(cfg.GeminiKey)
}
if len(cfg.VertexCompatAPIKey) > 0 {
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
}
if len(cfg.ClaudeKey) > 0 {
claudeAPIKeyCount += len(cfg.ClaudeKey)
}
if len(cfg.CodexKey) > 0 {
codexAPIKeyCount += len(cfg.CodexKey)
}
if len(cfg.OpenAICompatibility) > 0 {
for _, compatConfig := range cfg.OpenAICompatibility {
openAICompatCount += len(compatConfig.APIKeyEntries)
}
}
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
}
func (w *Watcher) persistConfigAsync() {
if w == nil || w.storePersister == nil {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := w.storePersister.PersistConfig(ctx); err != nil {
log.Errorf("failed to persist config change: %v", err)
}
}()
}
func (w *Watcher) persistAuthAsync(message string, paths ...string) {
if w == nil || w.storePersister == nil {
return
}
filtered := make([]string, 0, len(paths))
for _, p := range paths {
if trimmed := strings.TrimSpace(p); trimmed != "" {
filtered = append(filtered, trimmed)
}
}
if len(filtered) == 0 {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil {
log.Errorf("failed to persist auth changes: %v", err)
}
}()
}

View File

@@ -0,0 +1,134 @@
// config_reload.go implements debounced configuration hot reload.
// It detects material changes and reloads clients when the config changes.
package watcher
import (
"crypto/sha256"
"encoding/hex"
"os"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"gopkg.in/yaml.v3"
log "github.com/sirupsen/logrus"
)
func (w *Watcher) stopConfigReloadTimer() {
w.configReloadMu.Lock()
if w.configReloadTimer != nil {
w.configReloadTimer.Stop()
w.configReloadTimer = nil
}
w.configReloadMu.Unlock()
}
func (w *Watcher) scheduleConfigReload() {
w.configReloadMu.Lock()
defer w.configReloadMu.Unlock()
if w.configReloadTimer != nil {
w.configReloadTimer.Stop()
}
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
w.configReloadMu.Lock()
w.configReloadTimer = nil
w.configReloadMu.Unlock()
w.reloadConfigIfChanged()
})
}
func (w *Watcher) reloadConfigIfChanged() {
data, err := os.ReadFile(w.configPath)
if err != nil {
log.Errorf("failed to read config file for hash check: %v", err)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty config file write event")
return
}
sum := sha256.Sum256(data)
newHash := hex.EncodeToString(sum[:])
w.clientsMutex.RLock()
currentHash := w.lastConfigHash
w.clientsMutex.RUnlock()
if currentHash != "" && currentHash == newHash {
log.Debugf("config file content unchanged (hash match), skipping reload")
return
}
log.Infof("config file changed, reloading: %s", w.configPath)
if w.reloadConfig() {
finalHash := newHash
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
sumUpdated := sha256.Sum256(updatedData)
finalHash = hex.EncodeToString(sumUpdated[:])
} else if errRead != nil {
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
}
w.clientsMutex.Lock()
w.lastConfigHash = finalHash
w.clientsMutex.Unlock()
w.persistConfigAsync()
}
}
func (w *Watcher) reloadConfig() bool {
log.Debug("=========================== CONFIG RELOAD ============================")
log.Debugf("starting config reload from: %s", w.configPath)
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
if errLoadConfig != nil {
log.Errorf("failed to reload config: %v", errLoadConfig)
return false
}
if w.mirroredAuthDir != "" {
newConfig.AuthDir = w.mirroredAuthDir
} else {
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir)
} else {
newConfig.AuthDir = resolvedAuthDir
}
}
w.clientsMutex.Lock()
var oldConfig *config.Config
_ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig)
w.oldConfigYaml, _ = yaml.Marshal(newConfig)
w.config = newConfig
w.clientsMutex.Unlock()
var affectedOAuthProviders []string
if oldConfig != nil {
_, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
}
util.SetLogLevel(newConfig)
if oldConfig != nil && oldConfig.Debug != newConfig.Debug {
log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug)
}
if oldConfig != nil {
details := diff.BuildConfigChangeDetails(oldConfig, newConfig)
if len(details) > 0 {
log.Debugf("config changes detected:")
for _, d := range details {
log.Debugf(" %s", d)
}
} else {
log.Debugf("no material config field changes detected")
}
}
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
log.Infof("config successfully reloaded, triggering client reload")
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
return true
}

View File

@@ -0,0 +1,273 @@
// dispatcher.go implements auth update dispatching and queue management.
// It batches, deduplicates, and delivers auth updates to registered consumers.
package watcher
import (
"context"
"fmt"
"reflect"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
w.authQueue = queue
if w.dispatchCond == nil {
w.dispatchCond = sync.NewCond(&w.dispatchMu)
}
if w.dispatchCancel != nil {
w.dispatchCancel()
if w.dispatchCond != nil {
w.dispatchMu.Lock()
w.dispatchCond.Broadcast()
w.dispatchMu.Unlock()
}
w.dispatchCancel = nil
}
if queue != nil {
ctx, cancel := context.WithCancel(context.Background())
w.dispatchCancel = cancel
go w.dispatchLoop(ctx)
}
}
func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
if w == nil {
return false
}
w.clientsMutex.Lock()
if w.runtimeAuths == nil {
w.runtimeAuths = make(map[string]*coreauth.Auth)
}
switch update.Action {
case AuthUpdateActionAdd, AuthUpdateActionModify:
if update.Auth != nil && update.Auth.ID != "" {
clone := update.Auth.Clone()
w.runtimeAuths[clone.ID] = clone
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
w.currentAuths[clone.ID] = clone.Clone()
}
case AuthUpdateActionDelete:
id := update.ID
if id == "" && update.Auth != nil {
id = update.Auth.ID
}
if id != "" {
delete(w.runtimeAuths, id)
if w.currentAuths != nil {
delete(w.currentAuths, id)
}
}
}
w.clientsMutex.Unlock()
if w.getAuthQueue() == nil {
return false
}
w.dispatchAuthUpdates([]AuthUpdate{update})
return true
}
func (w *Watcher) refreshAuthState(force bool) {
auths := w.SnapshotCoreAuths()
w.clientsMutex.Lock()
if len(w.runtimeAuths) > 0 {
for _, a := range w.runtimeAuths {
if a != nil {
auths = append(auths, a.Clone())
}
}
}
updates := w.prepareAuthUpdatesLocked(auths, force)
w.clientsMutex.Unlock()
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate {
newState := make(map[string]*coreauth.Auth, len(auths))
for _, auth := range auths {
if auth == nil || auth.ID == "" {
continue
}
newState[auth.ID] = auth.Clone()
}
if w.currentAuths == nil {
w.currentAuths = newState
if w.authQueue == nil {
return nil
}
updates := make([]AuthUpdate, 0, len(newState))
for id, auth := range newState {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
}
return updates
}
if w.authQueue == nil {
w.currentAuths = newState
return nil
}
updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths))
for id, auth := range newState {
if existing, ok := w.currentAuths[id]; !ok {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
} else if force || !authEqual(existing, auth) {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()})
}
}
for id := range w.currentAuths {
if _, ok := newState[id]; !ok {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
}
}
w.currentAuths = newState
return updates
}
func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) {
if len(updates) == 0 {
return
}
queue := w.getAuthQueue()
if queue == nil {
return
}
baseTS := time.Now().UnixNano()
w.dispatchMu.Lock()
if w.pendingUpdates == nil {
w.pendingUpdates = make(map[string]AuthUpdate)
}
for idx, update := range updates {
key := w.authUpdateKey(update, baseTS+int64(idx))
if _, exists := w.pendingUpdates[key]; !exists {
w.pendingOrder = append(w.pendingOrder, key)
}
w.pendingUpdates[key] = update
}
if w.dispatchCond != nil {
w.dispatchCond.Signal()
}
w.dispatchMu.Unlock()
}
func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string {
if update.ID != "" {
return update.ID
}
return fmt.Sprintf("%s:%d", update.Action, ts)
}
func (w *Watcher) dispatchLoop(ctx context.Context) {
for {
batch, ok := w.nextPendingBatch(ctx)
if !ok {
return
}
queue := w.getAuthQueue()
if queue == nil {
if ctx.Err() != nil {
return
}
time.Sleep(10 * time.Millisecond)
continue
}
for _, update := range batch {
select {
case queue <- update:
case <-ctx.Done():
return
}
}
}
}
func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) {
w.dispatchMu.Lock()
defer w.dispatchMu.Unlock()
for len(w.pendingOrder) == 0 {
if ctx.Err() != nil {
return nil, false
}
w.dispatchCond.Wait()
if ctx.Err() != nil {
return nil, false
}
}
batch := make([]AuthUpdate, 0, len(w.pendingOrder))
for _, key := range w.pendingOrder {
batch = append(batch, w.pendingUpdates[key])
delete(w.pendingUpdates, key)
}
w.pendingOrder = w.pendingOrder[:0]
return batch, true
}
func (w *Watcher) getAuthQueue() chan<- AuthUpdate {
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
return w.authQueue
}
func (w *Watcher) stopDispatch() {
if w.dispatchCancel != nil {
w.dispatchCancel()
w.dispatchCancel = nil
}
w.dispatchMu.Lock()
w.pendingOrder = nil
w.pendingUpdates = nil
if w.dispatchCond != nil {
w.dispatchCond.Broadcast()
}
w.dispatchMu.Unlock()
w.clientsMutex.Lock()
w.authQueue = nil
w.clientsMutex.Unlock()
}
func authEqual(a, b *coreauth.Auth) bool {
return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b))
}
func normalizeAuth(a *coreauth.Auth) *coreauth.Auth {
if a == nil {
return nil
}
clone := a.Clone()
clone.CreatedAt = time.Time{}
clone.UpdatedAt = time.Time{}
clone.LastRefreshedAt = time.Time{}
clone.NextRefreshAfter = time.Time{}
clone.Runtime = nil
clone.Quota.NextRecoverAt = time.Time{}
return clone
}
func snapshotCoreAuths(cfg *config.Config, authDir string) []*coreauth.Auth {
ctx := &synthesizer.SynthesisContext{
Config: cfg,
AuthDir: authDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
var out []*coreauth.Auth
configSynth := synthesizer.NewConfigSynthesizer()
if auths, err := configSynth.Synthesize(ctx); err == nil {
out = append(out, auths...)
}
fileSynth := synthesizer.NewFileSynthesizer()
if auths, err := fileSynth.Synthesize(ctx); err == nil {
out = append(out, auths...)
}
return out
}

194
internal/watcher/events.go Normal file
View File

@@ -0,0 +1,194 @@
// events.go implements fsnotify event handling for config and auth file changes.
// It normalizes paths, debounces noisy events, and triggers reload/update logic.
package watcher
import (
"context"
"crypto/sha256"
"encoding/hex"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
func matchProvider(provider string, targets []string) (string, bool) {
p := strings.ToLower(strings.TrimSpace(provider))
for _, t := range targets {
if strings.EqualFold(p, strings.TrimSpace(t)) {
return p, true
}
}
return p, false
}
func (w *Watcher) start(ctx context.Context) error {
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
return errAddConfig
}
log.Debugf("watching config file: %s", w.configPath)
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
return errAddAuthDir
}
log.Debugf("watching auth directory: %s", w.authDir)
go w.processEvents(ctx)
w.reloadClients(true, nil, false)
return nil
}
func (w *Watcher) processEvents(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
w.handleEvent(event)
case errWatch, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Errorf("file watcher error: %v", errWatch)
}
}
}
func (w *Watcher) handleEvent(event fsnotify.Event) {
// Filter only relevant events: config file or auth-dir JSON files.
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
normalizedName := w.normalizeAuthPath(event.Name)
normalizedConfigPath := w.normalizeAuthPath(w.configPath)
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
if !isConfigEvent && !isAuthJSON {
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
return
}
now := time.Now()
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
// Handle config file changes
if isConfigEvent {
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
w.scheduleConfigReload()
return
}
// Handle auth directory changes incrementally (.json only)
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
if w.shouldDebounceRemove(normalizedName, now) {
log.Debugf("debouncing remove event for %s", filepath.Base(event.Name))
return
}
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
// Wait briefly; if the path exists again, treat as an update instead of removal.
time.Sleep(replaceCheckDelay)
if _, statErr := os.Stat(event.Name); statErr == nil {
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.addOrUpdateClient(event.Name)
return
}
if !w.isKnownAuthFile(event.Name) {
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.removeClient(event.Name)
return
}
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.addOrUpdateClient(event.Name)
}
}
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
data, errRead := os.ReadFile(path)
if errRead != nil {
return false, errRead
}
if len(data) == 0 {
return false, nil
}
sum := sha256.Sum256(data)
curHash := hex.EncodeToString(sum[:])
normalized := w.normalizeAuthPath(path)
w.clientsMutex.RLock()
prevHash, ok := w.lastAuthHashes[normalized]
w.clientsMutex.RUnlock()
if ok && prevHash == curHash {
return true, nil
}
return false, nil
}
func (w *Watcher) isKnownAuthFile(path string) bool {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
_, ok := w.lastAuthHashes[normalized]
return ok
}
func (w *Watcher) normalizeAuthPath(path string) string {
trimmed := strings.TrimSpace(path)
if trimmed == "" {
return ""
}
cleaned := filepath.Clean(trimmed)
if runtime.GOOS == "windows" {
cleaned = strings.TrimPrefix(cleaned, `\\?\`)
cleaned = strings.ToLower(cleaned)
}
return cleaned
}
func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool {
if normalizedPath == "" {
return false
}
w.clientsMutex.Lock()
if w.lastRemoveTimes == nil {
w.lastRemoveTimes = make(map[string]time.Time)
}
if last, ok := w.lastRemoveTimes[normalizedPath]; ok {
if now.Sub(last) < authRemoveDebounceWindow {
w.clientsMutex.Unlock()
return true
}
}
w.lastRemoveTimes[normalizedPath] = now
if len(w.lastRemoveTimes) > 128 {
cutoff := now.Add(-2 * authRemoveDebounceWindow)
for p, t := range w.lastRemoveTimes {
if t.Before(cutoff) {
delete(w.lastRemoveTimes, p)
}
}
}
w.clientsMutex.Unlock()
return false
}

View File

@@ -1,45 +1,22 @@
// Package watcher provides file system monitoring functionality for the CLI Proxy API.
// It watches configuration files and authentication directories for changes,
// automatically reloading clients and configuration when files are modified.
// The package handles cross-platform file system events and supports hot-reloading.
// Package watcher watches config/auth files and triggers hot reloads.
// It supports cross-platform fsnotify event handling.
package watcher
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io/fs"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
"gopkg.in/yaml.v3"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
func matchProvider(provider string, targets []string) (string, bool) {
p := strings.ToLower(strings.TrimSpace(provider))
for _, t := range targets {
if strings.EqualFold(p, strings.TrimSpace(t)) {
return p, true
}
}
return p, false
}
// storePersister captures persistence-capable token store methods used by the watcher.
type storePersister interface {
PersistConfig(ctx context.Context) error
@@ -131,26 +108,7 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
// Start begins watching the configuration file and authentication directory
func (w *Watcher) Start(ctx context.Context) error {
// Watch the config file
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
return errAddConfig
}
log.Debugf("watching config file: %s", w.configPath)
// Watch the auth directory
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
return errAddAuthDir
}
log.Debugf("watching auth directory: %s", w.authDir)
// Start the event processing goroutine
go w.processEvents(ctx)
// Perform an initial full reload based on current config and auth dir
w.reloadClients(true, nil, false)
return nil
return w.start(ctx)
}
// Stop stops the file watcher
@@ -160,15 +118,6 @@ func (w *Watcher) Stop() error {
return w.watcher.Close()
}
func (w *Watcher) stopConfigReloadTimer() {
w.configReloadMu.Lock()
if w.configReloadTimer != nil {
w.configReloadTimer.Stop()
w.configReloadTimer = nil
}
w.configReloadMu.Unlock()
}
// SetConfig updates the current configuration
func (w *Watcher) SetConfig(cfg *config.Config) {
w.clientsMutex.Lock()
@@ -179,818 +128,20 @@ func (w *Watcher) SetConfig(cfg *config.Config) {
// SetAuthUpdateQueue sets the queue used to emit auth updates.
func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
w.authQueue = queue
if w.dispatchCond == nil {
w.dispatchCond = sync.NewCond(&w.dispatchMu)
}
if w.dispatchCancel != nil {
w.dispatchCancel()
if w.dispatchCond != nil {
w.dispatchMu.Lock()
w.dispatchCond.Broadcast()
w.dispatchMu.Unlock()
}
w.dispatchCancel = nil
}
if queue != nil {
ctx, cancel := context.WithCancel(context.Background())
w.dispatchCancel = cancel
go w.dispatchLoop(ctx)
}
w.setAuthUpdateQueue(queue)
}
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
// to push auth updates through the same queue used by file/config watchers.
// Returns true if the update was enqueued; false if no queue is configured.
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
if w == nil {
return false
}
w.clientsMutex.Lock()
if w.runtimeAuths == nil {
w.runtimeAuths = make(map[string]*coreauth.Auth)
}
switch update.Action {
case AuthUpdateActionAdd, AuthUpdateActionModify:
if update.Auth != nil && update.Auth.ID != "" {
clone := update.Auth.Clone()
w.runtimeAuths[clone.ID] = clone
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
w.currentAuths[clone.ID] = clone.Clone()
}
case AuthUpdateActionDelete:
id := update.ID
if id == "" && update.Auth != nil {
id = update.Auth.ID
}
if id != "" {
delete(w.runtimeAuths, id)
if w.currentAuths != nil {
delete(w.currentAuths, id)
}
}
}
w.clientsMutex.Unlock()
if w.getAuthQueue() == nil {
return false
}
w.dispatchAuthUpdates([]AuthUpdate{update})
return true
return w.dispatchRuntimeAuthUpdate(update)
}
func (w *Watcher) refreshAuthState(force bool) {
auths := w.SnapshotCoreAuths()
w.clientsMutex.Lock()
if len(w.runtimeAuths) > 0 {
for _, a := range w.runtimeAuths {
if a != nil {
auths = append(auths, a.Clone())
}
}
}
updates := w.prepareAuthUpdatesLocked(auths, force)
w.clientsMutex.Unlock()
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate {
newState := make(map[string]*coreauth.Auth, len(auths))
for _, auth := range auths {
if auth == nil || auth.ID == "" {
continue
}
newState[auth.ID] = auth.Clone()
}
if w.currentAuths == nil {
w.currentAuths = newState
if w.authQueue == nil {
return nil
}
updates := make([]AuthUpdate, 0, len(newState))
for id, auth := range newState {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
}
return updates
}
if w.authQueue == nil {
w.currentAuths = newState
return nil
}
updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths))
for id, auth := range newState {
if existing, ok := w.currentAuths[id]; !ok {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
} else if force || !authEqual(existing, auth) {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()})
}
}
for id := range w.currentAuths {
if _, ok := newState[id]; !ok {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
}
}
w.currentAuths = newState
return updates
}
func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) {
if len(updates) == 0 {
return
}
queue := w.getAuthQueue()
if queue == nil {
return
}
baseTS := time.Now().UnixNano()
w.dispatchMu.Lock()
if w.pendingUpdates == nil {
w.pendingUpdates = make(map[string]AuthUpdate)
}
for idx, update := range updates {
key := w.authUpdateKey(update, baseTS+int64(idx))
if _, exists := w.pendingUpdates[key]; !exists {
w.pendingOrder = append(w.pendingOrder, key)
}
w.pendingUpdates[key] = update
}
if w.dispatchCond != nil {
w.dispatchCond.Signal()
}
w.dispatchMu.Unlock()
}
func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string {
if update.ID != "" {
return update.ID
}
return fmt.Sprintf("%s:%d", update.Action, ts)
}
func (w *Watcher) dispatchLoop(ctx context.Context) {
for {
batch, ok := w.nextPendingBatch(ctx)
if !ok {
return
}
queue := w.getAuthQueue()
if queue == nil {
if ctx.Err() != nil {
return
}
time.Sleep(10 * time.Millisecond)
continue
}
for _, update := range batch {
select {
case queue <- update:
case <-ctx.Done():
return
}
}
}
}
func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) {
w.dispatchMu.Lock()
defer w.dispatchMu.Unlock()
for len(w.pendingOrder) == 0 {
if ctx.Err() != nil {
return nil, false
}
w.dispatchCond.Wait()
if ctx.Err() != nil {
return nil, false
}
}
batch := make([]AuthUpdate, 0, len(w.pendingOrder))
for _, key := range w.pendingOrder {
batch = append(batch, w.pendingUpdates[key])
delete(w.pendingUpdates, key)
}
w.pendingOrder = w.pendingOrder[:0]
return batch, true
}
func (w *Watcher) getAuthQueue() chan<- AuthUpdate {
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
return w.authQueue
}
func (w *Watcher) stopDispatch() {
if w.dispatchCancel != nil {
w.dispatchCancel()
w.dispatchCancel = nil
}
w.dispatchMu.Lock()
w.pendingOrder = nil
w.pendingUpdates = nil
if w.dispatchCond != nil {
w.dispatchCond.Broadcast()
}
w.dispatchMu.Unlock()
w.clientsMutex.Lock()
w.authQueue = nil
w.clientsMutex.Unlock()
}
func (w *Watcher) persistConfigAsync() {
if w == nil || w.storePersister == nil {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := w.storePersister.PersistConfig(ctx); err != nil {
log.Errorf("failed to persist config change: %v", err)
}
}()
}
func (w *Watcher) persistAuthAsync(message string, paths ...string) {
if w == nil || w.storePersister == nil {
return
}
filtered := make([]string, 0, len(paths))
for _, p := range paths {
if trimmed := strings.TrimSpace(p); trimmed != "" {
filtered = append(filtered, trimmed)
}
}
if len(filtered) == 0 {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil {
log.Errorf("failed to persist auth changes: %v", err)
}
}()
}
func authEqual(a, b *coreauth.Auth) bool {
return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b))
}
func normalizeAuth(a *coreauth.Auth) *coreauth.Auth {
if a == nil {
return nil
}
clone := a.Clone()
clone.CreatedAt = time.Time{}
clone.UpdatedAt = time.Time{}
clone.LastRefreshedAt = time.Time{}
clone.NextRefreshAfter = time.Time{}
clone.Runtime = nil
clone.Quota.NextRecoverAt = time.Time{}
return clone
}
// SetClients sets the file-based clients.
// SetClients removed
// SetAPIKeyClients removed
// processEvents handles file system events
func (w *Watcher) processEvents(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
w.handleEvent(event)
case errWatch, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Errorf("file watcher error: %v", errWatch)
}
}
}
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
data, errRead := os.ReadFile(path)
if errRead != nil {
return false, errRead
}
if len(data) == 0 {
return false, nil
}
sum := sha256.Sum256(data)
curHash := hex.EncodeToString(sum[:])
normalized := w.normalizeAuthPath(path)
w.clientsMutex.RLock()
prevHash, ok := w.lastAuthHashes[normalized]
w.clientsMutex.RUnlock()
if ok && prevHash == curHash {
return true, nil
}
return false, nil
}
func (w *Watcher) isKnownAuthFile(path string) bool {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
_, ok := w.lastAuthHashes[normalized]
return ok
}
func (w *Watcher) normalizeAuthPath(path string) string {
trimmed := strings.TrimSpace(path)
if trimmed == "" {
return ""
}
cleaned := filepath.Clean(trimmed)
if runtime.GOOS == "windows" {
cleaned = strings.TrimPrefix(cleaned, `\\?\`)
cleaned = strings.ToLower(cleaned)
}
return cleaned
}
func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool {
if normalizedPath == "" {
return false
}
w.clientsMutex.Lock()
if w.lastRemoveTimes == nil {
w.lastRemoveTimes = make(map[string]time.Time)
}
if last, ok := w.lastRemoveTimes[normalizedPath]; ok {
if now.Sub(last) < authRemoveDebounceWindow {
w.clientsMutex.Unlock()
return true
}
}
w.lastRemoveTimes[normalizedPath] = now
if len(w.lastRemoveTimes) > 128 {
cutoff := now.Add(-2 * authRemoveDebounceWindow)
for p, t := range w.lastRemoveTimes {
if t.Before(cutoff) {
delete(w.lastRemoveTimes, p)
}
}
}
w.clientsMutex.Unlock()
return false
}
// handleEvent processes individual file system events
func (w *Watcher) handleEvent(event fsnotify.Event) {
// Filter only relevant events: config file or auth-dir JSON files.
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
normalizedName := w.normalizeAuthPath(event.Name)
normalizedConfigPath := w.normalizeAuthPath(w.configPath)
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
if !isConfigEvent && !isAuthJSON {
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
return
}
now := time.Now()
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
// Handle config file changes
if isConfigEvent {
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
w.scheduleConfigReload()
return
}
// Handle auth directory changes incrementally (.json only)
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
if w.shouldDebounceRemove(normalizedName, now) {
log.Debugf("debouncing remove event for %s", filepath.Base(event.Name))
return
}
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
// Wait briefly; if the path exists again, treat as an update instead of removal.
time.Sleep(replaceCheckDelay)
if _, statErr := os.Stat(event.Name); statErr == nil {
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.addOrUpdateClient(event.Name)
return
}
if !w.isKnownAuthFile(event.Name) {
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.removeClient(event.Name)
return
}
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.addOrUpdateClient(event.Name)
}
}
func (w *Watcher) scheduleConfigReload() {
w.configReloadMu.Lock()
defer w.configReloadMu.Unlock()
if w.configReloadTimer != nil {
w.configReloadTimer.Stop()
}
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
w.configReloadMu.Lock()
w.configReloadTimer = nil
w.configReloadMu.Unlock()
w.reloadConfigIfChanged()
})
}
func (w *Watcher) reloadConfigIfChanged() {
data, err := os.ReadFile(w.configPath)
if err != nil {
log.Errorf("failed to read config file for hash check: %v", err)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty config file write event")
return
}
sum := sha256.Sum256(data)
newHash := hex.EncodeToString(sum[:])
w.clientsMutex.RLock()
currentHash := w.lastConfigHash
w.clientsMutex.RUnlock()
if currentHash != "" && currentHash == newHash {
log.Debugf("config file content unchanged (hash match), skipping reload")
return
}
log.Infof("config file changed, reloading: %s", w.configPath)
if w.reloadConfig() {
finalHash := newHash
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
sumUpdated := sha256.Sum256(updatedData)
finalHash = hex.EncodeToString(sumUpdated[:])
} else if errRead != nil {
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
}
w.clientsMutex.Lock()
w.lastConfigHash = finalHash
w.clientsMutex.Unlock()
w.persistConfigAsync()
}
}
// reloadConfig reloads the configuration and triggers a full reload
func (w *Watcher) reloadConfig() bool {
log.Debug("=========================== CONFIG RELOAD ============================")
log.Debugf("starting config reload from: %s", w.configPath)
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
if errLoadConfig != nil {
log.Errorf("failed to reload config: %v", errLoadConfig)
return false
}
if w.mirroredAuthDir != "" {
newConfig.AuthDir = w.mirroredAuthDir
} else {
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir)
} else {
newConfig.AuthDir = resolvedAuthDir
}
}
w.clientsMutex.Lock()
var oldConfig *config.Config
_ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig)
w.oldConfigYaml, _ = yaml.Marshal(newConfig)
w.config = newConfig
w.clientsMutex.Unlock()
var affectedOAuthProviders []string
if oldConfig != nil {
_, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
}
// Always apply the current log level based on the latest config.
// This ensures logrus reflects the desired level even if change detection misses.
util.SetLogLevel(newConfig)
// Additional debug for visibility when the flag actually changes.
if oldConfig != nil && oldConfig.Debug != newConfig.Debug {
log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug)
}
// Log configuration changes in debug mode, only when there are material diffs
if oldConfig != nil {
details := diff.BuildConfigChangeDetails(oldConfig, newConfig)
if len(details) > 0 {
log.Debugf("config changes detected:")
for _, d := range details {
log.Debugf(" %s", d)
}
} else {
log.Debugf("no material config field changes detected")
}
}
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
log.Infof("config successfully reloaded, triggering client reload")
// Reload clients with new config
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
return true
}
// reloadClients performs a full scan and reload of all clients.
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
log.Debugf("starting full client load process")
w.clientsMutex.RLock()
cfg := w.config
w.clientsMutex.RUnlock()
if cfg == nil {
log.Error("config is nil, cannot reload clients")
return
}
if len(affectedOAuthProviders) > 0 {
w.clientsMutex.Lock()
if w.currentAuths != nil {
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
for id, auth := range w.currentAuths {
if auth == nil {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if _, match := matchProvider(provider, affectedOAuthProviders); match {
continue
}
filtered[id] = auth
}
w.currentAuths = filtered
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
} else {
w.currentAuths = nil
}
w.clientsMutex.Unlock()
}
// Unregister all old API key clients before creating new ones
// no legacy clients to unregister
// Create new API key clients based on the new config
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
var authFileCount int
if rescanAuth {
// Load file-based clients when explicitly requested (startup or authDir change)
authFileCount = w.loadFileClients(cfg)
log.Debugf("loaded %d file-based clients", authFileCount)
} else {
// Preserve existing auth hashes and only report current known count to avoid redundant scans.
w.clientsMutex.RLock()
authFileCount = len(w.lastAuthHashes)
w.clientsMutex.RUnlock()
log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount)
}
// no legacy file-based clients to unregister
// Update client maps
if rescanAuth {
w.clientsMutex.Lock()
// Rebuild auth file hash cache for current clients
w.lastAuthHashes = make(map[string]string)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
} else if resolvedAuthDir != "" {
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return nil
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
sum := sha256.Sum256(data)
normalizedPath := w.normalizeAuthPath(path)
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
}
}
return nil
})
}
w.clientsMutex.Unlock()
}
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
// Ensure consumers observe the new configuration before auth updates dispatch.
if w.reloadCallback != nil {
log.Debugf("triggering server update callback before auth refresh")
w.reloadCallback(cfg)
}
w.refreshAuthState(forceAuthRefresh)
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
totalNewClients,
authFileCount,
geminiAPIKeyCount,
vertexCompatAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,
openAICompatCount,
)
}
// createClientFromFile creates a single client instance from a given token file path.
// createClientFromFile removed (legacy)
// addOrUpdateClient handles the addition or update of a single client.
func (w *Watcher) addOrUpdateClient(path string) {
data, errRead := os.ReadFile(path)
if errRead != nil {
log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty auth file: %s", filepath.Base(path))
return
}
sum := sha256.Sum256(data)
curHash := hex.EncodeToString(sum[:])
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
if cfg == nil {
log.Error("config is nil, cannot add or update client")
w.clientsMutex.Unlock()
return
}
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
w.clientsMutex.Unlock()
return
}
// Update hash cache
w.lastAuthHashes[normalized] = curHash
w.clientsMutex.Unlock() // Unlock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after add/update")
w.reloadCallback(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
}
// removeClient handles the removal of a single client.
func (w *Watcher) removeClient(path string) {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
delete(w.lastAuthHashes, normalized)
w.clientsMutex.Unlock() // Release the lock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after removal")
w.reloadCallback(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
}
// SnapshotCombinedClients returns a snapshot of current combined clients.
// SnapshotCombinedClients removed
// SnapshotCoreAuths converts current clients snapshot into core auth entries.
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
w.clientsMutex.RLock()
cfg := w.config
w.clientsMutex.RUnlock()
ctx := &synthesizer.SynthesisContext{
Config: cfg,
AuthDir: w.authDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
var out []*coreauth.Auth
// Use ConfigSynthesizer for API key auth entries
configSynth := synthesizer.NewConfigSynthesizer()
if auths, err := configSynth.Synthesize(ctx); err == nil {
out = append(out, auths...)
}
// Use FileSynthesizer for file-based OAuth auth entries
fileSynth := synthesizer.NewFileSynthesizer()
if auths, err := fileSynth.Synthesize(ctx); err == nil {
out = append(out, auths...)
}
return out
}
// buildCombinedClientMap merges file-based clients with API key clients from the cache.
// buildCombinedClientMap removed
// unregisterClientWithReason attempts to call client-specific unregister hooks with context.
// unregisterClientWithReason removed
// loadFileClients scans the auth directory and creates clients from .json files.
func (w *Watcher) loadFileClients(cfg *config.Config) int {
authFileCount := 0
successfulAuthCount := 0
authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir)
if errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
return 0
}
if authDir == "" {
return 0
}
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
log.Debugf("error accessing path %s: %v", path, err)
return err
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
authFileCount++
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
// Count readable JSON files as successful auth entries
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
successfulAuthCount++
}
}
return nil
})
if errWalk != nil {
log.Errorf("error walking auth directory: %v", errWalk)
}
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
return authFileCount
}
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
geminiAPIKeyCount := 0
vertexCompatAPIKeyCount := 0
claudeAPIKeyCount := 0
codexAPIKeyCount := 0
openAICompatCount := 0
if len(cfg.GeminiKey) > 0 {
// Stateless executor handles Gemini API keys; avoid constructing legacy clients.
geminiAPIKeyCount += len(cfg.GeminiKey)
}
if len(cfg.VertexCompatAPIKey) > 0 {
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
}
if len(cfg.ClaudeKey) > 0 {
claudeAPIKeyCount += len(cfg.ClaudeKey)
}
if len(cfg.CodexKey) > 0 {
codexAPIKeyCount += len(cfg.CodexKey)
}
if len(cfg.OpenAICompatibility) > 0 {
// Do not construct legacy clients for OpenAI-compat providers; these are handled by the stateless executor.
for _, compatConfig := range cfg.OpenAICompatibility {
openAICompatCount += len(compatConfig.APIKeyEntries)
}
}
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
return snapshotCoreAuths(cfg, w.authDir)
}

View File

@@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -16,6 +17,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"gopkg.in/yaml.v3"
)
@@ -489,6 +491,28 @@ func TestAuthFileUnchangedUsesHash(t *testing.T) {
}
}
func TestAuthFileUnchangedEmptyAndMissing(t *testing.T) {
tmpDir := t.TempDir()
emptyFile := filepath.Join(tmpDir, "empty.json")
if err := os.WriteFile(emptyFile, []byte(""), 0o644); err != nil {
t.Fatalf("failed to write empty auth file: %v", err)
}
w := &Watcher{lastAuthHashes: make(map[string]string)}
unchanged, err := w.authFileUnchanged(emptyFile)
if err != nil {
t.Fatalf("unexpected error for empty file: %v", err)
}
if unchanged {
t.Fatal("expected empty file to be treated as changed")
}
_, err = w.authFileUnchanged(filepath.Join(tmpDir, "missing.json"))
if err == nil {
t.Fatal("expected error for missing auth file")
}
}
func TestReloadClientsCachesAuthHashes(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "one.json")
@@ -528,6 +552,23 @@ func TestReloadClientsLogsConfigDiffs(t *testing.T) {
w.reloadClients(false, nil, false)
}
func TestReloadClientsHandlesNilConfig(t *testing.T) {
w := &Watcher{}
w.reloadClients(true, nil, false)
}
func TestReloadClientsFiltersProvidersWithNilCurrentAuths(t *testing.T) {
tmp := t.TempDir()
w := &Watcher{
authDir: tmp,
config: &config.Config{AuthDir: tmp},
}
w.reloadClients(false, []string{"match"}, false)
if w.currentAuths != nil && len(w.currentAuths) != 0 {
t.Fatalf("expected currentAuths to be nil or empty, got %d", len(w.currentAuths))
}
}
func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
w := &Watcher{}
queue := make(chan AuthUpdate, 1)
@@ -541,6 +582,45 @@ func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
}
}
func TestPersistAsyncEarlyReturns(t *testing.T) {
var nilWatcher *Watcher
nilWatcher.persistConfigAsync()
nilWatcher.persistAuthAsync("msg", "a")
w := &Watcher{}
w.persistConfigAsync()
w.persistAuthAsync("msg", " ", "")
}
type errorPersister struct {
configCalls int32
authCalls int32
}
func (p *errorPersister) PersistConfig(context.Context) error {
atomic.AddInt32(&p.configCalls, 1)
return fmt.Errorf("persist config error")
}
func (p *errorPersister) PersistAuthFiles(context.Context, string, ...string) error {
atomic.AddInt32(&p.authCalls, 1)
return fmt.Errorf("persist auth error")
}
func TestPersistAsyncErrorPaths(t *testing.T) {
p := &errorPersister{}
w := &Watcher{storePersister: p}
w.persistConfigAsync()
w.persistAuthAsync("msg", "a")
time.Sleep(30 * time.Millisecond)
if atomic.LoadInt32(&p.configCalls) != 1 {
t.Fatalf("expected PersistConfig to be called once, got %d", p.configCalls)
}
if atomic.LoadInt32(&p.authCalls) != 1 {
t.Fatalf("expected PersistAuthFiles to be called once, got %d", p.authCalls)
}
}
func TestStopConfigReloadTimerSafeWhenNil(t *testing.T) {
w := &Watcher{}
w.stopConfigReloadTimer()
@@ -608,6 +688,803 @@ func TestDispatchAuthUpdatesFlushesQueue(t *testing.T) {
}
}
func TestDispatchLoopExitsOnContextDoneWhileSending(t *testing.T) {
queue := make(chan AuthUpdate) // unbuffered to block sends
w := &Watcher{
authQueue: queue,
pendingUpdates: map[string]AuthUpdate{
"k": {Action: AuthUpdateActionAdd, ID: "k"},
},
pendingOrder: []string{"k"},
}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.dispatchLoop(ctx)
close(done)
}()
time.Sleep(30 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("expected dispatchLoop to exit after ctx canceled while blocked on send")
}
}
func TestProcessEventsHandlesEventErrorAndChannelClose(t *testing.T) {
w := &Watcher{
watcher: &fsnotify.Watcher{
Events: make(chan fsnotify.Event, 2),
Errors: make(chan error, 2),
},
configPath: "config.yaml",
authDir: "auth",
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
w.processEvents(ctx)
close(done)
}()
w.watcher.Events <- fsnotify.Event{Name: "unrelated.txt", Op: fsnotify.Write}
w.watcher.Errors <- fmt.Errorf("watcher error")
time.Sleep(20 * time.Millisecond)
close(w.watcher.Events)
close(w.watcher.Errors)
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("processEvents did not exit after channels closed")
}
}
func TestProcessEventsReturnsWhenErrorsChannelClosed(t *testing.T) {
w := &Watcher{
watcher: &fsnotify.Watcher{
Events: nil,
Errors: make(chan error),
},
}
close(w.watcher.Errors)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
w.processEvents(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("processEvents did not exit after errors channel closed")
}
}
func TestHandleEventIgnoresUnrelatedFiles(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: filepath.Join(tmpDir, "note.txt"), Op: fsnotify.Write})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected no reloads for unrelated file, got %d", reloads)
}
}
func TestHandleEventConfigChangeSchedulesReload(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: configPath, Op: fsnotify.Write})
time.Sleep(400 * time.Millisecond)
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected config change to trigger reload once, got %d", reloads)
}
}
func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "a.json")
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected auth write to trigger reload callback, got %d", reloads)
}
}
func TestHandleEventRemoveDebounceSkips(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "remove.json")
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
lastRemoveTimes: map[string]time.Time{
filepath.Clean(authFile): time.Now(),
},
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected remove to be debounced, got %d", reloads)
}
}
func TestHandleEventAtomicReplaceUnchangedSkips(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "same.json")
content := []byte(`{"type":"demo"}`)
if err := os.WriteFile(authFile, content, 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
sum := sha256.Sum256(content)
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(sum[:])
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected unchanged atomic replace to be skipped, got %d", reloads)
}
}
func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "change.json")
oldContent := []byte(`{"type":"demo","v":1}`)
newContent := []byte(`{"type":"demo","v":2}`)
if err := os.WriteFile(authFile, newContent, 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
oldSum := sha256.Sum256(oldContent)
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads)
}
}
func TestHandleEventRemoveUnknownFileIgnored(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "unknown.json")
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected unknown remove to be ignored, got %d", reloads)
}
}
func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "known.json")
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected known remove to trigger reload, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected known auth hash to be deleted")
}
}
func TestNormalizeAuthPathAndDebounceCleanup(t *testing.T) {
w := &Watcher{}
if got := w.normalizeAuthPath(" "); got != "" {
t.Fatalf("expected empty normalize result, got %q", got)
}
if got := w.normalizeAuthPath(" a/../b "); got != filepath.Clean("a/../b") {
t.Fatalf("unexpected normalize result: %q", got)
}
w.clientsMutex.Lock()
w.lastRemoveTimes = make(map[string]time.Time, 140)
old := time.Now().Add(-3 * authRemoveDebounceWindow)
for i := 0; i < 129; i++ {
w.lastRemoveTimes[fmt.Sprintf("old-%d", i)] = old
}
w.clientsMutex.Unlock()
w.shouldDebounceRemove("new-path", time.Now())
w.clientsMutex.Lock()
gotLen := len(w.lastRemoveTimes)
w.clientsMutex.Unlock()
if gotLen >= 129 {
t.Fatalf("expected debounce cleanup to shrink map, got %d", gotLen)
}
}
func TestRefreshAuthStateDispatchesRuntimeAuths(t *testing.T) {
queue := make(chan AuthUpdate, 8)
w := &Watcher{
authDir: t.TempDir(),
lastAuthHashes: make(map[string]string),
}
w.SetConfig(&config.Config{AuthDir: w.authDir})
w.SetAuthUpdateQueue(queue)
defer w.stopDispatch()
w.clientsMutex.Lock()
w.runtimeAuths = map[string]*coreauth.Auth{
"nil": nil,
"r1": {ID: "r1", Provider: "runtime"},
}
w.clientsMutex.Unlock()
w.refreshAuthState(false)
select {
case u := <-queue:
if u.Action != AuthUpdateActionAdd || u.ID != "r1" {
t.Fatalf("unexpected auth update: %+v", u)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for runtime auth update")
}
}
func TestAddOrUpdateClientEdgeCases(t *testing.T) {
tmpDir := t.TempDir()
authDir := tmpDir
authFile := filepath.Join(tmpDir, "edge.json")
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
emptyFile := filepath.Join(tmpDir, "empty.json")
if err := os.WriteFile(emptyFile, []byte(""), 0o644); err != nil {
t.Fatalf("failed to write empty auth file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: authDir,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.addOrUpdateClient(filepath.Join(tmpDir, "missing.json"))
w.addOrUpdateClient(emptyFile)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected no reloads for missing/empty file, got %d", reloads)
}
w.addOrUpdateClient(authFile) // config nil -> should not panic or update
if len(w.lastAuthHashes) != 0 {
t.Fatalf("expected no hash entries without config, got %d", len(w.lastAuthHashes))
}
}
func TestLoadFileClientsWalkError(t *testing.T) {
tmpDir := t.TempDir()
noAccessDir := filepath.Join(tmpDir, "0noaccess")
if err := os.MkdirAll(noAccessDir, 0o755); err != nil {
t.Fatalf("failed to create noaccess dir: %v", err)
}
if err := os.Chmod(noAccessDir, 0); err != nil {
t.Skipf("chmod not supported: %v", err)
}
defer func() { _ = os.Chmod(noAccessDir, 0o755) }()
cfg := &config.Config{AuthDir: tmpDir}
w := &Watcher{}
w.SetConfig(cfg)
count := w.loadFileClients(cfg)
if count != 0 {
t.Fatalf("expected count 0 due to walk error, got %d", count)
}
}
func TestReloadConfigIfChangedHandlesMissingAndEmpty(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
w := &Watcher{
configPath: filepath.Join(tmpDir, "missing.yaml"),
authDir: authDir,
}
w.reloadConfigIfChanged() // missing file -> log + return
emptyPath := filepath.Join(tmpDir, "empty.yaml")
if err := os.WriteFile(emptyPath, []byte(""), 0o644); err != nil {
t.Fatalf("failed to write empty config: %v", err)
}
w.configPath = emptyPath
w.reloadConfigIfChanged() // empty file -> early return
}
func TestReloadConfigUsesMirroredAuthDir(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+filepath.Join(tmpDir, "other")+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
w := &Watcher{
configPath: configPath,
authDir: authDir,
mirroredAuthDir: authDir,
lastAuthHashes: make(map[string]string),
}
w.SetConfig(&config.Config{AuthDir: authDir})
if ok := w.reloadConfig(); !ok {
t.Fatal("expected reloadConfig to succeed")
}
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
if w.config == nil || w.config.AuthDir != authDir {
t.Fatalf("expected AuthDir to be overridden by mirroredAuthDir %s, got %+v", authDir, w.config)
}
}
func TestReloadConfigFiltersAffectedOAuthProviders(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
// Ensure SnapshotCoreAuths yields a provider that is NOT affected, so we can assert it survives.
if err := os.WriteFile(filepath.Join(authDir, "provider-b.json"), []byte(`{"type":"provider-b","email":"b@example.com"}`), 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
oldCfg := &config.Config{
AuthDir: authDir,
OAuthExcludedModels: map[string][]string{
"provider-a": {"m1"},
},
}
newCfg := &config.Config{
AuthDir: authDir,
OAuthExcludedModels: map[string][]string{
"provider-a": {"m2"},
},
}
data, err := yaml.Marshal(newCfg)
if err != nil {
t.Fatalf("failed to marshal config: %v", err)
}
if err = os.WriteFile(configPath, data, 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
w := &Watcher{
configPath: configPath,
authDir: authDir,
lastAuthHashes: make(map[string]string),
currentAuths: map[string]*coreauth.Auth{
"a": {ID: "a", Provider: "provider-a"},
},
}
w.SetConfig(oldCfg)
if ok := w.reloadConfig(); !ok {
t.Fatal("expected reloadConfig to succeed")
}
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
for _, auth := range w.currentAuths {
if auth != nil && auth.Provider == "provider-a" {
t.Fatal("expected affected provider auth to be filtered")
}
}
foundB := false
for _, auth := range w.currentAuths {
if auth != nil && auth.Provider == "provider-b" {
foundB = true
break
}
}
if !foundB {
t.Fatal("expected unaffected provider auth to remain")
}
}
func TestStartFailsWhenAuthDirMissing(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+filepath.Join(tmpDir, "missing-auth")+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authDir := filepath.Join(tmpDir, "missing-auth")
w, err := NewWatcher(configPath, authDir, nil)
if err != nil {
t.Fatalf("failed to create watcher: %v", err)
}
defer w.Stop()
w.SetConfig(&config.Config{AuthDir: authDir})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := w.Start(ctx); err == nil {
t.Fatal("expected Start to fail for missing auth dir")
}
}
func TestDispatchRuntimeAuthUpdateReturnsFalseWithoutQueue(t *testing.T) {
w := &Watcher{}
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionAdd, Auth: &coreauth.Auth{ID: "a"}}); ok {
t.Fatal("expected DispatchRuntimeAuthUpdate to return false when no queue configured")
}
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionDelete, Auth: &coreauth.Auth{ID: "a"}}); ok {
t.Fatal("expected DispatchRuntimeAuthUpdate delete to return false when no queue configured")
}
}
func TestNormalizeAuthNil(t *testing.T) {
if normalizeAuth(nil) != nil {
t.Fatal("expected normalizeAuth(nil) to return nil")
}
}
// stubStore implements coreauth.Store plus watcher-specific persistence helpers.
type stubStore struct {
authDir string
cfgPersisted int32
authPersisted int32
lastAuthMessage string
lastAuthPaths []string
}
func (s *stubStore) List(context.Context) ([]*coreauth.Auth, error) { return nil, nil }
func (s *stubStore) Save(context.Context, *coreauth.Auth) (string, error) {
return "", nil
}
func (s *stubStore) Delete(context.Context, string) error { return nil }
func (s *stubStore) PersistConfig(context.Context) error {
atomic.AddInt32(&s.cfgPersisted, 1)
return nil
}
func (s *stubStore) PersistAuthFiles(_ context.Context, message string, paths ...string) error {
atomic.AddInt32(&s.authPersisted, 1)
s.lastAuthMessage = message
s.lastAuthPaths = paths
return nil
}
func (s *stubStore) AuthDir() string { return s.authDir }
func TestNewWatcherDetectsPersisterAndAuthDir(t *testing.T) {
tmp := t.TempDir()
store := &stubStore{authDir: tmp}
orig := sdkAuth.GetTokenStore()
sdkAuth.RegisterTokenStore(store)
defer sdkAuth.RegisterTokenStore(orig)
w, err := NewWatcher("config.yaml", "auth", nil)
if err != nil {
t.Fatalf("NewWatcher failed: %v", err)
}
if w.storePersister == nil {
t.Fatal("expected storePersister to be set from token store")
}
if w.mirroredAuthDir != tmp {
t.Fatalf("expected mirroredAuthDir %s, got %s", tmp, w.mirroredAuthDir)
}
}
func TestPersistConfigAndAuthAsyncInvokePersister(t *testing.T) {
w := &Watcher{
storePersister: &stubStore{},
}
w.persistConfigAsync()
w.persistAuthAsync("msg", " a ", "", "b ")
time.Sleep(30 * time.Millisecond)
store := w.storePersister.(*stubStore)
if atomic.LoadInt32(&store.cfgPersisted) != 1 {
t.Fatalf("expected PersistConfig to be called once, got %d", store.cfgPersisted)
}
if atomic.LoadInt32(&store.authPersisted) != 1 {
t.Fatalf("expected PersistAuthFiles to be called once, got %d", store.authPersisted)
}
if store.lastAuthMessage != "msg" {
t.Fatalf("unexpected auth message: %s", store.lastAuthMessage)
}
if len(store.lastAuthPaths) != 2 || store.lastAuthPaths[0] != "a" || store.lastAuthPaths[1] != "b" {
t.Fatalf("unexpected filtered paths: %#v", store.lastAuthPaths)
}
}
func TestScheduleConfigReloadDebounces(t *testing.T) {
tmp := t.TempDir()
authDir := tmp
cfgPath := tmp + "/config.yaml"
if err := os.WriteFile(cfgPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
var reloads int32
w := &Watcher{
configPath: cfgPath,
authDir: authDir,
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.scheduleConfigReload()
w.scheduleConfigReload()
time.Sleep(400 * time.Millisecond)
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected single debounced reload, got %d", reloads)
}
if w.lastConfigHash == "" {
t.Fatal("expected lastConfigHash to be set after reload")
}
}
func TestPrepareAuthUpdatesLockedForceAndDelete(t *testing.T) {
w := &Watcher{
currentAuths: map[string]*coreauth.Auth{
"a": {ID: "a", Provider: "p1"},
},
authQueue: make(chan AuthUpdate, 4),
}
updates := w.prepareAuthUpdatesLocked([]*coreauth.Auth{{ID: "a", Provider: "p2"}}, false)
if len(updates) != 1 || updates[0].Action != AuthUpdateActionModify || updates[0].ID != "a" {
t.Fatalf("unexpected modify updates: %+v", updates)
}
updates = w.prepareAuthUpdatesLocked([]*coreauth.Auth{{ID: "a", Provider: "p2"}}, true)
if len(updates) != 1 || updates[0].Action != AuthUpdateActionModify {
t.Fatalf("expected force modify, got %+v", updates)
}
updates = w.prepareAuthUpdatesLocked([]*coreauth.Auth{}, false)
if len(updates) != 1 || updates[0].Action != AuthUpdateActionDelete || updates[0].ID != "a" {
t.Fatalf("expected delete for missing auth, got %+v", updates)
}
}
func TestAuthEqualIgnoresTemporalFields(t *testing.T) {
now := time.Now()
a := &coreauth.Auth{ID: "x", CreatedAt: now}
b := &coreauth.Auth{ID: "x", CreatedAt: now.Add(5 * time.Second)}
if !authEqual(a, b) {
t.Fatal("expected authEqual to ignore temporal differences")
}
}
func TestDispatchLoopExitsWhenQueueNilAndContextCanceled(t *testing.T) {
w := &Watcher{
dispatchCond: nil,
pendingUpdates: map[string]AuthUpdate{"k": {ID: "k"}},
pendingOrder: []string{"k"},
}
w.dispatchCond = sync.NewCond(&w.dispatchMu)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.dispatchLoop(ctx)
close(done)
}()
time.Sleep(20 * time.Millisecond)
cancel()
w.dispatchMu.Lock()
w.dispatchCond.Broadcast()
w.dispatchMu.Unlock()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("dispatchLoop did not exit after context cancel")
}
}
func TestReloadClientsFiltersOAuthProvidersWithoutRescan(t *testing.T) {
tmp := t.TempDir()
w := &Watcher{
authDir: tmp,
config: &config.Config{AuthDir: tmp},
currentAuths: map[string]*coreauth.Auth{
"a": {ID: "a", Provider: "Match"},
"b": {ID: "b", Provider: "other"},
},
lastAuthHashes: map[string]string{"cached": "hash"},
}
w.reloadClients(false, []string{"match"}, false)
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
if _, ok := w.currentAuths["a"]; ok {
t.Fatal("expected filtered provider to be removed")
}
if len(w.lastAuthHashes) != 1 {
t.Fatalf("expected existing hash cache to be retained, got %d", len(w.lastAuthHashes))
}
}
func TestScheduleProcessEventsStopsOnContextDone(t *testing.T) {
w := &Watcher{
watcher: &fsnotify.Watcher{
Events: make(chan fsnotify.Event, 1),
Errors: make(chan error, 1),
},
configPath: "config.yaml",
authDir: "auth",
}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.processEvents(ctx)
close(done)
}()
cancel()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("processEvents did not exit on context cancel")
}
}
func hexString(data []byte) string {
return strings.ToLower(fmt.Sprintf("%x", data))
}