Files
CLIProxyAPI/internal/watcher/clients.go
hkfires 41c2385aca refactor(watcher): split watcher.go into focused modules
- Create dispatcher.go for auth update queue management
- Create events.go for fsnotify event handling
- Create config_reload.go for hot-reload logic
- Create clients.go for client lifecycle management
- Simplify watcher.go to core coordinator (~150 lines)
- Maintain 100% API backward compatibility
- All tests passing with 72%+ coverage
2025-12-17 15:53:28 +08:00

271 lines
7.7 KiB
Go

// clients.go implements watcher client lifecycle logic and persistence helpers.
// It reloads clients, handles incremental auth file changes, and persists updates when supported.
package watcher
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
log.Debugf("starting full client load process")
w.clientsMutex.RLock()
cfg := w.config
w.clientsMutex.RUnlock()
if cfg == nil {
log.Error("config is nil, cannot reload clients")
return
}
if len(affectedOAuthProviders) > 0 {
w.clientsMutex.Lock()
if w.currentAuths != nil {
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
for id, auth := range w.currentAuths {
if auth == nil {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if _, match := matchProvider(provider, affectedOAuthProviders); match {
continue
}
filtered[id] = auth
}
w.currentAuths = filtered
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
} else {
w.currentAuths = nil
}
w.clientsMutex.Unlock()
}
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
var authFileCount int
if rescanAuth {
authFileCount = w.loadFileClients(cfg)
log.Debugf("loaded %d file-based clients", authFileCount)
} else {
w.clientsMutex.RLock()
authFileCount = len(w.lastAuthHashes)
w.clientsMutex.RUnlock()
log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount)
}
if rescanAuth {
w.clientsMutex.Lock()
w.lastAuthHashes = make(map[string]string)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
} else if resolvedAuthDir != "" {
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return nil
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
sum := sha256.Sum256(data)
normalizedPath := w.normalizeAuthPath(path)
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
}
}
return nil
})
}
w.clientsMutex.Unlock()
}
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
if w.reloadCallback != nil {
log.Debugf("triggering server update callback before auth refresh")
w.reloadCallback(cfg)
}
w.refreshAuthState(forceAuthRefresh)
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
totalNewClients,
authFileCount,
geminiAPIKeyCount,
vertexCompatAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,
openAICompatCount,
)
}
func (w *Watcher) addOrUpdateClient(path string) {
data, errRead := os.ReadFile(path)
if errRead != nil {
log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty auth file: %s", filepath.Base(path))
return
}
sum := sha256.Sum256(data)
curHash := hex.EncodeToString(sum[:])
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
if cfg == nil {
log.Error("config is nil, cannot add or update client")
w.clientsMutex.Unlock()
return
}
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
w.clientsMutex.Unlock()
return
}
w.lastAuthHashes[normalized] = curHash
w.clientsMutex.Unlock() // Unlock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after add/update")
w.reloadCallback(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
}
func (w *Watcher) removeClient(path string) {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
delete(w.lastAuthHashes, normalized)
w.clientsMutex.Unlock() // Release the lock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after removal")
w.reloadCallback(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
}
func (w *Watcher) loadFileClients(cfg *config.Config) int {
authFileCount := 0
successfulAuthCount := 0
authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir)
if errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
return 0
}
if authDir == "" {
return 0
}
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
log.Debugf("error accessing path %s: %v", path, err)
return err
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
authFileCount++
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
successfulAuthCount++
}
}
return nil
})
if errWalk != nil {
log.Errorf("error walking auth directory: %v", errWalk)
}
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
return authFileCount
}
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
geminiAPIKeyCount := 0
vertexCompatAPIKeyCount := 0
claudeAPIKeyCount := 0
codexAPIKeyCount := 0
openAICompatCount := 0
if len(cfg.GeminiKey) > 0 {
geminiAPIKeyCount += len(cfg.GeminiKey)
}
if len(cfg.VertexCompatAPIKey) > 0 {
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
}
if len(cfg.ClaudeKey) > 0 {
claudeAPIKeyCount += len(cfg.ClaudeKey)
}
if len(cfg.CodexKey) > 0 {
codexAPIKeyCount += len(cfg.CodexKey)
}
if len(cfg.OpenAICompatibility) > 0 {
for _, compatConfig := range cfg.OpenAICompatibility {
openAICompatCount += len(compatConfig.APIKeyEntries)
}
}
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
}
func (w *Watcher) persistConfigAsync() {
if w == nil || w.storePersister == nil {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := w.storePersister.PersistConfig(ctx); err != nil {
log.Errorf("failed to persist config change: %v", err)
}
}()
}
func (w *Watcher) persistAuthAsync(message string, paths ...string) {
if w == nil || w.storePersister == nil {
return
}
filtered := make([]string, 0, len(paths))
for _, p := range paths {
if trimmed := strings.TrimSpace(p); trimmed != "" {
filtered = append(filtered, trimmed)
}
}
if len(filtered) == 0 {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil {
log.Errorf("failed to persist auth changes: %v", err)
}
}()
}