diff --git a/internal/watcher/clients.go b/internal/watcher/clients.go new file mode 100644 index 00000000..5cd8b6e6 --- /dev/null +++ b/internal/watcher/clients.go @@ -0,0 +1,270 @@ +// clients.go implements watcher client lifecycle logic and persistence helpers. +// It reloads clients, handles incremental auth file changes, and persists updates when supported. +package watcher + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) { + log.Debugf("starting full client load process") + + w.clientsMutex.RLock() + cfg := w.config + w.clientsMutex.RUnlock() + + if cfg == nil { + log.Error("config is nil, cannot reload clients") + return + } + + if len(affectedOAuthProviders) > 0 { + w.clientsMutex.Lock() + if w.currentAuths != nil { + filtered := make(map[string]*coreauth.Auth, len(w.currentAuths)) + for id, auth := range w.currentAuths { + if auth == nil { + continue + } + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if _, match := matchProvider(provider, affectedOAuthProviders); match { + continue + } + filtered[id] = auth + } + w.currentAuths = filtered + log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders) + } else { + w.currentAuths = nil + } + w.clientsMutex.Unlock() + } + + geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) + totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + log.Debugf("loaded %d API key clients", totalAPIKeyClients) + + var authFileCount int + if rescanAuth { + authFileCount = w.loadFileClients(cfg) + log.Debugf("loaded %d file-based clients", authFileCount) + } else { + w.clientsMutex.RLock() + authFileCount = len(w.lastAuthHashes) + w.clientsMutex.RUnlock() + log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount) + } + + if rescanAuth { + w.clientsMutex.Lock() + + w.lastAuthHashes = make(map[string]string) + if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil { + log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir) + } else if resolvedAuthDir != "" { + _ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return nil + } + if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") { + if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 { + sum := sha256.Sum256(data) + normalizedPath := w.normalizeAuthPath(path) + w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:]) + } + } + return nil + }) + } + w.clientsMutex.Unlock() + } + + totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + + if w.reloadCallback != nil { + log.Debugf("triggering server update callback before auth refresh") + w.reloadCallback(cfg) + } + + w.refreshAuthState(forceAuthRefresh) + + log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", + totalNewClients, + authFileCount, + geminiAPIKeyCount, + vertexCompatAPIKeyCount, + claudeAPIKeyCount, + codexAPIKeyCount, + openAICompatCount, + ) +} + +func (w *Watcher) addOrUpdateClient(path string) { + data, errRead := os.ReadFile(path) + if errRead != nil { + log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead) + return + } + if len(data) == 0 { + log.Debugf("ignoring empty auth file: %s", filepath.Base(path)) + return + } + + sum := sha256.Sum256(data) + curHash := hex.EncodeToString(sum[:]) + normalized := w.normalizeAuthPath(path) + + w.clientsMutex.Lock() + + cfg := w.config + if cfg == nil { + log.Error("config is nil, cannot add or update client") + w.clientsMutex.Unlock() + return + } + if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash { + log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path)) + w.clientsMutex.Unlock() + return + } + + w.lastAuthHashes[normalized] = curHash + + w.clientsMutex.Unlock() // Unlock before the callback + + w.refreshAuthState(false) + + if w.reloadCallback != nil { + log.Debugf("triggering server update callback after add/update") + w.reloadCallback(cfg) + } + w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path) +} + +func (w *Watcher) removeClient(path string) { + normalized := w.normalizeAuthPath(path) + w.clientsMutex.Lock() + + cfg := w.config + delete(w.lastAuthHashes, normalized) + + w.clientsMutex.Unlock() // Release the lock before the callback + + w.refreshAuthState(false) + + if w.reloadCallback != nil { + log.Debugf("triggering server update callback after removal") + w.reloadCallback(cfg) + } + w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path) +} + +func (w *Watcher) loadFileClients(cfg *config.Config) int { + authFileCount := 0 + successfulAuthCount := 0 + + authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir) + if errResolveAuthDir != nil { + log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir) + return 0 + } + if authDir == "" { + return 0 + } + + errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error { + if err != nil { + log.Debugf("error accessing path %s: %v", path, err) + return err + } + if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") { + authFileCount++ + log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path)) + if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 { + successfulAuthCount++ + } + } + return nil + }) + + if errWalk != nil { + log.Errorf("error walking auth directory: %v", errWalk) + } + log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount) + return authFileCount +} + +func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) { + geminiAPIKeyCount := 0 + vertexCompatAPIKeyCount := 0 + claudeAPIKeyCount := 0 + codexAPIKeyCount := 0 + openAICompatCount := 0 + + if len(cfg.GeminiKey) > 0 { + geminiAPIKeyCount += len(cfg.GeminiKey) + } + if len(cfg.VertexCompatAPIKey) > 0 { + vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey) + } + if len(cfg.ClaudeKey) > 0 { + claudeAPIKeyCount += len(cfg.ClaudeKey) + } + if len(cfg.CodexKey) > 0 { + codexAPIKeyCount += len(cfg.CodexKey) + } + if len(cfg.OpenAICompatibility) > 0 { + for _, compatConfig := range cfg.OpenAICompatibility { + openAICompatCount += len(compatConfig.APIKeyEntries) + } + } + return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount +} + +func (w *Watcher) persistConfigAsync() { + if w == nil || w.storePersister == nil { + return + } + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := w.storePersister.PersistConfig(ctx); err != nil { + log.Errorf("failed to persist config change: %v", err) + } + }() +} + +func (w *Watcher) persistAuthAsync(message string, paths ...string) { + if w == nil || w.storePersister == nil { + return + } + filtered := make([]string, 0, len(paths)) + for _, p := range paths { + if trimmed := strings.TrimSpace(p); trimmed != "" { + filtered = append(filtered, trimmed) + } + } + if len(filtered) == 0 { + return + } + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil { + log.Errorf("failed to persist auth changes: %v", err) + } + }() +} diff --git a/internal/watcher/config_reload.go b/internal/watcher/config_reload.go new file mode 100644 index 00000000..244f738e --- /dev/null +++ b/internal/watcher/config_reload.go @@ -0,0 +1,134 @@ +// config_reload.go implements debounced configuration hot reload. +// It detects material changes and reloads clients when the config changes. +package watcher + +import ( + "crypto/sha256" + "encoding/hex" + "os" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" + "gopkg.in/yaml.v3" + + log "github.com/sirupsen/logrus" +) + +func (w *Watcher) stopConfigReloadTimer() { + w.configReloadMu.Lock() + if w.configReloadTimer != nil { + w.configReloadTimer.Stop() + w.configReloadTimer = nil + } + w.configReloadMu.Unlock() +} + +func (w *Watcher) scheduleConfigReload() { + w.configReloadMu.Lock() + defer w.configReloadMu.Unlock() + if w.configReloadTimer != nil { + w.configReloadTimer.Stop() + } + w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() { + w.configReloadMu.Lock() + w.configReloadTimer = nil + w.configReloadMu.Unlock() + w.reloadConfigIfChanged() + }) +} + +func (w *Watcher) reloadConfigIfChanged() { + data, err := os.ReadFile(w.configPath) + if err != nil { + log.Errorf("failed to read config file for hash check: %v", err) + return + } + if len(data) == 0 { + log.Debugf("ignoring empty config file write event") + return + } + sum := sha256.Sum256(data) + newHash := hex.EncodeToString(sum[:]) + + w.clientsMutex.RLock() + currentHash := w.lastConfigHash + w.clientsMutex.RUnlock() + + if currentHash != "" && currentHash == newHash { + log.Debugf("config file content unchanged (hash match), skipping reload") + return + } + log.Infof("config file changed, reloading: %s", w.configPath) + if w.reloadConfig() { + finalHash := newHash + if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 { + sumUpdated := sha256.Sum256(updatedData) + finalHash = hex.EncodeToString(sumUpdated[:]) + } else if errRead != nil { + log.WithError(errRead).Debug("failed to compute updated config hash after reload") + } + w.clientsMutex.Lock() + w.lastConfigHash = finalHash + w.clientsMutex.Unlock() + w.persistConfigAsync() + } +} + +func (w *Watcher) reloadConfig() bool { + log.Debug("=========================== CONFIG RELOAD ============================") + log.Debugf("starting config reload from: %s", w.configPath) + + newConfig, errLoadConfig := config.LoadConfig(w.configPath) + if errLoadConfig != nil { + log.Errorf("failed to reload config: %v", errLoadConfig) + return false + } + + if w.mirroredAuthDir != "" { + newConfig.AuthDir = w.mirroredAuthDir + } else { + if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil { + log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir) + } else { + newConfig.AuthDir = resolvedAuthDir + } + } + + w.clientsMutex.Lock() + var oldConfig *config.Config + _ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig) + w.oldConfigYaml, _ = yaml.Marshal(newConfig) + w.config = newConfig + w.clientsMutex.Unlock() + + var affectedOAuthProviders []string + if oldConfig != nil { + _, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels) + } + + util.SetLogLevel(newConfig) + if oldConfig != nil && oldConfig.Debug != newConfig.Debug { + log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug) + } + + if oldConfig != nil { + details := diff.BuildConfigChangeDetails(oldConfig, newConfig) + if len(details) > 0 { + log.Debugf("config changes detected:") + for _, d := range details { + log.Debugf(" %s", d) + } + } else { + log.Debugf("no material config field changes detected") + } + } + + authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir + forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix + + log.Infof("config successfully reloaded, triggering client reload") + w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh) + return true +} diff --git a/internal/watcher/dispatcher.go b/internal/watcher/dispatcher.go new file mode 100644 index 00000000..ff3c5b63 --- /dev/null +++ b/internal/watcher/dispatcher.go @@ -0,0 +1,273 @@ +// dispatcher.go implements auth update dispatching and queue management. +// It batches, deduplicates, and delivers auth updates to registered consumers. +package watcher + +import ( + "context" + "fmt" + "reflect" + "sync" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) { + w.clientsMutex.Lock() + defer w.clientsMutex.Unlock() + w.authQueue = queue + if w.dispatchCond == nil { + w.dispatchCond = sync.NewCond(&w.dispatchMu) + } + if w.dispatchCancel != nil { + w.dispatchCancel() + if w.dispatchCond != nil { + w.dispatchMu.Lock() + w.dispatchCond.Broadcast() + w.dispatchMu.Unlock() + } + w.dispatchCancel = nil + } + if queue != nil { + ctx, cancel := context.WithCancel(context.Background()) + w.dispatchCancel = cancel + go w.dispatchLoop(ctx) + } +} + +func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool { + if w == nil { + return false + } + w.clientsMutex.Lock() + if w.runtimeAuths == nil { + w.runtimeAuths = make(map[string]*coreauth.Auth) + } + switch update.Action { + case AuthUpdateActionAdd, AuthUpdateActionModify: + if update.Auth != nil && update.Auth.ID != "" { + clone := update.Auth.Clone() + w.runtimeAuths[clone.ID] = clone + if w.currentAuths == nil { + w.currentAuths = make(map[string]*coreauth.Auth) + } + w.currentAuths[clone.ID] = clone.Clone() + } + case AuthUpdateActionDelete: + id := update.ID + if id == "" && update.Auth != nil { + id = update.Auth.ID + } + if id != "" { + delete(w.runtimeAuths, id) + if w.currentAuths != nil { + delete(w.currentAuths, id) + } + } + } + w.clientsMutex.Unlock() + if w.getAuthQueue() == nil { + return false + } + w.dispatchAuthUpdates([]AuthUpdate{update}) + return true +} + +func (w *Watcher) refreshAuthState(force bool) { + auths := w.SnapshotCoreAuths() + w.clientsMutex.Lock() + if len(w.runtimeAuths) > 0 { + for _, a := range w.runtimeAuths { + if a != nil { + auths = append(auths, a.Clone()) + } + } + } + updates := w.prepareAuthUpdatesLocked(auths, force) + w.clientsMutex.Unlock() + w.dispatchAuthUpdates(updates) +} + +func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate { + newState := make(map[string]*coreauth.Auth, len(auths)) + for _, auth := range auths { + if auth == nil || auth.ID == "" { + continue + } + newState[auth.ID] = auth.Clone() + } + if w.currentAuths == nil { + w.currentAuths = newState + if w.authQueue == nil { + return nil + } + updates := make([]AuthUpdate, 0, len(newState)) + for id, auth := range newState { + updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()}) + } + return updates + } + if w.authQueue == nil { + w.currentAuths = newState + return nil + } + updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths)) + for id, auth := range newState { + if existing, ok := w.currentAuths[id]; !ok { + updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()}) + } else if force || !authEqual(existing, auth) { + updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()}) + } + } + for id := range w.currentAuths { + if _, ok := newState[id]; !ok { + updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id}) + } + } + w.currentAuths = newState + return updates +} + +func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) { + if len(updates) == 0 { + return + } + queue := w.getAuthQueue() + if queue == nil { + return + } + baseTS := time.Now().UnixNano() + w.dispatchMu.Lock() + if w.pendingUpdates == nil { + w.pendingUpdates = make(map[string]AuthUpdate) + } + for idx, update := range updates { + key := w.authUpdateKey(update, baseTS+int64(idx)) + if _, exists := w.pendingUpdates[key]; !exists { + w.pendingOrder = append(w.pendingOrder, key) + } + w.pendingUpdates[key] = update + } + if w.dispatchCond != nil { + w.dispatchCond.Signal() + } + w.dispatchMu.Unlock() +} + +func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string { + if update.ID != "" { + return update.ID + } + return fmt.Sprintf("%s:%d", update.Action, ts) +} + +func (w *Watcher) dispatchLoop(ctx context.Context) { + for { + batch, ok := w.nextPendingBatch(ctx) + if !ok { + return + } + queue := w.getAuthQueue() + if queue == nil { + if ctx.Err() != nil { + return + } + time.Sleep(10 * time.Millisecond) + continue + } + for _, update := range batch { + select { + case queue <- update: + case <-ctx.Done(): + return + } + } + } +} + +func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) { + w.dispatchMu.Lock() + defer w.dispatchMu.Unlock() + for len(w.pendingOrder) == 0 { + if ctx.Err() != nil { + return nil, false + } + w.dispatchCond.Wait() + if ctx.Err() != nil { + return nil, false + } + } + batch := make([]AuthUpdate, 0, len(w.pendingOrder)) + for _, key := range w.pendingOrder { + batch = append(batch, w.pendingUpdates[key]) + delete(w.pendingUpdates, key) + } + w.pendingOrder = w.pendingOrder[:0] + return batch, true +} + +func (w *Watcher) getAuthQueue() chan<- AuthUpdate { + w.clientsMutex.RLock() + defer w.clientsMutex.RUnlock() + return w.authQueue +} + +func (w *Watcher) stopDispatch() { + if w.dispatchCancel != nil { + w.dispatchCancel() + w.dispatchCancel = nil + } + w.dispatchMu.Lock() + w.pendingOrder = nil + w.pendingUpdates = nil + if w.dispatchCond != nil { + w.dispatchCond.Broadcast() + } + w.dispatchMu.Unlock() + w.clientsMutex.Lock() + w.authQueue = nil + w.clientsMutex.Unlock() +} + +func authEqual(a, b *coreauth.Auth) bool { + return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b)) +} + +func normalizeAuth(a *coreauth.Auth) *coreauth.Auth { + if a == nil { + return nil + } + clone := a.Clone() + clone.CreatedAt = time.Time{} + clone.UpdatedAt = time.Time{} + clone.LastRefreshedAt = time.Time{} + clone.NextRefreshAfter = time.Time{} + clone.Runtime = nil + clone.Quota.NextRecoverAt = time.Time{} + return clone +} + +func snapshotCoreAuths(cfg *config.Config, authDir string) []*coreauth.Auth { + ctx := &synthesizer.SynthesisContext{ + Config: cfg, + AuthDir: authDir, + Now: time.Now(), + IDGenerator: synthesizer.NewStableIDGenerator(), + } + + var out []*coreauth.Auth + + configSynth := synthesizer.NewConfigSynthesizer() + if auths, err := configSynth.Synthesize(ctx); err == nil { + out = append(out, auths...) + } + + fileSynth := synthesizer.NewFileSynthesizer() + if auths, err := fileSynth.Synthesize(ctx); err == nil { + out = append(out, auths...) + } + + return out +} diff --git a/internal/watcher/events.go b/internal/watcher/events.go new file mode 100644 index 00000000..250cf75c --- /dev/null +++ b/internal/watcher/events.go @@ -0,0 +1,194 @@ +// events.go implements fsnotify event handling for config and auth file changes. +// It normalizes paths, debounces noisy events, and triggers reload/update logic. +package watcher + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/fsnotify/fsnotify" + log "github.com/sirupsen/logrus" +) + +func matchProvider(provider string, targets []string) (string, bool) { + p := strings.ToLower(strings.TrimSpace(provider)) + for _, t := range targets { + if strings.EqualFold(p, strings.TrimSpace(t)) { + return p, true + } + } + return p, false +} + +func (w *Watcher) start(ctx context.Context) error { + if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil { + log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig) + return errAddConfig + } + log.Debugf("watching config file: %s", w.configPath) + + if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil { + log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir) + return errAddAuthDir + } + log.Debugf("watching auth directory: %s", w.authDir) + + go w.processEvents(ctx) + + w.reloadClients(true, nil, false) + return nil +} + +func (w *Watcher) processEvents(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case event, ok := <-w.watcher.Events: + if !ok { + return + } + w.handleEvent(event) + case errWatch, ok := <-w.watcher.Errors: + if !ok { + return + } + log.Errorf("file watcher error: %v", errWatch) + } + } +} + +func (w *Watcher) handleEvent(event fsnotify.Event) { + // Filter only relevant events: config file or auth-dir JSON files. + configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename + normalizedName := w.normalizeAuthPath(event.Name) + normalizedConfigPath := w.normalizeAuthPath(w.configPath) + normalizedAuthDir := w.normalizeAuthPath(w.authDir) + isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0 + authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename + isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0 + if !isConfigEvent && !isAuthJSON { + // Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise. + return + } + + now := time.Now() + log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name) + + // Handle config file changes + if isConfigEvent { + log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000")) + w.scheduleConfigReload() + return + } + + // Handle auth directory changes incrementally (.json only) + if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 { + if w.shouldDebounceRemove(normalizedName, now) { + log.Debugf("debouncing remove event for %s", filepath.Base(event.Name)) + return + } + // Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready. + // Wait briefly; if the path exists again, treat as an update instead of removal. + time.Sleep(replaceCheckDelay) + if _, statErr := os.Stat(event.Name); statErr == nil { + if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged { + log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name)) + return + } + log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) + w.addOrUpdateClient(event.Name) + return + } + if !w.isKnownAuthFile(event.Name) { + log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name)) + return + } + log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) + w.removeClient(event.Name) + return + } + if event.Op&(fsnotify.Create|fsnotify.Write) != 0 { + if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged { + log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name)) + return + } + log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) + w.addOrUpdateClient(event.Name) + } +} + +func (w *Watcher) authFileUnchanged(path string) (bool, error) { + data, errRead := os.ReadFile(path) + if errRead != nil { + return false, errRead + } + if len(data) == 0 { + return false, nil + } + sum := sha256.Sum256(data) + curHash := hex.EncodeToString(sum[:]) + + normalized := w.normalizeAuthPath(path) + w.clientsMutex.RLock() + prevHash, ok := w.lastAuthHashes[normalized] + w.clientsMutex.RUnlock() + if ok && prevHash == curHash { + return true, nil + } + return false, nil +} + +func (w *Watcher) isKnownAuthFile(path string) bool { + normalized := w.normalizeAuthPath(path) + w.clientsMutex.RLock() + defer w.clientsMutex.RUnlock() + _, ok := w.lastAuthHashes[normalized] + return ok +} + +func (w *Watcher) normalizeAuthPath(path string) string { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "" + } + cleaned := filepath.Clean(trimmed) + if runtime.GOOS == "windows" { + cleaned = strings.TrimPrefix(cleaned, `\\?\`) + cleaned = strings.ToLower(cleaned) + } + return cleaned +} + +func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool { + if normalizedPath == "" { + return false + } + w.clientsMutex.Lock() + if w.lastRemoveTimes == nil { + w.lastRemoveTimes = make(map[string]time.Time) + } + if last, ok := w.lastRemoveTimes[normalizedPath]; ok { + if now.Sub(last) < authRemoveDebounceWindow { + w.clientsMutex.Unlock() + return true + } + } + w.lastRemoveTimes[normalizedPath] = now + if len(w.lastRemoveTimes) > 128 { + cutoff := now.Add(-2 * authRemoveDebounceWindow) + for p, t := range w.lastRemoveTimes { + if t.Before(cutoff) { + delete(w.lastRemoveTimes, p) + } + } + } + w.clientsMutex.Unlock() + return false +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 61e5b378..77006cf8 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1,45 +1,22 @@ -// Package watcher provides file system monitoring functionality for the CLI Proxy API. -// It watches configuration files and authentication directories for changes, -// automatically reloading clients and configuration when files are modified. -// The package handles cross-platform file system events and supports hot-reloading. +// Package watcher watches config/auth files and triggers hot reloads. +// It supports cross-platform fsnotify event handling. package watcher import ( "context" - "crypto/sha256" - "encoding/hex" - "fmt" - "io/fs" - "os" - "path/filepath" - "reflect" - "runtime" "strings" "sync" "time" "github.com/fsnotify/fsnotify" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" - "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer" "gopkg.in/yaml.v3" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" ) -func matchProvider(provider string, targets []string) (string, bool) { - p := strings.ToLower(strings.TrimSpace(provider)) - for _, t := range targets { - if strings.EqualFold(p, strings.TrimSpace(t)) { - return p, true - } - } - return p, false -} - // storePersister captures persistence-capable token store methods used by the watcher. type storePersister interface { PersistConfig(ctx context.Context) error @@ -131,26 +108,7 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config)) // Start begins watching the configuration file and authentication directory func (w *Watcher) Start(ctx context.Context) error { - // Watch the config file - if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil { - log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig) - return errAddConfig - } - log.Debugf("watching config file: %s", w.configPath) - - // Watch the auth directory - if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil { - log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir) - return errAddAuthDir - } - log.Debugf("watching auth directory: %s", w.authDir) - - // Start the event processing goroutine - go w.processEvents(ctx) - - // Perform an initial full reload based on current config and auth dir - w.reloadClients(true, nil, false) - return nil + return w.start(ctx) } // Stop stops the file watcher @@ -160,15 +118,6 @@ func (w *Watcher) Stop() error { return w.watcher.Close() } -func (w *Watcher) stopConfigReloadTimer() { - w.configReloadMu.Lock() - if w.configReloadTimer != nil { - w.configReloadTimer.Stop() - w.configReloadTimer = nil - } - w.configReloadMu.Unlock() -} - // SetConfig updates the current configuration func (w *Watcher) SetConfig(cfg *config.Config) { w.clientsMutex.Lock() @@ -179,818 +128,20 @@ func (w *Watcher) SetConfig(cfg *config.Config) { // SetAuthUpdateQueue sets the queue used to emit auth updates. func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) { - w.clientsMutex.Lock() - defer w.clientsMutex.Unlock() - w.authQueue = queue - if w.dispatchCond == nil { - w.dispatchCond = sync.NewCond(&w.dispatchMu) - } - if w.dispatchCancel != nil { - w.dispatchCancel() - if w.dispatchCond != nil { - w.dispatchMu.Lock() - w.dispatchCond.Broadcast() - w.dispatchMu.Unlock() - } - w.dispatchCancel = nil - } - if queue != nil { - ctx, cancel := context.WithCancel(context.Background()) - w.dispatchCancel = cancel - go w.dispatchLoop(ctx) - } + w.setAuthUpdateQueue(queue) } // DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths) // to push auth updates through the same queue used by file/config watchers. // Returns true if the update was enqueued; false if no queue is configured. func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool { - if w == nil { - return false - } - w.clientsMutex.Lock() - if w.runtimeAuths == nil { - w.runtimeAuths = make(map[string]*coreauth.Auth) - } - switch update.Action { - case AuthUpdateActionAdd, AuthUpdateActionModify: - if update.Auth != nil && update.Auth.ID != "" { - clone := update.Auth.Clone() - w.runtimeAuths[clone.ID] = clone - if w.currentAuths == nil { - w.currentAuths = make(map[string]*coreauth.Auth) - } - w.currentAuths[clone.ID] = clone.Clone() - } - case AuthUpdateActionDelete: - id := update.ID - if id == "" && update.Auth != nil { - id = update.Auth.ID - } - if id != "" { - delete(w.runtimeAuths, id) - if w.currentAuths != nil { - delete(w.currentAuths, id) - } - } - } - w.clientsMutex.Unlock() - if w.getAuthQueue() == nil { - return false - } - w.dispatchAuthUpdates([]AuthUpdate{update}) - return true + return w.dispatchRuntimeAuthUpdate(update) } -func (w *Watcher) refreshAuthState(force bool) { - auths := w.SnapshotCoreAuths() - w.clientsMutex.Lock() - if len(w.runtimeAuths) > 0 { - for _, a := range w.runtimeAuths { - if a != nil { - auths = append(auths, a.Clone()) - } - } - } - updates := w.prepareAuthUpdatesLocked(auths, force) - w.clientsMutex.Unlock() - w.dispatchAuthUpdates(updates) -} - -func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate { - newState := make(map[string]*coreauth.Auth, len(auths)) - for _, auth := range auths { - if auth == nil || auth.ID == "" { - continue - } - newState[auth.ID] = auth.Clone() - } - if w.currentAuths == nil { - w.currentAuths = newState - if w.authQueue == nil { - return nil - } - updates := make([]AuthUpdate, 0, len(newState)) - for id, auth := range newState { - updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()}) - } - return updates - } - if w.authQueue == nil { - w.currentAuths = newState - return nil - } - updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths)) - for id, auth := range newState { - if existing, ok := w.currentAuths[id]; !ok { - updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()}) - } else if force || !authEqual(existing, auth) { - updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()}) - } - } - for id := range w.currentAuths { - if _, ok := newState[id]; !ok { - updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id}) - } - } - w.currentAuths = newState - return updates -} - -func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) { - if len(updates) == 0 { - return - } - queue := w.getAuthQueue() - if queue == nil { - return - } - baseTS := time.Now().UnixNano() - w.dispatchMu.Lock() - if w.pendingUpdates == nil { - w.pendingUpdates = make(map[string]AuthUpdate) - } - for idx, update := range updates { - key := w.authUpdateKey(update, baseTS+int64(idx)) - if _, exists := w.pendingUpdates[key]; !exists { - w.pendingOrder = append(w.pendingOrder, key) - } - w.pendingUpdates[key] = update - } - if w.dispatchCond != nil { - w.dispatchCond.Signal() - } - w.dispatchMu.Unlock() -} - -func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string { - if update.ID != "" { - return update.ID - } - return fmt.Sprintf("%s:%d", update.Action, ts) -} - -func (w *Watcher) dispatchLoop(ctx context.Context) { - for { - batch, ok := w.nextPendingBatch(ctx) - if !ok { - return - } - queue := w.getAuthQueue() - if queue == nil { - if ctx.Err() != nil { - return - } - time.Sleep(10 * time.Millisecond) - continue - } - for _, update := range batch { - select { - case queue <- update: - case <-ctx.Done(): - return - } - } - } -} - -func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) { - w.dispatchMu.Lock() - defer w.dispatchMu.Unlock() - for len(w.pendingOrder) == 0 { - if ctx.Err() != nil { - return nil, false - } - w.dispatchCond.Wait() - if ctx.Err() != nil { - return nil, false - } - } - batch := make([]AuthUpdate, 0, len(w.pendingOrder)) - for _, key := range w.pendingOrder { - batch = append(batch, w.pendingUpdates[key]) - delete(w.pendingUpdates, key) - } - w.pendingOrder = w.pendingOrder[:0] - return batch, true -} - -func (w *Watcher) getAuthQueue() chan<- AuthUpdate { - w.clientsMutex.RLock() - defer w.clientsMutex.RUnlock() - return w.authQueue -} - -func (w *Watcher) stopDispatch() { - if w.dispatchCancel != nil { - w.dispatchCancel() - w.dispatchCancel = nil - } - w.dispatchMu.Lock() - w.pendingOrder = nil - w.pendingUpdates = nil - if w.dispatchCond != nil { - w.dispatchCond.Broadcast() - } - w.dispatchMu.Unlock() - w.clientsMutex.Lock() - w.authQueue = nil - w.clientsMutex.Unlock() -} - -func (w *Watcher) persistConfigAsync() { - if w == nil || w.storePersister == nil { - return - } - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := w.storePersister.PersistConfig(ctx); err != nil { - log.Errorf("failed to persist config change: %v", err) - } - }() -} - -func (w *Watcher) persistAuthAsync(message string, paths ...string) { - if w == nil || w.storePersister == nil { - return - } - filtered := make([]string, 0, len(paths)) - for _, p := range paths { - if trimmed := strings.TrimSpace(p); trimmed != "" { - filtered = append(filtered, trimmed) - } - } - if len(filtered) == 0 { - return - } - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil { - log.Errorf("failed to persist auth changes: %v", err) - } - }() -} - -func authEqual(a, b *coreauth.Auth) bool { - return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b)) -} - -func normalizeAuth(a *coreauth.Auth) *coreauth.Auth { - if a == nil { - return nil - } - clone := a.Clone() - clone.CreatedAt = time.Time{} - clone.UpdatedAt = time.Time{} - clone.LastRefreshedAt = time.Time{} - clone.NextRefreshAfter = time.Time{} - clone.Runtime = nil - clone.Quota.NextRecoverAt = time.Time{} - return clone -} - -// SetClients sets the file-based clients. -// SetClients removed -// SetAPIKeyClients removed - -// processEvents handles file system events -func (w *Watcher) processEvents(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case event, ok := <-w.watcher.Events: - if !ok { - return - } - w.handleEvent(event) - case errWatch, ok := <-w.watcher.Errors: - if !ok { - return - } - log.Errorf("file watcher error: %v", errWatch) - } - } -} - -func (w *Watcher) authFileUnchanged(path string) (bool, error) { - data, errRead := os.ReadFile(path) - if errRead != nil { - return false, errRead - } - if len(data) == 0 { - return false, nil - } - sum := sha256.Sum256(data) - curHash := hex.EncodeToString(sum[:]) - - normalized := w.normalizeAuthPath(path) - w.clientsMutex.RLock() - prevHash, ok := w.lastAuthHashes[normalized] - w.clientsMutex.RUnlock() - if ok && prevHash == curHash { - return true, nil - } - return false, nil -} - -func (w *Watcher) isKnownAuthFile(path string) bool { - normalized := w.normalizeAuthPath(path) - w.clientsMutex.RLock() - defer w.clientsMutex.RUnlock() - _, ok := w.lastAuthHashes[normalized] - return ok -} - -func (w *Watcher) normalizeAuthPath(path string) string { - trimmed := strings.TrimSpace(path) - if trimmed == "" { - return "" - } - cleaned := filepath.Clean(trimmed) - if runtime.GOOS == "windows" { - cleaned = strings.TrimPrefix(cleaned, `\\?\`) - cleaned = strings.ToLower(cleaned) - } - return cleaned -} - -func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool { - if normalizedPath == "" { - return false - } - w.clientsMutex.Lock() - if w.lastRemoveTimes == nil { - w.lastRemoveTimes = make(map[string]time.Time) - } - if last, ok := w.lastRemoveTimes[normalizedPath]; ok { - if now.Sub(last) < authRemoveDebounceWindow { - w.clientsMutex.Unlock() - return true - } - } - w.lastRemoveTimes[normalizedPath] = now - if len(w.lastRemoveTimes) > 128 { - cutoff := now.Add(-2 * authRemoveDebounceWindow) - for p, t := range w.lastRemoveTimes { - if t.Before(cutoff) { - delete(w.lastRemoveTimes, p) - } - } - } - w.clientsMutex.Unlock() - return false -} - -// handleEvent processes individual file system events -func (w *Watcher) handleEvent(event fsnotify.Event) { - // Filter only relevant events: config file or auth-dir JSON files. - configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename - normalizedName := w.normalizeAuthPath(event.Name) - normalizedConfigPath := w.normalizeAuthPath(w.configPath) - normalizedAuthDir := w.normalizeAuthPath(w.authDir) - isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0 - authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename - isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0 - if !isConfigEvent && !isAuthJSON { - // Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise. - return - } - - now := time.Now() - log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name) - - // Handle config file changes - if isConfigEvent { - log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000")) - w.scheduleConfigReload() - return - } - - // Handle auth directory changes incrementally (.json only) - if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 { - if w.shouldDebounceRemove(normalizedName, now) { - log.Debugf("debouncing remove event for %s", filepath.Base(event.Name)) - return - } - // Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready. - // Wait briefly; if the path exists again, treat as an update instead of removal. - time.Sleep(replaceCheckDelay) - if _, statErr := os.Stat(event.Name); statErr == nil { - if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged { - log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name)) - return - } - log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) - w.addOrUpdateClient(event.Name) - return - } - if !w.isKnownAuthFile(event.Name) { - log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name)) - return - } - log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) - w.removeClient(event.Name) - return - } - if event.Op&(fsnotify.Create|fsnotify.Write) != 0 { - if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged { - log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name)) - return - } - log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) - w.addOrUpdateClient(event.Name) - } -} - -func (w *Watcher) scheduleConfigReload() { - w.configReloadMu.Lock() - defer w.configReloadMu.Unlock() - if w.configReloadTimer != nil { - w.configReloadTimer.Stop() - } - w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() { - w.configReloadMu.Lock() - w.configReloadTimer = nil - w.configReloadMu.Unlock() - w.reloadConfigIfChanged() - }) -} - -func (w *Watcher) reloadConfigIfChanged() { - data, err := os.ReadFile(w.configPath) - if err != nil { - log.Errorf("failed to read config file for hash check: %v", err) - return - } - if len(data) == 0 { - log.Debugf("ignoring empty config file write event") - return - } - sum := sha256.Sum256(data) - newHash := hex.EncodeToString(sum[:]) - - w.clientsMutex.RLock() - currentHash := w.lastConfigHash - w.clientsMutex.RUnlock() - - if currentHash != "" && currentHash == newHash { - log.Debugf("config file content unchanged (hash match), skipping reload") - return - } - log.Infof("config file changed, reloading: %s", w.configPath) - if w.reloadConfig() { - finalHash := newHash - if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 { - sumUpdated := sha256.Sum256(updatedData) - finalHash = hex.EncodeToString(sumUpdated[:]) - } else if errRead != nil { - log.WithError(errRead).Debug("failed to compute updated config hash after reload") - } - w.clientsMutex.Lock() - w.lastConfigHash = finalHash - w.clientsMutex.Unlock() - w.persistConfigAsync() - } -} - -// reloadConfig reloads the configuration and triggers a full reload -func (w *Watcher) reloadConfig() bool { - log.Debug("=========================== CONFIG RELOAD ============================") - log.Debugf("starting config reload from: %s", w.configPath) - - newConfig, errLoadConfig := config.LoadConfig(w.configPath) - if errLoadConfig != nil { - log.Errorf("failed to reload config: %v", errLoadConfig) - return false - } - - if w.mirroredAuthDir != "" { - newConfig.AuthDir = w.mirroredAuthDir - } else { - if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil { - log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir) - } else { - newConfig.AuthDir = resolvedAuthDir - } - } - - w.clientsMutex.Lock() - var oldConfig *config.Config - _ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig) - w.oldConfigYaml, _ = yaml.Marshal(newConfig) - w.config = newConfig - w.clientsMutex.Unlock() - - var affectedOAuthProviders []string - if oldConfig != nil { - _, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels) - } - - // Always apply the current log level based on the latest config. - // This ensures logrus reflects the desired level even if change detection misses. - util.SetLogLevel(newConfig) - // Additional debug for visibility when the flag actually changes. - if oldConfig != nil && oldConfig.Debug != newConfig.Debug { - log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug) - } - - // Log configuration changes in debug mode, only when there are material diffs - if oldConfig != nil { - details := diff.BuildConfigChangeDetails(oldConfig, newConfig) - if len(details) > 0 { - log.Debugf("config changes detected:") - for _, d := range details { - log.Debugf(" %s", d) - } - } else { - log.Debugf("no material config field changes detected") - } - } - - authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir - forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix - - log.Infof("config successfully reloaded, triggering client reload") - // Reload clients with new config - w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh) - return true -} - -// reloadClients performs a full scan and reload of all clients. -func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) { - log.Debugf("starting full client load process") - - w.clientsMutex.RLock() - cfg := w.config - w.clientsMutex.RUnlock() - - if cfg == nil { - log.Error("config is nil, cannot reload clients") - return - } - - if len(affectedOAuthProviders) > 0 { - w.clientsMutex.Lock() - if w.currentAuths != nil { - filtered := make(map[string]*coreauth.Auth, len(w.currentAuths)) - for id, auth := range w.currentAuths { - if auth == nil { - continue - } - provider := strings.ToLower(strings.TrimSpace(auth.Provider)) - if _, match := matchProvider(provider, affectedOAuthProviders); match { - continue - } - filtered[id] = auth - } - w.currentAuths = filtered - log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders) - } else { - w.currentAuths = nil - } - w.clientsMutex.Unlock() - } - - // Unregister all old API key clients before creating new ones - // no legacy clients to unregister - - // Create new API key clients based on the new config - geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) - totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount - log.Debugf("loaded %d API key clients", totalAPIKeyClients) - - var authFileCount int - if rescanAuth { - // Load file-based clients when explicitly requested (startup or authDir change) - authFileCount = w.loadFileClients(cfg) - log.Debugf("loaded %d file-based clients", authFileCount) - } else { - // Preserve existing auth hashes and only report current known count to avoid redundant scans. - w.clientsMutex.RLock() - authFileCount = len(w.lastAuthHashes) - w.clientsMutex.RUnlock() - log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount) - } - - // no legacy file-based clients to unregister - - // Update client maps - if rescanAuth { - w.clientsMutex.Lock() - - // Rebuild auth file hash cache for current clients - w.lastAuthHashes = make(map[string]string) - if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil { - log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir) - } else if resolvedAuthDir != "" { - _ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error { - if err != nil { - return nil - } - if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") { - if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 { - sum := sha256.Sum256(data) - normalizedPath := w.normalizeAuthPath(path) - w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:]) - } - } - return nil - }) - } - w.clientsMutex.Unlock() - } - - totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount - - // Ensure consumers observe the new configuration before auth updates dispatch. - if w.reloadCallback != nil { - log.Debugf("triggering server update callback before auth refresh") - w.reloadCallback(cfg) - } - - w.refreshAuthState(forceAuthRefresh) - - log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", - totalNewClients, - authFileCount, - geminiAPIKeyCount, - vertexCompatAPIKeyCount, - claudeAPIKeyCount, - codexAPIKeyCount, - openAICompatCount, - ) -} - -// createClientFromFile creates a single client instance from a given token file path. -// createClientFromFile removed (legacy) - -// addOrUpdateClient handles the addition or update of a single client. -func (w *Watcher) addOrUpdateClient(path string) { - data, errRead := os.ReadFile(path) - if errRead != nil { - log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead) - return - } - if len(data) == 0 { - log.Debugf("ignoring empty auth file: %s", filepath.Base(path)) - return - } - - sum := sha256.Sum256(data) - curHash := hex.EncodeToString(sum[:]) - normalized := w.normalizeAuthPath(path) - - w.clientsMutex.Lock() - - cfg := w.config - if cfg == nil { - log.Error("config is nil, cannot add or update client") - w.clientsMutex.Unlock() - return - } - if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash { - log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path)) - w.clientsMutex.Unlock() - return - } - - // Update hash cache - w.lastAuthHashes[normalized] = curHash - - w.clientsMutex.Unlock() // Unlock before the callback - - w.refreshAuthState(false) - - if w.reloadCallback != nil { - log.Debugf("triggering server update callback after add/update") - w.reloadCallback(cfg) - } - w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path) -} - -// removeClient handles the removal of a single client. -func (w *Watcher) removeClient(path string) { - normalized := w.normalizeAuthPath(path) - w.clientsMutex.Lock() - - cfg := w.config - delete(w.lastAuthHashes, normalized) - - w.clientsMutex.Unlock() // Release the lock before the callback - - w.refreshAuthState(false) - - if w.reloadCallback != nil { - log.Debugf("triggering server update callback after removal") - w.reloadCallback(cfg) - } - w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path) -} - -// SnapshotCombinedClients returns a snapshot of current combined clients. -// SnapshotCombinedClients removed - // SnapshotCoreAuths converts current clients snapshot into core auth entries. func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { w.clientsMutex.RLock() cfg := w.config w.clientsMutex.RUnlock() - - ctx := &synthesizer.SynthesisContext{ - Config: cfg, - AuthDir: w.authDir, - Now: time.Now(), - IDGenerator: synthesizer.NewStableIDGenerator(), - } - - var out []*coreauth.Auth - - // Use ConfigSynthesizer for API key auth entries - configSynth := synthesizer.NewConfigSynthesizer() - if auths, err := configSynth.Synthesize(ctx); err == nil { - out = append(out, auths...) - } - - // Use FileSynthesizer for file-based OAuth auth entries - fileSynth := synthesizer.NewFileSynthesizer() - if auths, err := fileSynth.Synthesize(ctx); err == nil { - out = append(out, auths...) - } - - return out -} - -// buildCombinedClientMap merges file-based clients with API key clients from the cache. -// buildCombinedClientMap removed - -// unregisterClientWithReason attempts to call client-specific unregister hooks with context. -// unregisterClientWithReason removed - -// loadFileClients scans the auth directory and creates clients from .json files. -func (w *Watcher) loadFileClients(cfg *config.Config) int { - authFileCount := 0 - successfulAuthCount := 0 - - authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir) - if errResolveAuthDir != nil { - log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir) - return 0 - } - if authDir == "" { - return 0 - } - - errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error { - if err != nil { - log.Debugf("error accessing path %s: %v", path, err) - return err - } - if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") { - authFileCount++ - log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path)) - // Count readable JSON files as successful auth entries - if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 { - successfulAuthCount++ - } - } - return nil - }) - - if errWalk != nil { - log.Errorf("error walking auth directory: %v", errWalk) - } - log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount) - return authFileCount -} - -func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) { - geminiAPIKeyCount := 0 - vertexCompatAPIKeyCount := 0 - claudeAPIKeyCount := 0 - codexAPIKeyCount := 0 - openAICompatCount := 0 - - if len(cfg.GeminiKey) > 0 { - // Stateless executor handles Gemini API keys; avoid constructing legacy clients. - geminiAPIKeyCount += len(cfg.GeminiKey) - } - if len(cfg.VertexCompatAPIKey) > 0 { - vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey) - } - if len(cfg.ClaudeKey) > 0 { - claudeAPIKeyCount += len(cfg.ClaudeKey) - } - if len(cfg.CodexKey) > 0 { - codexAPIKeyCount += len(cfg.CodexKey) - } - if len(cfg.OpenAICompatibility) > 0 { - // Do not construct legacy clients for OpenAI-compat providers; these are handled by the stateless executor. - for _, compatConfig := range cfg.OpenAICompatibility { - openAICompatCount += len(compatConfig.APIKeyEntries) - } - } - return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount + return snapshotCoreAuths(cfg, w.authDir) } diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index 770b5242..29113f59 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "strings" + "sync" "sync/atomic" "testing" "time" @@ -16,6 +17,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "gopkg.in/yaml.v3" ) @@ -489,6 +491,28 @@ func TestAuthFileUnchangedUsesHash(t *testing.T) { } } +func TestAuthFileUnchangedEmptyAndMissing(t *testing.T) { + tmpDir := t.TempDir() + emptyFile := filepath.Join(tmpDir, "empty.json") + if err := os.WriteFile(emptyFile, []byte(""), 0o644); err != nil { + t.Fatalf("failed to write empty auth file: %v", err) + } + + w := &Watcher{lastAuthHashes: make(map[string]string)} + unchanged, err := w.authFileUnchanged(emptyFile) + if err != nil { + t.Fatalf("unexpected error for empty file: %v", err) + } + if unchanged { + t.Fatal("expected empty file to be treated as changed") + } + + _, err = w.authFileUnchanged(filepath.Join(tmpDir, "missing.json")) + if err == nil { + t.Fatal("expected error for missing auth file") + } +} + func TestReloadClientsCachesAuthHashes(t *testing.T) { tmpDir := t.TempDir() authFile := filepath.Join(tmpDir, "one.json") @@ -528,6 +552,23 @@ func TestReloadClientsLogsConfigDiffs(t *testing.T) { w.reloadClients(false, nil, false) } +func TestReloadClientsHandlesNilConfig(t *testing.T) { + w := &Watcher{} + w.reloadClients(true, nil, false) +} + +func TestReloadClientsFiltersProvidersWithNilCurrentAuths(t *testing.T) { + tmp := t.TempDir() + w := &Watcher{ + authDir: tmp, + config: &config.Config{AuthDir: tmp}, + } + w.reloadClients(false, []string{"match"}, false) + if w.currentAuths != nil && len(w.currentAuths) != 0 { + t.Fatalf("expected currentAuths to be nil or empty, got %d", len(w.currentAuths)) + } +} + func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) { w := &Watcher{} queue := make(chan AuthUpdate, 1) @@ -541,6 +582,45 @@ func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) { } } +func TestPersistAsyncEarlyReturns(t *testing.T) { + var nilWatcher *Watcher + nilWatcher.persistConfigAsync() + nilWatcher.persistAuthAsync("msg", "a") + + w := &Watcher{} + w.persistConfigAsync() + w.persistAuthAsync("msg", " ", "") +} + +type errorPersister struct { + configCalls int32 + authCalls int32 +} + +func (p *errorPersister) PersistConfig(context.Context) error { + atomic.AddInt32(&p.configCalls, 1) + return fmt.Errorf("persist config error") +} + +func (p *errorPersister) PersistAuthFiles(context.Context, string, ...string) error { + atomic.AddInt32(&p.authCalls, 1) + return fmt.Errorf("persist auth error") +} + +func TestPersistAsyncErrorPaths(t *testing.T) { + p := &errorPersister{} + w := &Watcher{storePersister: p} + w.persistConfigAsync() + w.persistAuthAsync("msg", "a") + time.Sleep(30 * time.Millisecond) + if atomic.LoadInt32(&p.configCalls) != 1 { + t.Fatalf("expected PersistConfig to be called once, got %d", p.configCalls) + } + if atomic.LoadInt32(&p.authCalls) != 1 { + t.Fatalf("expected PersistAuthFiles to be called once, got %d", p.authCalls) + } +} + func TestStopConfigReloadTimerSafeWhenNil(t *testing.T) { w := &Watcher{} w.stopConfigReloadTimer() @@ -608,6 +688,803 @@ func TestDispatchAuthUpdatesFlushesQueue(t *testing.T) { } } +func TestDispatchLoopExitsOnContextDoneWhileSending(t *testing.T) { + queue := make(chan AuthUpdate) // unbuffered to block sends + w := &Watcher{ + authQueue: queue, + pendingUpdates: map[string]AuthUpdate{ + "k": {Action: AuthUpdateActionAdd, ID: "k"}, + }, + pendingOrder: []string{"k"}, + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + w.dispatchLoop(ctx) + close(done) + }() + + time.Sleep(30 * time.Millisecond) + cancel() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("expected dispatchLoop to exit after ctx canceled while blocked on send") + } +} + +func TestProcessEventsHandlesEventErrorAndChannelClose(t *testing.T) { + w := &Watcher{ + watcher: &fsnotify.Watcher{ + Events: make(chan fsnotify.Event, 2), + Errors: make(chan error, 2), + }, + configPath: "config.yaml", + authDir: "auth", + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + go func() { + w.processEvents(ctx) + close(done) + }() + + w.watcher.Events <- fsnotify.Event{Name: "unrelated.txt", Op: fsnotify.Write} + w.watcher.Errors <- fmt.Errorf("watcher error") + + time.Sleep(20 * time.Millisecond) + close(w.watcher.Events) + close(w.watcher.Errors) + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("processEvents did not exit after channels closed") + } +} + +func TestProcessEventsReturnsWhenErrorsChannelClosed(t *testing.T) { + w := &Watcher{ + watcher: &fsnotify.Watcher{ + Events: nil, + Errors: make(chan error), + }, + } + + close(w.watcher.Errors) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + go func() { + w.processEvents(ctx) + close(done) + }() + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("processEvents did not exit after errors channel closed") + } +} + +func TestHandleEventIgnoresUnrelatedFiles(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + var reloads int32 + w := &Watcher{ + authDir: authDir, + configPath: configPath, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + + w.handleEvent(fsnotify.Event{Name: filepath.Join(tmpDir, "note.txt"), Op: fsnotify.Write}) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected no reloads for unrelated file, got %d", reloads) + } +} + +func TestHandleEventConfigChangeSchedulesReload(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + var reloads int32 + w := &Watcher{ + authDir: authDir, + configPath: configPath, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + + w.handleEvent(fsnotify.Event{Name: configPath, Op: fsnotify.Write}) + + time.Sleep(400 * time.Millisecond) + if atomic.LoadInt32(&reloads) != 1 { + t.Fatalf("expected config change to trigger reload once, got %d", reloads) + } +} + +func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + authFile := filepath.Join(authDir, "a.json") + if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + + var reloads int32 + w := &Watcher{ + authDir: authDir, + configPath: configPath, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + + w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write}) + if atomic.LoadInt32(&reloads) != 1 { + t.Fatalf("expected auth write to trigger reload callback, got %d", reloads) + } +} + +func TestHandleEventRemoveDebounceSkips(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + authFile := filepath.Join(authDir, "remove.json") + + var reloads int32 + w := &Watcher{ + authDir: authDir, + configPath: configPath, + lastAuthHashes: make(map[string]string), + lastRemoveTimes: map[string]time.Time{ + filepath.Clean(authFile): time.Now(), + }, + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + + w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove}) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected remove to be debounced, got %d", reloads) + } +} + +func TestHandleEventAtomicReplaceUnchangedSkips(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + authFile := filepath.Join(authDir, "same.json") + content := []byte(`{"type":"demo"}`) + if err := os.WriteFile(authFile, content, 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + sum := sha256.Sum256(content) + + var reloads int32 + w := &Watcher{ + authDir: authDir, + configPath: configPath, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(sum[:]) + + w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename}) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected unchanged atomic replace to be skipped, got %d", reloads) + } +} + +func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + authFile := filepath.Join(authDir, "change.json") + oldContent := []byte(`{"type":"demo","v":1}`) + newContent := []byte(`{"type":"demo","v":2}`) + if err := os.WriteFile(authFile, newContent, 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + oldSum := sha256.Sum256(oldContent) + + var reloads int32 + w := &Watcher{ + authDir: authDir, + configPath: configPath, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:]) + + w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename}) + if atomic.LoadInt32(&reloads) != 1 { + t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads) + } +} + +func TestHandleEventRemoveUnknownFileIgnored(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + authFile := filepath.Join(authDir, "unknown.json") + + var reloads int32 + w := &Watcher{ + authDir: authDir, + configPath: configPath, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + + w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove}) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected unknown remove to be ignored, got %d", reloads) + } +} + +func TestHandleEventRemoveKnownFileDeletes(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + authFile := filepath.Join(authDir, "known.json") + + var reloads int32 + w := &Watcher{ + authDir: authDir, + configPath: configPath, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash" + + w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove}) + if atomic.LoadInt32(&reloads) != 1 { + t.Fatalf("expected known remove to trigger reload, got %d", reloads) + } + if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok { + t.Fatal("expected known auth hash to be deleted") + } +} + +func TestNormalizeAuthPathAndDebounceCleanup(t *testing.T) { + w := &Watcher{} + if got := w.normalizeAuthPath(" "); got != "" { + t.Fatalf("expected empty normalize result, got %q", got) + } + if got := w.normalizeAuthPath(" a/../b "); got != filepath.Clean("a/../b") { + t.Fatalf("unexpected normalize result: %q", got) + } + + w.clientsMutex.Lock() + w.lastRemoveTimes = make(map[string]time.Time, 140) + old := time.Now().Add(-3 * authRemoveDebounceWindow) + for i := 0; i < 129; i++ { + w.lastRemoveTimes[fmt.Sprintf("old-%d", i)] = old + } + w.clientsMutex.Unlock() + + w.shouldDebounceRemove("new-path", time.Now()) + + w.clientsMutex.Lock() + gotLen := len(w.lastRemoveTimes) + w.clientsMutex.Unlock() + if gotLen >= 129 { + t.Fatalf("expected debounce cleanup to shrink map, got %d", gotLen) + } +} + +func TestRefreshAuthStateDispatchesRuntimeAuths(t *testing.T) { + queue := make(chan AuthUpdate, 8) + w := &Watcher{ + authDir: t.TempDir(), + lastAuthHashes: make(map[string]string), + } + w.SetConfig(&config.Config{AuthDir: w.authDir}) + w.SetAuthUpdateQueue(queue) + defer w.stopDispatch() + + w.clientsMutex.Lock() + w.runtimeAuths = map[string]*coreauth.Auth{ + "nil": nil, + "r1": {ID: "r1", Provider: "runtime"}, + } + w.clientsMutex.Unlock() + + w.refreshAuthState(false) + + select { + case u := <-queue: + if u.Action != AuthUpdateActionAdd || u.ID != "r1" { + t.Fatalf("unexpected auth update: %+v", u) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for runtime auth update") + } +} + +func TestAddOrUpdateClientEdgeCases(t *testing.T) { + tmpDir := t.TempDir() + authDir := tmpDir + authFile := filepath.Join(tmpDir, "edge.json") + if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + emptyFile := filepath.Join(tmpDir, "empty.json") + if err := os.WriteFile(emptyFile, []byte(""), 0o644); err != nil { + t.Fatalf("failed to write empty auth file: %v", err) + } + + var reloads int32 + w := &Watcher{ + authDir: authDir, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + + w.addOrUpdateClient(filepath.Join(tmpDir, "missing.json")) + w.addOrUpdateClient(emptyFile) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected no reloads for missing/empty file, got %d", reloads) + } + + w.addOrUpdateClient(authFile) // config nil -> should not panic or update + if len(w.lastAuthHashes) != 0 { + t.Fatalf("expected no hash entries without config, got %d", len(w.lastAuthHashes)) + } +} + +func TestLoadFileClientsWalkError(t *testing.T) { + tmpDir := t.TempDir() + noAccessDir := filepath.Join(tmpDir, "0noaccess") + if err := os.MkdirAll(noAccessDir, 0o755); err != nil { + t.Fatalf("failed to create noaccess dir: %v", err) + } + if err := os.Chmod(noAccessDir, 0); err != nil { + t.Skipf("chmod not supported: %v", err) + } + defer func() { _ = os.Chmod(noAccessDir, 0o755) }() + + cfg := &config.Config{AuthDir: tmpDir} + w := &Watcher{} + w.SetConfig(cfg) + + count := w.loadFileClients(cfg) + if count != 0 { + t.Fatalf("expected count 0 due to walk error, got %d", count) + } +} + +func TestReloadConfigIfChangedHandlesMissingAndEmpty(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + + w := &Watcher{ + configPath: filepath.Join(tmpDir, "missing.yaml"), + authDir: authDir, + } + w.reloadConfigIfChanged() // missing file -> log + return + + emptyPath := filepath.Join(tmpDir, "empty.yaml") + if err := os.WriteFile(emptyPath, []byte(""), 0o644); err != nil { + t.Fatalf("failed to write empty config: %v", err) + } + w.configPath = emptyPath + w.reloadConfigIfChanged() // empty file -> early return +} + +func TestReloadConfigUsesMirroredAuthDir(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+filepath.Join(tmpDir, "other")+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + w := &Watcher{ + configPath: configPath, + authDir: authDir, + mirroredAuthDir: authDir, + lastAuthHashes: make(map[string]string), + } + w.SetConfig(&config.Config{AuthDir: authDir}) + + if ok := w.reloadConfig(); !ok { + t.Fatal("expected reloadConfig to succeed") + } + + w.clientsMutex.RLock() + defer w.clientsMutex.RUnlock() + if w.config == nil || w.config.AuthDir != authDir { + t.Fatalf("expected AuthDir to be overridden by mirroredAuthDir %s, got %+v", authDir, w.config) + } +} + +func TestReloadConfigFiltersAffectedOAuthProviders(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + + // Ensure SnapshotCoreAuths yields a provider that is NOT affected, so we can assert it survives. + if err := os.WriteFile(filepath.Join(authDir, "provider-b.json"), []byte(`{"type":"provider-b","email":"b@example.com"}`), 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + + oldCfg := &config.Config{ + AuthDir: authDir, + OAuthExcludedModels: map[string][]string{ + "provider-a": {"m1"}, + }, + } + newCfg := &config.Config{ + AuthDir: authDir, + OAuthExcludedModels: map[string][]string{ + "provider-a": {"m2"}, + }, + } + data, err := yaml.Marshal(newCfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err = os.WriteFile(configPath, data, 0o644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + w := &Watcher{ + configPath: configPath, + authDir: authDir, + lastAuthHashes: make(map[string]string), + currentAuths: map[string]*coreauth.Auth{ + "a": {ID: "a", Provider: "provider-a"}, + }, + } + w.SetConfig(oldCfg) + + if ok := w.reloadConfig(); !ok { + t.Fatal("expected reloadConfig to succeed") + } + + w.clientsMutex.RLock() + defer w.clientsMutex.RUnlock() + for _, auth := range w.currentAuths { + if auth != nil && auth.Provider == "provider-a" { + t.Fatal("expected affected provider auth to be filtered") + } + } + foundB := false + for _, auth := range w.currentAuths { + if auth != nil && auth.Provider == "provider-b" { + foundB = true + break + } + } + if !foundB { + t.Fatal("expected unaffected provider auth to remain") + } +} + +func TestStartFailsWhenAuthDirMissing(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+filepath.Join(tmpDir, "missing-auth")+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + authDir := filepath.Join(tmpDir, "missing-auth") + + w, err := NewWatcher(configPath, authDir, nil) + if err != nil { + t.Fatalf("failed to create watcher: %v", err) + } + defer w.Stop() + w.SetConfig(&config.Config{AuthDir: authDir}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := w.Start(ctx); err == nil { + t.Fatal("expected Start to fail for missing auth dir") + } +} + +func TestDispatchRuntimeAuthUpdateReturnsFalseWithoutQueue(t *testing.T) { + w := &Watcher{} + if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionAdd, Auth: &coreauth.Auth{ID: "a"}}); ok { + t.Fatal("expected DispatchRuntimeAuthUpdate to return false when no queue configured") + } + if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionDelete, Auth: &coreauth.Auth{ID: "a"}}); ok { + t.Fatal("expected DispatchRuntimeAuthUpdate delete to return false when no queue configured") + } +} + +func TestNormalizeAuthNil(t *testing.T) { + if normalizeAuth(nil) != nil { + t.Fatal("expected normalizeAuth(nil) to return nil") + } +} + +// stubStore implements coreauth.Store plus watcher-specific persistence helpers. +type stubStore struct { + authDir string + cfgPersisted int32 + authPersisted int32 + lastAuthMessage string + lastAuthPaths []string +} + +func (s *stubStore) List(context.Context) ([]*coreauth.Auth, error) { return nil, nil } +func (s *stubStore) Save(context.Context, *coreauth.Auth) (string, error) { + return "", nil +} +func (s *stubStore) Delete(context.Context, string) error { return nil } +func (s *stubStore) PersistConfig(context.Context) error { + atomic.AddInt32(&s.cfgPersisted, 1) + return nil +} +func (s *stubStore) PersistAuthFiles(_ context.Context, message string, paths ...string) error { + atomic.AddInt32(&s.authPersisted, 1) + s.lastAuthMessage = message + s.lastAuthPaths = paths + return nil +} +func (s *stubStore) AuthDir() string { return s.authDir } + +func TestNewWatcherDetectsPersisterAndAuthDir(t *testing.T) { + tmp := t.TempDir() + store := &stubStore{authDir: tmp} + orig := sdkAuth.GetTokenStore() + sdkAuth.RegisterTokenStore(store) + defer sdkAuth.RegisterTokenStore(orig) + + w, err := NewWatcher("config.yaml", "auth", nil) + if err != nil { + t.Fatalf("NewWatcher failed: %v", err) + } + if w.storePersister == nil { + t.Fatal("expected storePersister to be set from token store") + } + if w.mirroredAuthDir != tmp { + t.Fatalf("expected mirroredAuthDir %s, got %s", tmp, w.mirroredAuthDir) + } +} + +func TestPersistConfigAndAuthAsyncInvokePersister(t *testing.T) { + w := &Watcher{ + storePersister: &stubStore{}, + } + + w.persistConfigAsync() + w.persistAuthAsync("msg", " a ", "", "b ") + + time.Sleep(30 * time.Millisecond) + store := w.storePersister.(*stubStore) + if atomic.LoadInt32(&store.cfgPersisted) != 1 { + t.Fatalf("expected PersistConfig to be called once, got %d", store.cfgPersisted) + } + if atomic.LoadInt32(&store.authPersisted) != 1 { + t.Fatalf("expected PersistAuthFiles to be called once, got %d", store.authPersisted) + } + if store.lastAuthMessage != "msg" { + t.Fatalf("unexpected auth message: %s", store.lastAuthMessage) + } + if len(store.lastAuthPaths) != 2 || store.lastAuthPaths[0] != "a" || store.lastAuthPaths[1] != "b" { + t.Fatalf("unexpected filtered paths: %#v", store.lastAuthPaths) + } +} + +func TestScheduleConfigReloadDebounces(t *testing.T) { + tmp := t.TempDir() + authDir := tmp + cfgPath := tmp + "/config.yaml" + if err := os.WriteFile(cfgPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + var reloads int32 + w := &Watcher{ + configPath: cfgPath, + authDir: authDir, + reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) }, + } + w.SetConfig(&config.Config{AuthDir: authDir}) + + w.scheduleConfigReload() + w.scheduleConfigReload() + + time.Sleep(400 * time.Millisecond) + + if atomic.LoadInt32(&reloads) != 1 { + t.Fatalf("expected single debounced reload, got %d", reloads) + } + if w.lastConfigHash == "" { + t.Fatal("expected lastConfigHash to be set after reload") + } +} + +func TestPrepareAuthUpdatesLockedForceAndDelete(t *testing.T) { + w := &Watcher{ + currentAuths: map[string]*coreauth.Auth{ + "a": {ID: "a", Provider: "p1"}, + }, + authQueue: make(chan AuthUpdate, 4), + } + + updates := w.prepareAuthUpdatesLocked([]*coreauth.Auth{{ID: "a", Provider: "p2"}}, false) + if len(updates) != 1 || updates[0].Action != AuthUpdateActionModify || updates[0].ID != "a" { + t.Fatalf("unexpected modify updates: %+v", updates) + } + + updates = w.prepareAuthUpdatesLocked([]*coreauth.Auth{{ID: "a", Provider: "p2"}}, true) + if len(updates) != 1 || updates[0].Action != AuthUpdateActionModify { + t.Fatalf("expected force modify, got %+v", updates) + } + + updates = w.prepareAuthUpdatesLocked([]*coreauth.Auth{}, false) + if len(updates) != 1 || updates[0].Action != AuthUpdateActionDelete || updates[0].ID != "a" { + t.Fatalf("expected delete for missing auth, got %+v", updates) + } +} + +func TestAuthEqualIgnoresTemporalFields(t *testing.T) { + now := time.Now() + a := &coreauth.Auth{ID: "x", CreatedAt: now} + b := &coreauth.Auth{ID: "x", CreatedAt: now.Add(5 * time.Second)} + if !authEqual(a, b) { + t.Fatal("expected authEqual to ignore temporal differences") + } +} + +func TestDispatchLoopExitsWhenQueueNilAndContextCanceled(t *testing.T) { + w := &Watcher{ + dispatchCond: nil, + pendingUpdates: map[string]AuthUpdate{"k": {ID: "k"}}, + pendingOrder: []string{"k"}, + } + w.dispatchCond = sync.NewCond(&w.dispatchMu) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + w.dispatchLoop(ctx) + close(done) + }() + + time.Sleep(20 * time.Millisecond) + cancel() + w.dispatchMu.Lock() + w.dispatchCond.Broadcast() + w.dispatchMu.Unlock() + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("dispatchLoop did not exit after context cancel") + } +} + +func TestReloadClientsFiltersOAuthProvidersWithoutRescan(t *testing.T) { + tmp := t.TempDir() + w := &Watcher{ + authDir: tmp, + config: &config.Config{AuthDir: tmp}, + currentAuths: map[string]*coreauth.Auth{ + "a": {ID: "a", Provider: "Match"}, + "b": {ID: "b", Provider: "other"}, + }, + lastAuthHashes: map[string]string{"cached": "hash"}, + } + + w.reloadClients(false, []string{"match"}, false) + + w.clientsMutex.RLock() + defer w.clientsMutex.RUnlock() + if _, ok := w.currentAuths["a"]; ok { + t.Fatal("expected filtered provider to be removed") + } + if len(w.lastAuthHashes) != 1 { + t.Fatalf("expected existing hash cache to be retained, got %d", len(w.lastAuthHashes)) + } +} + +func TestScheduleProcessEventsStopsOnContextDone(t *testing.T) { + w := &Watcher{ + watcher: &fsnotify.Watcher{ + Events: make(chan fsnotify.Event, 1), + Errors: make(chan error, 1), + }, + configPath: "config.yaml", + authDir: "auth", + } + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + w.processEvents(ctx) + close(done) + }() + + cancel() + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("processEvents did not exit on context cancel") + } +} + func hexString(data []byte) string { return strings.ToLower(fmt.Sprintf("%x", data)) }