From ef8e94e99228763ff14403ae02ae4b3ededd6f11 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 16 Dec 2025 21:45:33 +0800 Subject: [PATCH] refactor(watcher): extract config diff helpers Break out config diffing, hashing, and OpenAI compatibility utilities into a dedicated diff package, update watcher to consume them, and add comprehensive tests for diff logic and watcher behavior. --- internal/watcher/diff/config_diff.go | 259 ++++++++ internal/watcher/diff/config_diff_test.go | 465 ++++++++++++++ internal/watcher/diff/model_hash.go | 62 ++ internal/watcher/diff/model_hash_test.go | 104 ++++ internal/watcher/diff/oauth_excluded.go | 151 +++++ internal/watcher/diff/oauth_excluded_test.go | 109 ++++ internal/watcher/diff/openai_compat.go | 124 ++++ internal/watcher/diff/openai_compat_test.go | 113 ++++ internal/watcher/watcher.go | 522 +--------------- internal/watcher/watcher_test.go | 609 +++++++++++++++++++ 10 files changed, 2004 insertions(+), 514 deletions(-) create mode 100644 internal/watcher/diff/config_diff.go create mode 100644 internal/watcher/diff/config_diff_test.go create mode 100644 internal/watcher/diff/model_hash.go create mode 100644 internal/watcher/diff/model_hash_test.go create mode 100644 internal/watcher/diff/oauth_excluded.go create mode 100644 internal/watcher/diff/oauth_excluded_test.go create mode 100644 internal/watcher/diff/openai_compat.go create mode 100644 internal/watcher/diff/openai_compat_test.go create mode 100644 internal/watcher/watcher_test.go diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go new file mode 100644 index 00000000..092001fd --- /dev/null +++ b/internal/watcher/diff/config_diff.go @@ -0,0 +1,259 @@ +package diff + +import ( + "fmt" + "reflect" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +// BuildConfigChangeDetails computes a redacted, human-readable list of config changes. +// Secrets are never printed; only structural or non-sensitive fields are surfaced. +func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { + changes := make([]string, 0, 16) + if oldCfg == nil || newCfg == nil { + return changes + } + + // Simple scalars + if oldCfg.Port != newCfg.Port { + changes = append(changes, fmt.Sprintf("port: %d -> %d", oldCfg.Port, newCfg.Port)) + } + if oldCfg.AuthDir != newCfg.AuthDir { + changes = append(changes, fmt.Sprintf("auth-dir: %s -> %s", oldCfg.AuthDir, newCfg.AuthDir)) + } + if oldCfg.Debug != newCfg.Debug { + changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug)) + } + if oldCfg.LoggingToFile != newCfg.LoggingToFile { + changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile)) + } + if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled { + changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled)) + } + if oldCfg.DisableCooling != newCfg.DisableCooling { + changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling)) + } + if oldCfg.RequestLog != newCfg.RequestLog { + changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog)) + } + if oldCfg.RequestRetry != newCfg.RequestRetry { + changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry)) + } + if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval { + changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval)) + } + if oldCfg.ProxyURL != newCfg.ProxyURL { + changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL)) + } + if oldCfg.WebsocketAuth != newCfg.WebsocketAuth { + changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth)) + } + + // Quota-exceeded behavior + if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject { + changes = append(changes, fmt.Sprintf("quota-exceeded.switch-project: %t -> %t", oldCfg.QuotaExceeded.SwitchProject, newCfg.QuotaExceeded.SwitchProject)) + } + if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel { + changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel)) + } + + // API keys (redacted) and counts + if len(oldCfg.APIKeys) != len(newCfg.APIKeys) { + changes = append(changes, fmt.Sprintf("api-keys count: %d -> %d", len(oldCfg.APIKeys), len(newCfg.APIKeys))) + } else if !reflect.DeepEqual(trimStrings(oldCfg.APIKeys), trimStrings(newCfg.APIKeys)) { + changes = append(changes, "api-keys: values updated (count unchanged, redacted)") + } + if len(oldCfg.GeminiKey) != len(newCfg.GeminiKey) { + changes = append(changes, fmt.Sprintf("gemini-api-key count: %d -> %d", len(oldCfg.GeminiKey), len(newCfg.GeminiKey))) + } else { + for i := range oldCfg.GeminiKey { + o := oldCfg.GeminiKey[i] + n := newCfg.GeminiKey[i] + if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) { + changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) + } + if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { + changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) + } + if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { + changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i)) + } + if !equalStringMap(o.Headers, n.Headers) { + changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i)) + } + oldExcluded := SummarizeExcludedModels(o.ExcludedModels) + newExcluded := SummarizeExcludedModels(n.ExcludedModels) + if oldExcluded.hash != newExcluded.hash { + changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) + } + } + } + + // Claude keys (do not print key material) + if len(oldCfg.ClaudeKey) != len(newCfg.ClaudeKey) { + changes = append(changes, fmt.Sprintf("claude-api-key count: %d -> %d", len(oldCfg.ClaudeKey), len(newCfg.ClaudeKey))) + } else { + for i := range oldCfg.ClaudeKey { + o := oldCfg.ClaudeKey[i] + n := newCfg.ClaudeKey[i] + if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) { + changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) + } + if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { + changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) + } + if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { + changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i)) + } + if !equalStringMap(o.Headers, n.Headers) { + changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i)) + } + oldExcluded := SummarizeExcludedModels(o.ExcludedModels) + newExcluded := SummarizeExcludedModels(n.ExcludedModels) + if oldExcluded.hash != newExcluded.hash { + changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) + } + } + } + + // Codex keys (do not print key material) + if len(oldCfg.CodexKey) != len(newCfg.CodexKey) { + changes = append(changes, fmt.Sprintf("codex-api-key count: %d -> %d", len(oldCfg.CodexKey), len(newCfg.CodexKey))) + } else { + for i := range oldCfg.CodexKey { + o := oldCfg.CodexKey[i] + n := newCfg.CodexKey[i] + if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) { + changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) + } + if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { + changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) + } + if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { + changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i)) + } + if !equalStringMap(o.Headers, n.Headers) { + changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i)) + } + oldExcluded := SummarizeExcludedModels(o.ExcludedModels) + newExcluded := SummarizeExcludedModels(n.ExcludedModels) + if oldExcluded.hash != newExcluded.hash { + changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) + } + } + } + + // AmpCode settings (redacted where needed) + oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL) + newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL) + if oldAmpURL != newAmpURL { + changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL)) + } + oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey) + newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey) + switch { + case oldAmpKey == "" && newAmpKey != "": + changes = append(changes, "ampcode.upstream-api-key: added") + case oldAmpKey != "" && newAmpKey == "": + changes = append(changes, "ampcode.upstream-api-key: removed") + case oldAmpKey != newAmpKey: + changes = append(changes, "ampcode.upstream-api-key: updated") + } + if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost { + changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost)) + } + oldMappings := SummarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings) + newMappings := SummarizeAmpModelMappings(newCfg.AmpCode.ModelMappings) + if oldMappings.hash != newMappings.hash { + changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count)) + } + if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings { + changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings)) + } + + if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { + changes = append(changes, entries...) + } + + // Remote management (never print the key) + if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote { + changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote)) + } + if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel { + changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel)) + } + oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository) + newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository) + if oldPanelRepo != newPanelRepo { + changes = append(changes, fmt.Sprintf("remote-management.panel-github-repository: %s -> %s", oldPanelRepo, newPanelRepo)) + } + if oldCfg.RemoteManagement.SecretKey != newCfg.RemoteManagement.SecretKey { + switch { + case oldCfg.RemoteManagement.SecretKey == "" && newCfg.RemoteManagement.SecretKey != "": + changes = append(changes, "remote-management.secret-key: created") + case oldCfg.RemoteManagement.SecretKey != "" && newCfg.RemoteManagement.SecretKey == "": + changes = append(changes, "remote-management.secret-key: deleted") + default: + changes = append(changes, "remote-management.secret-key: updated") + } + } + + // OpenAI compatibility providers (summarized) + if compat := DiffOpenAICompatibility(oldCfg.OpenAICompatibility, newCfg.OpenAICompatibility); len(compat) > 0 { + changes = append(changes, "openai-compatibility:") + for _, c := range compat { + changes = append(changes, " "+c) + } + } + + // Vertex-compatible API keys + if len(oldCfg.VertexCompatAPIKey) != len(newCfg.VertexCompatAPIKey) { + changes = append(changes, fmt.Sprintf("vertex-api-key count: %d -> %d", len(oldCfg.VertexCompatAPIKey), len(newCfg.VertexCompatAPIKey))) + } else { + for i := range oldCfg.VertexCompatAPIKey { + o := oldCfg.VertexCompatAPIKey[i] + n := newCfg.VertexCompatAPIKey[i] + if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) { + changes = append(changes, fmt.Sprintf("vertex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) + } + if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { + changes = append(changes, fmt.Sprintf("vertex[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) + } + if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { + changes = append(changes, fmt.Sprintf("vertex[%d].api-key: updated", i)) + } + oldModels := SummarizeVertexModels(o.Models) + newModels := SummarizeVertexModels(n.Models) + if oldModels.hash != newModels.hash { + changes = append(changes, fmt.Sprintf("vertex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count)) + } + if !equalStringMap(o.Headers, n.Headers) { + changes = append(changes, fmt.Sprintf("vertex[%d].headers: updated", i)) + } + } + } + + return changes +} + +func trimStrings(in []string) []string { + out := make([]string, len(in)) + for i := range in { + out[i] = strings.TrimSpace(in[i]) + } + return out +} + +func equalStringMap(a, b map[string]string) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if b[k] != v { + return false + } + } + return true +} diff --git a/internal/watcher/diff/config_diff_test.go b/internal/watcher/diff/config_diff_test.go new file mode 100644 index 00000000..f952b695 --- /dev/null +++ b/internal/watcher/diff/config_diff_test.go @@ -0,0 +1,465 @@ +package diff + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestBuildConfigChangeDetails(t *testing.T) { + oldCfg := &config.Config{ + Port: 8080, + AuthDir: "/tmp/auth-old", + GeminiKey: []config.GeminiKey{ + {APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model"}}, + }, + AmpCode: config.AmpCode{ + UpstreamURL: "http://old-upstream", + ModelMappings: []config.AmpModelMapping{{From: "from-old", To: "to-old"}}, + RestrictManagementToLocalhost: false, + }, + RemoteManagement: config.RemoteManagement{ + AllowRemote: false, + SecretKey: "old", + DisableControlPanel: false, + PanelGitHubRepository: "repo-old", + }, + OAuthExcludedModels: map[string][]string{ + "providerA": {"m1"}, + }, + OpenAICompatibility: []config.OpenAICompatibility{ + { + Name: "compat-a", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "k1"}, + }, + Models: []config.OpenAICompatibilityModel{{Name: "m1"}}, + }, + }, + } + + newCfg := &config.Config{ + Port: 9090, + AuthDir: "/tmp/auth-new", + GeminiKey: []config.GeminiKey{ + {APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model", "extra"}}, + }, + AmpCode: config.AmpCode{ + UpstreamURL: "http://new-upstream", + RestrictManagementToLocalhost: true, + ModelMappings: []config.AmpModelMapping{ + {From: "from-old", To: "to-old"}, + {From: "from-new", To: "to-new"}, + }, + }, + RemoteManagement: config.RemoteManagement{ + AllowRemote: true, + SecretKey: "new", + DisableControlPanel: true, + PanelGitHubRepository: "repo-new", + }, + OAuthExcludedModels: map[string][]string{ + "providerA": {"m1", "m2"}, + "providerB": {"x"}, + }, + OpenAICompatibility: []config.OpenAICompatibility{ + { + Name: "compat-a", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "k1"}, + }, + Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}}, + }, + { + Name: "compat-b", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "k2"}, + }, + }, + }, + } + + details := BuildConfigChangeDetails(oldCfg, newCfg) + + expectContains(t, details, "port: 8080 -> 9090") + expectContains(t, details, "auth-dir: /tmp/auth-old -> /tmp/auth-new") + expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)") + expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream") + expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)") + expectContains(t, details, "remote-management.allow-remote: false -> true") + expectContains(t, details, "remote-management.secret-key: updated") + expectContains(t, details, "oauth-excluded-models[providera]: updated (1 -> 2 entries)") + expectContains(t, details, "oauth-excluded-models[providerb]: added (1 entries)") + expectContains(t, details, "openai-compatibility:") + expectContains(t, details, " provider added: compat-b (api-keys=1, models=0)") + expectContains(t, details, " provider updated: compat-a (models 1 -> 2)") +} + +func TestBuildConfigChangeDetails_NoChanges(t *testing.T) { + cfg := &config.Config{ + Port: 8080, + } + if details := BuildConfigChangeDetails(cfg, cfg); len(details) != 0 { + t.Fatalf("expected no change entries, got %v", details) + } +} + +func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing.T) { + oldCfg := &config.Config{ + GeminiKey: []config.GeminiKey{ + {APIKey: "g1", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}}, + }, + VertexCompatAPIKey: []config.VertexCompatKey{ + {APIKey: "v1", BaseURL: "http://v-old", Models: []config.VertexCompatModel{{Name: "m1"}}}, + }, + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}}, + ForceModelMappings: false, + }, + } + newCfg := &config.Config{ + GeminiKey: []config.GeminiKey{ + {APIKey: "g1", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"a", "b"}}, + }, + VertexCompatAPIKey: []config.VertexCompatKey{ + {APIKey: "v1", BaseURL: "http://v-new", Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}}, + }, + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}}, + ForceModelMappings: true, + }, + } + + details := BuildConfigChangeDetails(oldCfg, newCfg) + expectContains(t, details, "gemini[0].headers: updated") + expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)") + expectContains(t, details, "ampcode.model-mappings: updated (1 -> 1 entries)") + expectContains(t, details, "ampcode.force-model-mappings: false -> true") +} + +func TestBuildConfigChangeDetails_NilSafe(t *testing.T) { + if details := BuildConfigChangeDetails(nil, &config.Config{}); len(details) != 0 { + t.Fatalf("expected empty change list when old nil, got %v", details) + } + if details := BuildConfigChangeDetails(&config.Config{}, nil); len(details) != 0 { + t.Fatalf("expected empty change list when new nil, got %v", details) + } +} + +func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) { + oldCfg := &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ + APIKeys: []string{"a"}, + }, + AmpCode: config.AmpCode{ + UpstreamAPIKey: "", + }, + RemoteManagement: config.RemoteManagement{ + SecretKey: "", + }, + } + newCfg := &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ + APIKeys: []string{"a", "b", "c"}, + }, + AmpCode: config.AmpCode{ + UpstreamAPIKey: "new-key", + }, + RemoteManagement: config.RemoteManagement{ + SecretKey: "new-secret", + }, + } + + details := BuildConfigChangeDetails(oldCfg, newCfg) + expectContains(t, details, "api-keys count: 1 -> 3") + expectContains(t, details, "ampcode.upstream-api-key: added") + expectContains(t, details, "remote-management.secret-key: created") +} + +func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) { + oldCfg := &config.Config{ + Port: 1000, + AuthDir: "/old", + Debug: false, + LoggingToFile: false, + UsageStatisticsEnabled: false, + DisableCooling: false, + RequestRetry: 1, + MaxRetryInterval: 1, + WebsocketAuth: false, + QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false}, + ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}}, + CodexKey: []config.CodexKey{{APIKey: "x1"}}, + AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false}, + RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"}, + SDKConfig: sdkconfig.SDKConfig{ + RequestLog: false, + ProxyURL: "http://old-proxy", + APIKeys: []string{"key-1"}, + }, + } + newCfg := &config.Config{ + Port: 2000, + AuthDir: "/new", + Debug: true, + LoggingToFile: true, + UsageStatisticsEnabled: true, + DisableCooling: true, + RequestRetry: 2, + MaxRetryInterval: 3, + WebsocketAuth: true, + QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true}, + ClaudeKey: []config.ClaudeKey{ + {APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}}, + {APIKey: "c2"}, + }, + CodexKey: []config.CodexKey{ + {APIKey: "x1", BaseURL: "http://x", ProxyURL: "http://px", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"b"}}, + {APIKey: "x2"}, + }, + AmpCode: config.AmpCode{ + UpstreamAPIKey: "", + RestrictManagementToLocalhost: true, + ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}}, + }, + RemoteManagement: config.RemoteManagement{ + DisableControlPanel: true, + PanelGitHubRepository: "new/repo", + SecretKey: "", + }, + SDKConfig: sdkconfig.SDKConfig{ + RequestLog: true, + ProxyURL: "http://new-proxy", + APIKeys: []string{" key-1 ", "key-2"}, + }, + } + + details := BuildConfigChangeDetails(oldCfg, newCfg) + expectContains(t, details, "debug: false -> true") + expectContains(t, details, "logging-to-file: false -> true") + expectContains(t, details, "usage-statistics-enabled: false -> true") + expectContains(t, details, "disable-cooling: false -> true") + expectContains(t, details, "request-log: false -> true") + expectContains(t, details, "request-retry: 1 -> 2") + expectContains(t, details, "max-retry-interval: 1 -> 3") + expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy") + expectContains(t, details, "ws-auth: false -> true") + expectContains(t, details, "quota-exceeded.switch-project: false -> true") + expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true") + expectContains(t, details, "api-keys count: 1 -> 2") + expectContains(t, details, "claude-api-key count: 1 -> 2") + expectContains(t, details, "codex-api-key count: 1 -> 2") + expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true") + expectContains(t, details, "ampcode.upstream-api-key: removed") + expectContains(t, details, "remote-management.disable-control-panel: false -> true") + expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo") + expectContains(t, details, "remote-management.secret-key: deleted") +} + +func TestBuildConfigChangeDetails_AllBranches(t *testing.T) { + oldCfg := &config.Config{ + Port: 1, + AuthDir: "/a", + Debug: false, + LoggingToFile: false, + UsageStatisticsEnabled: false, + DisableCooling: false, + RequestRetry: 1, + MaxRetryInterval: 1, + WebsocketAuth: false, + QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false}, + GeminiKey: []config.GeminiKey{ + {APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}}, + }, + ClaudeKey: []config.ClaudeKey{ + {APIKey: "c-old", BaseURL: "http://c-old", ProxyURL: "http://cp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}}, + }, + CodexKey: []config.CodexKey{ + {APIKey: "x-old", BaseURL: "http://x-old", ProxyURL: "http://xp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}}, + }, + VertexCompatAPIKey: []config.VertexCompatKey{ + {APIKey: "v-old", BaseURL: "http://v-old", ProxyURL: "http://vp-old", Headers: map[string]string{"H": "1"}, Models: []config.VertexCompatModel{{Name: "m1"}}}, + }, + AmpCode: config.AmpCode{ + UpstreamURL: "http://amp-old", + UpstreamAPIKey: "old-key", + RestrictManagementToLocalhost: false, + ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}}, + ForceModelMappings: false, + }, + RemoteManagement: config.RemoteManagement{ + AllowRemote: false, + DisableControlPanel: false, + PanelGitHubRepository: "old/repo", + SecretKey: "old", + }, + SDKConfig: sdkconfig.SDKConfig{ + RequestLog: false, + ProxyURL: "http://old-proxy", + APIKeys: []string{" keyA "}, + }, + OAuthExcludedModels: map[string][]string{"p1": {"a"}}, + OpenAICompatibility: []config.OpenAICompatibility{ + { + Name: "prov-old", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "k1"}, + }, + Models: []config.OpenAICompatibilityModel{{Name: "m1"}}, + }, + }, + } + newCfg := &config.Config{ + Port: 2, + AuthDir: "/b", + Debug: true, + LoggingToFile: true, + UsageStatisticsEnabled: true, + DisableCooling: true, + RequestRetry: 2, + MaxRetryInterval: 3, + WebsocketAuth: true, + QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true}, + GeminiKey: []config.GeminiKey{ + {APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}}, + }, + ClaudeKey: []config.ClaudeKey{ + {APIKey: "c-new", BaseURL: "http://c-new", ProxyURL: "http://cp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}}, + }, + CodexKey: []config.CodexKey{ + {APIKey: "x-new", BaseURL: "http://x-new", ProxyURL: "http://xp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}}, + }, + VertexCompatAPIKey: []config.VertexCompatKey{ + {APIKey: "v-new", BaseURL: "http://v-new", ProxyURL: "http://vp-new", Headers: map[string]string{"H": "2"}, Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}}, + }, + AmpCode: config.AmpCode{ + UpstreamURL: "http://amp-new", + UpstreamAPIKey: "", + RestrictManagementToLocalhost: true, + ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}}, + ForceModelMappings: true, + }, + RemoteManagement: config.RemoteManagement{ + AllowRemote: true, + DisableControlPanel: true, + PanelGitHubRepository: "new/repo", + SecretKey: "", + }, + SDKConfig: sdkconfig.SDKConfig{ + RequestLog: true, + ProxyURL: "http://new-proxy", + APIKeys: []string{"keyB"}, + }, + OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}}, + OpenAICompatibility: []config.OpenAICompatibility{ + { + Name: "prov-old", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "k1"}, + {APIKey: "k2"}, + }, + Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}}, + }, + { + Name: "prov-new", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k3"}}, + }, + }, + } + + changes := BuildConfigChangeDetails(oldCfg, newCfg) + expectContains(t, changes, "port: 1 -> 2") + expectContains(t, changes, "auth-dir: /a -> /b") + expectContains(t, changes, "debug: false -> true") + expectContains(t, changes, "logging-to-file: false -> true") + expectContains(t, changes, "usage-statistics-enabled: false -> true") + expectContains(t, changes, "disable-cooling: false -> true") + expectContains(t, changes, "request-retry: 1 -> 2") + expectContains(t, changes, "max-retry-interval: 1 -> 3") + expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy") + expectContains(t, changes, "ws-auth: false -> true") + expectContains(t, changes, "quota-exceeded.switch-project: false -> true") + expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true") + expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)") + expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new") + expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new") + expectContains(t, changes, "gemini[0].api-key: updated") + expectContains(t, changes, "gemini[0].headers: updated") + expectContains(t, changes, "gemini[0].excluded-models: updated (0 -> 2 entries)") + expectContains(t, changes, "claude[0].base-url: http://c-old -> http://c-new") + expectContains(t, changes, "claude[0].proxy-url: http://cp-old -> http://cp-new") + expectContains(t, changes, "claude[0].api-key: updated") + expectContains(t, changes, "claude[0].headers: updated") + expectContains(t, changes, "claude[0].excluded-models: updated (1 -> 2 entries)") + expectContains(t, changes, "codex[0].base-url: http://x-old -> http://x-new") + expectContains(t, changes, "codex[0].proxy-url: http://xp-old -> http://xp-new") + expectContains(t, changes, "codex[0].api-key: updated") + expectContains(t, changes, "codex[0].headers: updated") + expectContains(t, changes, "codex[0].excluded-models: updated (1 -> 2 entries)") + expectContains(t, changes, "vertex[0].base-url: http://v-old -> http://v-new") + expectContains(t, changes, "vertex[0].proxy-url: http://vp-old -> http://vp-new") + expectContains(t, changes, "vertex[0].api-key: updated") + expectContains(t, changes, "vertex[0].models: updated (1 -> 2 entries)") + expectContains(t, changes, "vertex[0].headers: updated") + expectContains(t, changes, "ampcode.upstream-url: http://amp-old -> http://amp-new") + expectContains(t, changes, "ampcode.upstream-api-key: removed") + expectContains(t, changes, "ampcode.restrict-management-to-localhost: false -> true") + expectContains(t, changes, "ampcode.model-mappings: updated (1 -> 1 entries)") + expectContains(t, changes, "ampcode.force-model-mappings: false -> true") + expectContains(t, changes, "oauth-excluded-models[p1]: updated (1 -> 2 entries)") + expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)") + expectContains(t, changes, "remote-management.allow-remote: false -> true") + expectContains(t, changes, "remote-management.disable-control-panel: false -> true") + expectContains(t, changes, "remote-management.panel-github-repository: old/repo -> new/repo") + expectContains(t, changes, "remote-management.secret-key: deleted") + expectContains(t, changes, "openai-compatibility:") +} + +func TestBuildConfigChangeDetails_SecretAndUpstreamUpdates(t *testing.T) { + oldCfg := &config.Config{ + AmpCode: config.AmpCode{ + UpstreamAPIKey: "old", + }, + RemoteManagement: config.RemoteManagement{ + SecretKey: "old", + }, + } + newCfg := &config.Config{ + AmpCode: config.AmpCode{ + UpstreamAPIKey: "new", + }, + RemoteManagement: config.RemoteManagement{ + SecretKey: "new", + }, + } + + changes := BuildConfigChangeDetails(oldCfg, newCfg) + expectContains(t, changes, "ampcode.upstream-api-key: updated") + expectContains(t, changes, "remote-management.secret-key: updated") +} + +func TestBuildConfigChangeDetails_CountBranches(t *testing.T) { + oldCfg := &config.Config{} + newCfg := &config.Config{ + GeminiKey: []config.GeminiKey{{APIKey: "g"}}, + ClaudeKey: []config.ClaudeKey{{APIKey: "c"}}, + CodexKey: []config.CodexKey{{APIKey: "x"}}, + VertexCompatAPIKey: []config.VertexCompatKey{ + {APIKey: "v", BaseURL: "http://v"}, + }, + } + + changes := BuildConfigChangeDetails(oldCfg, newCfg) + expectContains(t, changes, "gemini-api-key count: 0 -> 1") + expectContains(t, changes, "claude-api-key count: 0 -> 1") + expectContains(t, changes, "codex-api-key count: 0 -> 1") + expectContains(t, changes, "vertex-api-key count: 0 -> 1") +} + +func TestTrimStrings(t *testing.T) { + out := trimStrings([]string{" a ", "b", " c"}) + if len(out) != 3 || out[0] != "a" || out[1] != "b" || out[2] != "c" { + t.Fatalf("unexpected trimmed strings: %v", out) + } +} diff --git a/internal/watcher/diff/model_hash.go b/internal/watcher/diff/model_hash.go new file mode 100644 index 00000000..796b09cf --- /dev/null +++ b/internal/watcher/diff/model_hash.go @@ -0,0 +1,62 @@ +package diff + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "sort" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +// ComputeOpenAICompatModelsHash returns a stable hash for OpenAI-compat models. +// Used to detect model list changes during hot reload. +func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string { + if len(models) == 0 { + return "" + } + data, _ := json.Marshal(models) + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +// ComputeVertexCompatModelsHash returns a stable hash for Vertex-compatible models. +func ComputeVertexCompatModelsHash(models []config.VertexCompatModel) string { + if len(models) == 0 { + return "" + } + data, _ := json.Marshal(models) + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +// ComputeClaudeModelsHash returns a stable hash for Claude model aliases. +func ComputeClaudeModelsHash(models []config.ClaudeModel) string { + if len(models) == 0 { + return "" + } + data, _ := json.Marshal(models) + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +// ComputeExcludedModelsHash returns a normalized hash for excluded model lists. +func ComputeExcludedModelsHash(excluded []string) string { + if len(excluded) == 0 { + return "" + } + normalized := make([]string, 0, len(excluded)) + for _, entry := range excluded { + if trimmed := strings.TrimSpace(entry); trimmed != "" { + normalized = append(normalized, strings.ToLower(trimmed)) + } + } + if len(normalized) == 0 { + return "" + } + sort.Strings(normalized) + data, _ := json.Marshal(normalized) + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/watcher/diff/model_hash_test.go b/internal/watcher/diff/model_hash_test.go new file mode 100644 index 00000000..a91f97ab --- /dev/null +++ b/internal/watcher/diff/model_hash_test.go @@ -0,0 +1,104 @@ +package diff + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) { + models := []config.OpenAICompatibilityModel{ + {Name: "gpt-4", Alias: "gpt4"}, + {Name: "gpt-3.5-turbo"}, + } + hash1 := ComputeOpenAICompatModelsHash(models) + hash2 := ComputeOpenAICompatModelsHash(models) + if hash1 == "" { + t.Fatal("hash should not be empty") + } + if hash1 != hash2 { + t.Fatalf("hash should be deterministic, got %s vs %s", hash1, hash2) + } + changed := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-4"}, {Name: "gpt-4.1"}}) + if hash1 == changed { + t.Fatal("hash should change when model list changes") + } +} + +func TestComputeVertexCompatModelsHash_DifferentInputs(t *testing.T) { + models := []config.VertexCompatModel{{Name: "gemini-pro", Alias: "pro"}} + hash1 := ComputeVertexCompatModelsHash(models) + hash2 := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: "gemini-1.5-pro", Alias: "pro"}}) + if hash1 == "" || hash2 == "" { + t.Fatal("hashes should not be empty for non-empty models") + } + if hash1 == hash2 { + t.Fatal("hash should differ when model content differs") + } +} + +func TestComputeClaudeModelsHash_Empty(t *testing.T) { + if got := ComputeClaudeModelsHash(nil); got != "" { + t.Fatalf("expected empty hash for nil models, got %q", got) + } + if got := ComputeClaudeModelsHash([]config.ClaudeModel{}); got != "" { + t.Fatalf("expected empty hash for empty slice, got %q", got) + } +} + +func TestComputeExcludedModelsHash_Normalizes(t *testing.T) { + hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"}) + hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"}) + if hash1 == "" || hash2 == "" { + t.Fatal("hash should not be empty for non-empty input") + } + if hash1 != hash2 { + t.Fatalf("hash should be order/space insensitive for same multiset, got %s vs %s", hash1, hash2) + } + hash3 := ComputeExcludedModelsHash([]string{"c"}) + if hash1 == hash3 { + t.Fatal("hash should differ for different normalized sets") + } +} + +func TestComputeOpenAICompatModelsHash_Empty(t *testing.T) { + if got := ComputeOpenAICompatModelsHash(nil); got != "" { + t.Fatalf("expected empty hash for nil input, got %q", got) + } + if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{}); got != "" { + t.Fatalf("expected empty hash for empty slice, got %q", got) + } +} + +func TestComputeVertexCompatModelsHash_Empty(t *testing.T) { + if got := ComputeVertexCompatModelsHash(nil); got != "" { + t.Fatalf("expected empty hash for nil input, got %q", got) + } + if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{}); got != "" { + t.Fatalf("expected empty hash for empty slice, got %q", got) + } +} + +func TestComputeExcludedModelsHash_Empty(t *testing.T) { + if got := ComputeExcludedModelsHash(nil); got != "" { + t.Fatalf("expected empty hash for nil input, got %q", got) + } + if got := ComputeExcludedModelsHash([]string{}); got != "" { + t.Fatalf("expected empty hash for empty slice, got %q", got) + } + if got := ComputeExcludedModelsHash([]string{" ", ""}); got != "" { + t.Fatalf("expected empty hash for whitespace-only entries, got %q", got) + } +} + +func TestComputeClaudeModelsHash_Deterministic(t *testing.T) { + models := []config.ClaudeModel{{Name: "a", Alias: "A"}, {Name: "b"}} + h1 := ComputeClaudeModelsHash(models) + h2 := ComputeClaudeModelsHash(models) + if h1 == "" || h1 != h2 { + t.Fatalf("expected deterministic hash, got %s / %s", h1, h2) + } + if h3 := ComputeClaudeModelsHash([]config.ClaudeModel{{Name: "a"}}); h3 == h1 { + t.Fatalf("expected different hash when models change, got %s", h3) + } +} diff --git a/internal/watcher/diff/oauth_excluded.go b/internal/watcher/diff/oauth_excluded.go new file mode 100644 index 00000000..7fdfa9cb --- /dev/null +++ b/internal/watcher/diff/oauth_excluded.go @@ -0,0 +1,151 @@ +package diff + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +type excludedModelsSummary struct { + hash string + count int +} + +// SummarizeExcludedModels normalizes and hashes an excluded-model list. +func SummarizeExcludedModels(list []string) excludedModelsSummary { + if len(list) == 0 { + return excludedModelsSummary{} + } + seen := make(map[string]struct{}, len(list)) + normalized := make([]string, 0, len(list)) + for _, entry := range list { + if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" { + if _, exists := seen[trimmed]; exists { + continue + } + seen[trimmed] = struct{}{} + normalized = append(normalized, trimmed) + } + } + sort.Strings(normalized) + return excludedModelsSummary{ + hash: ComputeExcludedModelsHash(normalized), + count: len(normalized), + } +} + +// SummarizeOAuthExcludedModels summarizes OAuth excluded models per provider. +func SummarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary { + if len(entries) == 0 { + return nil + } + out := make(map[string]excludedModelsSummary, len(entries)) + for k, v := range entries { + key := strings.ToLower(strings.TrimSpace(k)) + if key == "" { + continue + } + out[key] = SummarizeExcludedModels(v) + } + return out +} + +// DiffOAuthExcludedModelChanges compares OAuth excluded models maps. +func DiffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) { + oldSummary := SummarizeOAuthExcludedModels(oldMap) + newSummary := SummarizeOAuthExcludedModels(newMap) + keys := make(map[string]struct{}, len(oldSummary)+len(newSummary)) + for k := range oldSummary { + keys[k] = struct{}{} + } + for k := range newSummary { + keys[k] = struct{}{} + } + changes := make([]string, 0, len(keys)) + affected := make([]string, 0, len(keys)) + for key := range keys { + oldInfo, okOld := oldSummary[key] + newInfo, okNew := newSummary[key] + switch { + case okOld && !okNew: + changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key)) + affected = append(affected, key) + case !okOld && okNew: + changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count)) + affected = append(affected, key) + case okOld && okNew && oldInfo.hash != newInfo.hash: + changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count)) + affected = append(affected, key) + } + } + sort.Strings(changes) + sort.Strings(affected) + return changes, affected +} + +type ampModelMappingsSummary struct { + hash string + count int +} + +// SummarizeAmpModelMappings hashes Amp model mappings for change detection. +func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) ampModelMappingsSummary { + if len(mappings) == 0 { + return ampModelMappingsSummary{} + } + entries := make([]string, 0, len(mappings)) + for _, mapping := range mappings { + from := strings.TrimSpace(mapping.From) + to := strings.TrimSpace(mapping.To) + if from == "" && to == "" { + continue + } + entries = append(entries, from+"->"+to) + } + if len(entries) == 0 { + return ampModelMappingsSummary{} + } + sort.Strings(entries) + sum := sha256.Sum256([]byte(strings.Join(entries, "|"))) + return ampModelMappingsSummary{ + hash: hex.EncodeToString(sum[:]), + count: len(entries), + } +} + +type vertexModelsSummary struct { + hash string + count int +} + +// SummarizeVertexModels hashes vertex-compatible models for change detection. +func SummarizeVertexModels(models []config.VertexCompatModel) vertexModelsSummary { + if len(models) == 0 { + return vertexModelsSummary{} + } + names := make([]string, 0, len(models)) + for _, m := range models { + name := strings.TrimSpace(m.Name) + alias := strings.TrimSpace(m.Alias) + if name == "" && alias == "" { + continue + } + if alias != "" { + name = alias + } + names = append(names, name) + } + if len(names) == 0 { + return vertexModelsSummary{} + } + sort.Strings(names) + sum := sha256.Sum256([]byte(strings.Join(names, "|"))) + return vertexModelsSummary{ + hash: hex.EncodeToString(sum[:]), + count: len(names), + } +} diff --git a/internal/watcher/diff/oauth_excluded_test.go b/internal/watcher/diff/oauth_excluded_test.go new file mode 100644 index 00000000..f5ad3913 --- /dev/null +++ b/internal/watcher/diff/oauth_excluded_test.go @@ -0,0 +1,109 @@ +package diff + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func TestSummarizeExcludedModels_NormalizesAndDedupes(t *testing.T) { + summary := SummarizeExcludedModels([]string{"A", " a ", "B", "b"}) + if summary.count != 2 { + t.Fatalf("expected 2 unique entries, got %d", summary.count) + } + if summary.hash == "" { + t.Fatal("expected non-empty hash") + } + if empty := SummarizeExcludedModels(nil); empty.count != 0 || empty.hash != "" { + t.Fatalf("expected empty summary for nil input, got %+v", empty) + } +} + +func TestDiffOAuthExcludedModelChanges(t *testing.T) { + oldMap := map[string][]string{ + "ProviderA": {"model-1", "model-2"}, + "providerB": {"x"}, + } + newMap := map[string][]string{ + "providerA": {"model-1", "model-3"}, + "providerC": {"y"}, + } + + changes, affected := DiffOAuthExcludedModelChanges(oldMap, newMap) + expectContains(t, changes, "oauth-excluded-models[providera]: updated (2 -> 2 entries)") + expectContains(t, changes, "oauth-excluded-models[providerb]: removed") + expectContains(t, changes, "oauth-excluded-models[providerc]: added (1 entries)") + + if len(affected) != 3 { + t.Fatalf("expected 3 affected providers, got %d", len(affected)) + } +} + +func TestSummarizeAmpModelMappings(t *testing.T) { + summary := SummarizeAmpModelMappings([]config.AmpModelMapping{ + {From: "a", To: "A"}, + {From: "b", To: "B"}, + {From: " ", To: " "}, // ignored + }) + if summary.count != 2 { + t.Fatalf("expected 2 entries, got %d", summary.count) + } + if summary.hash == "" { + t.Fatal("expected non-empty hash") + } + if empty := SummarizeAmpModelMappings(nil); empty.count != 0 || empty.hash != "" { + t.Fatalf("expected empty summary for nil input, got %+v", empty) + } + if blank := SummarizeAmpModelMappings([]config.AmpModelMapping{{From: " ", To: " "}}); blank.count != 0 || blank.hash != "" { + t.Fatalf("expected blank mappings ignored, got %+v", blank) + } +} + +func TestSummarizeOAuthExcludedModels_NormalizesKeys(t *testing.T) { + out := SummarizeOAuthExcludedModels(map[string][]string{ + "ProvA": {"X"}, + "": {"ignored"}, + }) + if len(out) != 1 { + t.Fatalf("expected only non-empty key summary, got %d", len(out)) + } + if _, ok := out["prova"]; !ok { + t.Fatalf("expected normalized key 'prova', got keys %v", out) + } + if out["prova"].count != 1 || out["prova"].hash == "" { + t.Fatalf("unexpected summary %+v", out["prova"]) + } + if outEmpty := SummarizeOAuthExcludedModels(nil); outEmpty != nil { + t.Fatalf("expected nil map for nil input, got %v", outEmpty) + } +} + +func TestSummarizeVertexModels(t *testing.T) { + summary := SummarizeVertexModels([]config.VertexCompatModel{ + {Name: "m1"}, + {Name: " ", Alias: "alias"}, + {}, // ignored + }) + if summary.count != 2 { + t.Fatalf("expected 2 vertex models, got %d", summary.count) + } + if summary.hash == "" { + t.Fatal("expected non-empty hash") + } + if empty := SummarizeVertexModels(nil); empty.count != 0 || empty.hash != "" { + t.Fatalf("expected empty summary for nil input, got %+v", empty) + } + if blank := SummarizeVertexModels([]config.VertexCompatModel{{Name: " "}}); blank.count != 0 || blank.hash != "" { + t.Fatalf("expected blank model ignored, got %+v", blank) + } +} + +func expectContains(t *testing.T, list []string, target string) { + t.Helper() + for _, entry := range list { + if entry == target { + return + } + } + t.Fatalf("expected list to contain %q, got %#v", target, list) +} diff --git a/internal/watcher/diff/openai_compat.go b/internal/watcher/diff/openai_compat.go new file mode 100644 index 00000000..dee47802 --- /dev/null +++ b/internal/watcher/diff/openai_compat.go @@ -0,0 +1,124 @@ +package diff + +import ( + "fmt" + "sort" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +// DiffOpenAICompatibility produces human-readable change descriptions. +func DiffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string { + changes := make([]string, 0) + oldMap := make(map[string]config.OpenAICompatibility, len(oldList)) + oldLabels := make(map[string]string, len(oldList)) + for idx, entry := range oldList { + key, label := openAICompatKey(entry, idx) + oldMap[key] = entry + oldLabels[key] = label + } + newMap := make(map[string]config.OpenAICompatibility, len(newList)) + newLabels := make(map[string]string, len(newList)) + for idx, entry := range newList { + key, label := openAICompatKey(entry, idx) + newMap[key] = entry + newLabels[key] = label + } + keySet := make(map[string]struct{}, len(oldMap)+len(newMap)) + for key := range oldMap { + keySet[key] = struct{}{} + } + for key := range newMap { + keySet[key] = struct{}{} + } + orderedKeys := make([]string, 0, len(keySet)) + for key := range keySet { + orderedKeys = append(orderedKeys, key) + } + sort.Strings(orderedKeys) + for _, key := range orderedKeys { + oldEntry, oldOk := oldMap[key] + newEntry, newOk := newMap[key] + label := oldLabels[key] + if label == "" { + label = newLabels[key] + } + switch { + case !oldOk: + changes = append(changes, fmt.Sprintf("provider added: %s (api-keys=%d, models=%d)", label, countAPIKeys(newEntry), countOpenAIModels(newEntry.Models))) + case !newOk: + changes = append(changes, fmt.Sprintf("provider removed: %s (api-keys=%d, models=%d)", label, countAPIKeys(oldEntry), countOpenAIModels(oldEntry.Models))) + default: + if detail := describeOpenAICompatibilityUpdate(oldEntry, newEntry); detail != "" { + changes = append(changes, fmt.Sprintf("provider updated: %s %s", label, detail)) + } + } + } + return changes +} + +func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibility) string { + oldKeyCount := countAPIKeys(oldEntry) + newKeyCount := countAPIKeys(newEntry) + oldModelCount := countOpenAIModels(oldEntry.Models) + newModelCount := countOpenAIModels(newEntry.Models) + details := make([]string, 0, 3) + if oldKeyCount != newKeyCount { + details = append(details, fmt.Sprintf("api-keys %d -> %d", oldKeyCount, newKeyCount)) + } + if oldModelCount != newModelCount { + details = append(details, fmt.Sprintf("models %d -> %d", oldModelCount, newModelCount)) + } + if !equalStringMap(oldEntry.Headers, newEntry.Headers) { + details = append(details, "headers updated") + } + if len(details) == 0 { + return "" + } + return "(" + strings.Join(details, ", ") + ")" +} + +func countAPIKeys(entry config.OpenAICompatibility) int { + count := 0 + for _, keyEntry := range entry.APIKeyEntries { + if strings.TrimSpace(keyEntry.APIKey) != "" { + count++ + } + } + return count +} + +func countOpenAIModels(models []config.OpenAICompatibilityModel) int { + count := 0 + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + count++ + } + return count +} + +func openAICompatKey(entry config.OpenAICompatibility, index int) (string, string) { + name := strings.TrimSpace(entry.Name) + if name != "" { + return "name:" + name, name + } + base := strings.TrimSpace(entry.BaseURL) + if base != "" { + return "base:" + base, base + } + for _, model := range entry.Models { + alias := strings.TrimSpace(model.Alias) + if alias == "" { + alias = strings.TrimSpace(model.Name) + } + if alias != "" { + return "alias:" + alias, alias + } + } + return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1) +} diff --git a/internal/watcher/diff/openai_compat_test.go b/internal/watcher/diff/openai_compat_test.go new file mode 100644 index 00000000..c0ec9090 --- /dev/null +++ b/internal/watcher/diff/openai_compat_test.go @@ -0,0 +1,113 @@ +package diff + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func TestDiffOpenAICompatibility(t *testing.T) { + oldList := []config.OpenAICompatibility{ + { + Name: "provider-a", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "key-a"}, + }, + Models: []config.OpenAICompatibilityModel{ + {Name: "m1"}, + }, + }, + } + newList := []config.OpenAICompatibility{ + { + Name: "provider-a", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "key-a"}, + {APIKey: "key-b"}, + }, + Models: []config.OpenAICompatibilityModel{ + {Name: "m1"}, + {Name: "m2"}, + }, + Headers: map[string]string{"X-Test": "1"}, + }, + { + Name: "provider-b", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-b"}}, + }, + } + + changes := DiffOpenAICompatibility(oldList, newList) + expectContains(t, changes, "provider added: provider-b (api-keys=1, models=0)") + expectContains(t, changes, "provider updated: provider-a (api-keys 1 -> 2, models 1 -> 2, headers updated)") +} + +func TestDiffOpenAICompatibility_RemovedAndUnchanged(t *testing.T) { + oldList := []config.OpenAICompatibility{ + { + Name: "provider-a", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}}, + Models: []config.OpenAICompatibilityModel{{Name: "m1"}}, + }, + } + newList := []config.OpenAICompatibility{ + { + Name: "provider-a", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}}, + Models: []config.OpenAICompatibilityModel{{Name: "m1"}}, + }, + } + if changes := DiffOpenAICompatibility(oldList, newList); len(changes) != 0 { + t.Fatalf("expected no changes, got %v", changes) + } + + newList = nil + changes := DiffOpenAICompatibility(oldList, newList) + expectContains(t, changes, "provider removed: provider-a (api-keys=1, models=1)") +} + +func TestOpenAICompatKeyFallbacks(t *testing.T) { + entry := config.OpenAICompatibility{ + BaseURL: "http://base", + Models: []config.OpenAICompatibilityModel{{Alias: "alias-only"}}, + } + key, label := openAICompatKey(entry, 0) + if key != "base:http://base" || label != "http://base" { + t.Fatalf("expected base key, got %s/%s", key, label) + } + + entry.BaseURL = "" + key, label = openAICompatKey(entry, 1) + if key != "alias:alias-only" || label != "alias-only" { + t.Fatalf("expected alias fallback, got %s/%s", key, label) + } + + entry.Models = nil + key, label = openAICompatKey(entry, 2) + if key != "index:2" || label != "entry-3" { + t.Fatalf("expected index fallback, got %s/%s", key, label) + } +} + +func TestCountOpenAIModelsSkipsBlanks(t *testing.T) { + models := []config.OpenAICompatibilityModel{ + {Name: "m1"}, + {Name: ""}, + {Alias: ""}, + {Name: " "}, + {Alias: "a1"}, + } + if got := countOpenAIModels(models); got != 2 { + t.Fatalf("expected 2 counted models, got %d", got) + } +} + +func TestOpenAICompatKeyUsesModelNameWhenAliasEmpty(t *testing.T) { + entry := config.OpenAICompatibility{ + Models: []config.OpenAICompatibilityModel{{Name: "model-name"}}, + } + key, label := openAICompatKey(entry, 5) + if key != "alias:model-name" || label != "model-name" { + t.Fatalf("expected model-name fallback, got %s/%s", key, label) + } +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 43a3a3dc..8f03bf5b 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -23,6 +23,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" + "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" "gopkg.in/yaml.v3" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" @@ -485,170 +486,6 @@ func normalizeAuth(a *coreauth.Auth) *coreauth.Auth { return clone } -// computeOpenAICompatModelsHash returns a stable hash for the compatibility models so that -// changes to the model list trigger auth updates during hot reload. -func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string { - if len(models) == 0 { - return "" - } - data, err := json.Marshal(models) - if err != nil || len(data) == 0 { - return "" - } - sum := sha256.Sum256(data) - return hex.EncodeToString(sum[:]) -} - -func computeVertexCompatModelsHash(models []config.VertexCompatModel) string { - if len(models) == 0 { - return "" - } - data, err := json.Marshal(models) - if err != nil || len(data) == 0 { - return "" - } - sum := sha256.Sum256(data) - return hex.EncodeToString(sum[:]) -} - -// computeClaudeModelsHash returns a stable hash for Claude model aliases. -func computeClaudeModelsHash(models []config.ClaudeModel) string { - if len(models) == 0 { - return "" - } - data, err := json.Marshal(models) - if err != nil || len(data) == 0 { - return "" - } - sum := sha256.Sum256(data) - return hex.EncodeToString(sum[:]) -} - -func computeExcludedModelsHash(excluded []string) string { - if len(excluded) == 0 { - return "" - } - normalized := make([]string, 0, len(excluded)) - for _, entry := range excluded { - if trimmed := strings.TrimSpace(entry); trimmed != "" { - normalized = append(normalized, strings.ToLower(trimmed)) - } - } - if len(normalized) == 0 { - return "" - } - sort.Strings(normalized) - data, err := json.Marshal(normalized) - if err != nil || len(data) == 0 { - return "" - } - sum := sha256.Sum256(data) - return hex.EncodeToString(sum[:]) -} - -type excludedModelsSummary struct { - hash string - count int -} - -func summarizeExcludedModels(list []string) excludedModelsSummary { - if len(list) == 0 { - return excludedModelsSummary{} - } - seen := make(map[string]struct{}, len(list)) - normalized := make([]string, 0, len(list)) - for _, entry := range list { - if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" { - if _, exists := seen[trimmed]; exists { - continue - } - seen[trimmed] = struct{}{} - normalized = append(normalized, trimmed) - } - } - sort.Strings(normalized) - return excludedModelsSummary{ - hash: computeExcludedModelsHash(normalized), - count: len(normalized), - } -} - -type ampModelMappingsSummary struct { - hash string - count int -} - -func summarizeAmpModelMappings(mappings []config.AmpModelMapping) ampModelMappingsSummary { - if len(mappings) == 0 { - return ampModelMappingsSummary{} - } - entries := make([]string, 0, len(mappings)) - for _, mapping := range mappings { - from := strings.TrimSpace(mapping.From) - to := strings.TrimSpace(mapping.To) - if from == "" && to == "" { - continue - } - entries = append(entries, from+"->"+to) - } - if len(entries) == 0 { - return ampModelMappingsSummary{} - } - sort.Strings(entries) - sum := sha256.Sum256([]byte(strings.Join(entries, "|"))) - return ampModelMappingsSummary{ - hash: hex.EncodeToString(sum[:]), - count: len(entries), - } -} - -func summarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary { - if len(entries) == 0 { - return nil - } - out := make(map[string]excludedModelsSummary, len(entries)) - for k, v := range entries { - key := strings.ToLower(strings.TrimSpace(k)) - if key == "" { - continue - } - out[key] = summarizeExcludedModels(v) - } - return out -} - -func diffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) { - oldSummary := summarizeOAuthExcludedModels(oldMap) - newSummary := summarizeOAuthExcludedModels(newMap) - keys := make(map[string]struct{}, len(oldSummary)+len(newSummary)) - for k := range oldSummary { - keys[k] = struct{}{} - } - for k := range newSummary { - keys[k] = struct{}{} - } - changes := make([]string, 0, len(keys)) - affected := make([]string, 0, len(keys)) - for key := range keys { - oldInfo, okOld := oldSummary[key] - newInfo, okNew := newSummary[key] - switch { - case okOld && !okNew: - changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key)) - affected = append(affected, key) - case !okOld && okNew: - changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count)) - affected = append(affected, key) - case okOld && okNew && oldInfo.hash != newInfo.hash: - changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count)) - affected = append(affected, key) - } - } - sort.Strings(changes) - sort.Strings(affected) - return changes, affected -} - func applyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) { if auth == nil || cfg == nil { return @@ -677,7 +514,7 @@ func applyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey combined = append(combined, k) } sort.Strings(combined) - hash := computeExcludedModelsHash(combined) + hash := diff.ComputeExcludedModelsHash(combined) if auth.Attributes == nil { auth.Attributes = make(map[string]string) } @@ -924,7 +761,7 @@ func (w *Watcher) reloadConfig() bool { var affectedOAuthProviders []string if oldConfig != nil { - _, affectedOAuthProviders = diffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels) + _, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels) } // Always apply the current log level based on the latest config. @@ -937,7 +774,7 @@ func (w *Watcher) reloadConfig() bool { // Log configuration changes in debug mode, only when there are material diffs if oldConfig != nil { - details := buildConfigChangeDetails(oldConfig, newConfig) + details := diff.BuildConfigChangeDetails(oldConfig, newConfig) if len(details) > 0 { log.Debugf("config changes detected:") for _, d := range details { @@ -1188,7 +1025,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if base != "" { attrs["base_url"] = base } - if hash := computeClaudeModelsHash(ck.Models); hash != "" { + if hash := diff.ComputeClaudeModelsHash(ck.Models); hash != "" { attrs["models_hash"] = hash } addConfigHeadersToAttrs(ck.Headers, attrs) @@ -1261,7 +1098,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if key != "" { attrs["api_key"] = key } - if hash := computeOpenAICompatModelsHash(compat.Models); hash != "" { + if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" { attrs["models_hash"] = hash } addConfigHeadersToAttrs(compat.Headers, attrs) @@ -1287,7 +1124,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { "compat_name": compat.Name, "provider_key": providerName, } - if hash := computeOpenAICompatModelsHash(compat.Models); hash != "" { + if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" { attrs["models_hash"] = hash } addConfigHeadersToAttrs(compat.Headers, attrs) @@ -1323,7 +1160,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if key != "" { attrs["api_key"] = key } - if hash := computeVertexCompatModelsHash(compat.Models); hash != "" { + if hash := diff.ComputeVertexCompatModelsHash(compat.Models); hash != "" { attrs["models_hash"] = hash } addConfigHeadersToAttrs(compat.Headers, attrs) @@ -1586,329 +1423,6 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) { return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount } -func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string { - changes := make([]string, 0) - oldMap := make(map[string]config.OpenAICompatibility, len(oldList)) - oldLabels := make(map[string]string, len(oldList)) - for idx, entry := range oldList { - key, label := openAICompatKey(entry, idx) - oldMap[key] = entry - oldLabels[key] = label - } - newMap := make(map[string]config.OpenAICompatibility, len(newList)) - newLabels := make(map[string]string, len(newList)) - for idx, entry := range newList { - key, label := openAICompatKey(entry, idx) - newMap[key] = entry - newLabels[key] = label - } - keySet := make(map[string]struct{}, len(oldMap)+len(newMap)) - for key := range oldMap { - keySet[key] = struct{}{} - } - for key := range newMap { - keySet[key] = struct{}{} - } - orderedKeys := make([]string, 0, len(keySet)) - for key := range keySet { - orderedKeys = append(orderedKeys, key) - } - sort.Strings(orderedKeys) - for _, key := range orderedKeys { - oldEntry, oldOk := oldMap[key] - newEntry, newOk := newMap[key] - label := oldLabels[key] - if label == "" { - label = newLabels[key] - } - switch { - case !oldOk: - changes = append(changes, fmt.Sprintf("provider added: %s (api-keys=%d, models=%d)", label, countAPIKeys(newEntry), countOpenAIModels(newEntry.Models))) - case !newOk: - changes = append(changes, fmt.Sprintf("provider removed: %s (api-keys=%d, models=%d)", label, countAPIKeys(oldEntry), countOpenAIModels(oldEntry.Models))) - default: - if detail := describeOpenAICompatibilityUpdate(oldEntry, newEntry); detail != "" { - changes = append(changes, fmt.Sprintf("provider updated: %s %s", label, detail)) - } - } - } - return changes -} - -func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibility) string { - oldKeyCount := countAPIKeys(oldEntry) - newKeyCount := countAPIKeys(newEntry) - oldModelCount := countOpenAIModels(oldEntry.Models) - newModelCount := countOpenAIModels(newEntry.Models) - details := make([]string, 0, 3) - if oldKeyCount != newKeyCount { - details = append(details, fmt.Sprintf("api-keys %d -> %d", oldKeyCount, newKeyCount)) - } - if oldModelCount != newModelCount { - details = append(details, fmt.Sprintf("models %d -> %d", oldModelCount, newModelCount)) - } - if !equalStringMap(oldEntry.Headers, newEntry.Headers) { - details = append(details, "headers updated") - } - if len(details) == 0 { - return "" - } - return "(" + strings.Join(details, ", ") + ")" -} - -func countAPIKeys(entry config.OpenAICompatibility) int { - count := 0 - for _, keyEntry := range entry.APIKeyEntries { - if strings.TrimSpace(keyEntry.APIKey) != "" { - count++ - } - } - return count -} - -func countOpenAIModels(models []config.OpenAICompatibilityModel) int { - count := 0 - for _, model := range models { - name := strings.TrimSpace(model.Name) - alias := strings.TrimSpace(model.Alias) - if name == "" && alias == "" { - continue - } - count++ - } - return count -} - -func openAICompatKey(entry config.OpenAICompatibility, index int) (string, string) { - name := strings.TrimSpace(entry.Name) - if name != "" { - return "name:" + name, name - } - base := strings.TrimSpace(entry.BaseURL) - if base != "" { - return "base:" + base, base - } - for _, model := range entry.Models { - alias := strings.TrimSpace(model.Alias) - if alias == "" { - alias = strings.TrimSpace(model.Name) - } - if alias != "" { - return "alias:" + alias, alias - } - } - return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1) -} - -// buildConfigChangeDetails computes a redacted, human-readable list of config changes. -// It avoids printing secrets (like API keys) and focuses on structural or non-sensitive fields. -func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { - changes := make([]string, 0, 16) - if oldCfg == nil || newCfg == nil { - return changes - } - - // Simple scalars - if oldCfg.Port != newCfg.Port { - changes = append(changes, fmt.Sprintf("port: %d -> %d", oldCfg.Port, newCfg.Port)) - } - if oldCfg.AuthDir != newCfg.AuthDir { - changes = append(changes, fmt.Sprintf("auth-dir: %s -> %s", oldCfg.AuthDir, newCfg.AuthDir)) - } - if oldCfg.Debug != newCfg.Debug { - changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug)) - } - if oldCfg.LoggingToFile != newCfg.LoggingToFile { - changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile)) - } - if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled { - changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled)) - } - if oldCfg.DisableCooling != newCfg.DisableCooling { - changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling)) - } - if oldCfg.RequestLog != newCfg.RequestLog { - changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog)) - } - if oldCfg.RequestRetry != newCfg.RequestRetry { - changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry)) - } - if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval { - changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval)) - } - if oldCfg.ProxyURL != newCfg.ProxyURL { - changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL)) - } - if oldCfg.WebsocketAuth != newCfg.WebsocketAuth { - changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth)) - } - - // Quota-exceeded behavior - if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject { - changes = append(changes, fmt.Sprintf("quota-exceeded.switch-project: %t -> %t", oldCfg.QuotaExceeded.SwitchProject, newCfg.QuotaExceeded.SwitchProject)) - } - if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel { - changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel)) - } - - // API keys (redacted) and counts - if len(oldCfg.APIKeys) != len(newCfg.APIKeys) { - changes = append(changes, fmt.Sprintf("api-keys count: %d -> %d", len(oldCfg.APIKeys), len(newCfg.APIKeys))) - } else if !reflect.DeepEqual(trimStrings(oldCfg.APIKeys), trimStrings(newCfg.APIKeys)) { - changes = append(changes, "api-keys: values updated (count unchanged, redacted)") - } - if len(oldCfg.GeminiKey) != len(newCfg.GeminiKey) { - changes = append(changes, fmt.Sprintf("gemini-api-key count: %d -> %d", len(oldCfg.GeminiKey), len(newCfg.GeminiKey))) - } else { - for i := range oldCfg.GeminiKey { - if i >= len(newCfg.GeminiKey) { - break - } - o := oldCfg.GeminiKey[i] - n := newCfg.GeminiKey[i] - if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) { - changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) - } - if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { - changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) - } - if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { - changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i)) - } - if !equalStringMap(o.Headers, n.Headers) { - changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i)) - } - oldExcluded := summarizeExcludedModels(o.ExcludedModels) - newExcluded := summarizeExcludedModels(n.ExcludedModels) - if oldExcluded.hash != newExcluded.hash { - changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) - } - } - } - - // Claude keys (do not print key material) - if len(oldCfg.ClaudeKey) != len(newCfg.ClaudeKey) { - changes = append(changes, fmt.Sprintf("claude-api-key count: %d -> %d", len(oldCfg.ClaudeKey), len(newCfg.ClaudeKey))) - } else { - for i := range oldCfg.ClaudeKey { - if i >= len(newCfg.ClaudeKey) { - break - } - o := oldCfg.ClaudeKey[i] - n := newCfg.ClaudeKey[i] - if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) { - changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) - } - if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { - changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) - } - if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { - changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i)) - } - if !equalStringMap(o.Headers, n.Headers) { - changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i)) - } - oldExcluded := summarizeExcludedModels(o.ExcludedModels) - newExcluded := summarizeExcludedModels(n.ExcludedModels) - if oldExcluded.hash != newExcluded.hash { - changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) - } - } - } - - // Codex keys (do not print key material) - if len(oldCfg.CodexKey) != len(newCfg.CodexKey) { - changes = append(changes, fmt.Sprintf("codex-api-key count: %d -> %d", len(oldCfg.CodexKey), len(newCfg.CodexKey))) - } else { - for i := range oldCfg.CodexKey { - if i >= len(newCfg.CodexKey) { - break - } - o := oldCfg.CodexKey[i] - n := newCfg.CodexKey[i] - if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) { - changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) - } - if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { - changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) - } - if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { - changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i)) - } - if !equalStringMap(o.Headers, n.Headers) { - changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i)) - } - oldExcluded := summarizeExcludedModels(o.ExcludedModels) - newExcluded := summarizeExcludedModels(n.ExcludedModels) - if oldExcluded.hash != newExcluded.hash { - changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) - } - } - } - - // AmpCode settings (redacted where needed) - oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL) - newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL) - if oldAmpURL != newAmpURL { - changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL)) - } - oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey) - newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey) - switch { - case oldAmpKey == "" && newAmpKey != "": - changes = append(changes, "ampcode.upstream-api-key: added") - case oldAmpKey != "" && newAmpKey == "": - changes = append(changes, "ampcode.upstream-api-key: removed") - case oldAmpKey != newAmpKey: - changes = append(changes, "ampcode.upstream-api-key: updated") - } - if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost { - changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost)) - } - oldMappings := summarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings) - newMappings := summarizeAmpModelMappings(newCfg.AmpCode.ModelMappings) - if oldMappings.hash != newMappings.hash { - changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count)) - } - - if entries, _ := diffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { - changes = append(changes, entries...) - } - - // Remote management (never print the key) - if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote { - changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote)) - } - if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel { - changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel)) - } - oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository) - newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository) - if oldPanelRepo != newPanelRepo { - changes = append(changes, fmt.Sprintf("remote-management.panel-github-repository: %s -> %s", oldPanelRepo, newPanelRepo)) - } - if oldCfg.RemoteManagement.SecretKey != newCfg.RemoteManagement.SecretKey { - switch { - case oldCfg.RemoteManagement.SecretKey == "" && newCfg.RemoteManagement.SecretKey != "": - changes = append(changes, "remote-management.secret-key: created") - case oldCfg.RemoteManagement.SecretKey != "" && newCfg.RemoteManagement.SecretKey == "": - changes = append(changes, "remote-management.secret-key: deleted") - default: - changes = append(changes, "remote-management.secret-key: updated") - } - } - - // OpenAI compatibility providers (summarized) - if compat := diffOpenAICompatibility(oldCfg.OpenAICompatibility, newCfg.OpenAICompatibility); len(compat) > 0 { - changes = append(changes, "openai-compatibility:") - for _, c := range compat { - changes = append(changes, " "+c) - } - } - - return changes -} - func addConfigHeadersToAttrs(headers map[string]string, attrs map[string]string) { if len(headers) == 0 || attrs == nil { return @@ -1922,23 +1436,3 @@ func addConfigHeadersToAttrs(headers map[string]string, attrs map[string]string) attrs["header:"+key] = val } } - -func trimStrings(in []string) []string { - out := make([]string, len(in)) - for i := range in { - out[i] = strings.TrimSpace(in[i]) - } - return out -} - -func equalStringMap(a, b map[string]string) bool { - if len(a) != len(b) { - return false - } - for k, v := range a { - if b[k] != v { - return false - } - } - return true -} diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go new file mode 100644 index 00000000..71fb1852 --- /dev/null +++ b/internal/watcher/watcher_test.go @@ -0,0 +1,609 @@ +package watcher + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "github.com/fsnotify/fsnotify" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "gopkg.in/yaml.v3" +) + +func TestApplyAuthExcludedModelsMeta_APIKey(t *testing.T) { + auth := &coreauth.Auth{Attributes: map[string]string{}} + cfg := &config.Config{} + perKey := []string{" Model-1 ", "model-2"} + + applyAuthExcludedModelsMeta(auth, cfg, perKey, "apikey") + + expected := diff.ComputeExcludedModelsHash([]string{"model-1", "model-2"}) + if got := auth.Attributes["excluded_models_hash"]; got != expected { + t.Fatalf("expected hash %s, got %s", expected, got) + } + if got := auth.Attributes["auth_kind"]; got != "apikey" { + t.Fatalf("expected auth_kind=apikey, got %s", got) + } +} + +func TestApplyAuthExcludedModelsMeta_OAuthProvider(t *testing.T) { + auth := &coreauth.Auth{ + Provider: "TestProv", + Attributes: map[string]string{}, + } + cfg := &config.Config{ + OAuthExcludedModels: map[string][]string{ + "testprov": {"A", "b"}, + }, + } + + applyAuthExcludedModelsMeta(auth, cfg, nil, "oauth") + + expected := diff.ComputeExcludedModelsHash([]string{"a", "b"}) + if got := auth.Attributes["excluded_models_hash"]; got != expected { + t.Fatalf("expected hash %s, got %s", expected, got) + } + if got := auth.Attributes["auth_kind"]; got != "oauth" { + t.Fatalf("expected auth_kind=oauth, got %s", got) + } +} + +func TestBuildAPIKeyClientsCounts(t *testing.T) { + cfg := &config.Config{ + GeminiKey: []config.GeminiKey{{APIKey: "g1"}, {APIKey: "g2"}}, + VertexCompatAPIKey: []config.VertexCompatKey{ + {APIKey: "v1"}, + }, + ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}}, + CodexKey: []config.CodexKey{{APIKey: "x1"}, {APIKey: "x2"}}, + OpenAICompatibility: []config.OpenAICompatibility{ + {APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "o1"}, {APIKey: "o2"}}}, + }, + } + + gemini, vertex, claude, codex, compat := BuildAPIKeyClients(cfg) + if gemini != 2 || vertex != 1 || claude != 1 || codex != 2 || compat != 2 { + t.Fatalf("unexpected counts: %d %d %d %d %d", gemini, vertex, claude, codex, compat) + } +} + +func TestNormalizeAuthStripsTemporalFields(t *testing.T) { + now := time.Now() + auth := &coreauth.Auth{ + CreatedAt: now, + UpdatedAt: now, + LastRefreshedAt: now, + NextRefreshAfter: now, + Quota: coreauth.QuotaState{ + NextRecoverAt: now, + }, + Runtime: map[string]any{"k": "v"}, + } + + normalized := normalizeAuth(auth) + if !normalized.CreatedAt.IsZero() || !normalized.UpdatedAt.IsZero() || !normalized.LastRefreshedAt.IsZero() || !normalized.NextRefreshAfter.IsZero() { + t.Fatal("expected time fields to be zeroed") + } + if normalized.Runtime != nil { + t.Fatal("expected runtime to be nil") + } + if !normalized.Quota.NextRecoverAt.IsZero() { + t.Fatal("expected quota.NextRecoverAt to be zeroed") + } +} + +func TestMatchProvider(t *testing.T) { + if _, ok := matchProvider("OpenAI", []string{"openai", "claude"}); !ok { + t.Fatal("expected match to succeed ignoring case") + } + if _, ok := matchProvider("missing", []string{"openai"}); ok { + t.Fatal("expected match to fail for unknown provider") + } +} + +func TestSnapshotCoreAuths_ConfigAndAuthFiles(t *testing.T) { + authDir := t.TempDir() + metadata := map[string]any{ + "type": "gemini", + "email": "user@example.com", + "project_id": "proj-a, proj-b", + "proxy_url": "https://proxy", + } + authFile := filepath.Join(authDir, "gemini.json") + data, err := json.Marshal(metadata) + if err != nil { + t.Fatalf("failed to marshal metadata: %v", err) + } + if err = os.WriteFile(authFile, data, 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + + cfg := &config.Config{ + AuthDir: authDir, + GeminiKey: []config.GeminiKey{ + { + APIKey: "g-key", + BaseURL: "https://gemini", + ExcludedModels: []string{"Model-A", "model-b"}, + Headers: map[string]string{"X-Req": "1"}, + }, + }, + OAuthExcludedModels: map[string][]string{ + "gemini-cli": {"Foo", "bar"}, + }, + } + + w := &Watcher{authDir: authDir} + w.SetConfig(cfg) + + auths := w.SnapshotCoreAuths() + if len(auths) != 4 { + t.Fatalf("expected 4 auth entries (1 config + 1 primary + 2 virtual), got %d", len(auths)) + } + + var geminiAPIKeyAuth *coreauth.Auth + var geminiPrimary *coreauth.Auth + virtuals := make([]*coreauth.Auth, 0) + for _, a := range auths { + switch { + case a.Provider == "gemini" && a.Attributes["api_key"] == "g-key": + geminiAPIKeyAuth = a + case a.Attributes["gemini_virtual_primary"] == "true": + geminiPrimary = a + case strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) != "": + virtuals = append(virtuals, a) + } + } + if geminiAPIKeyAuth == nil { + t.Fatal("expected synthesized Gemini API key auth") + } + expectedAPIKeyHash := diff.ComputeExcludedModelsHash([]string{"Model-A", "model-b"}) + if geminiAPIKeyAuth.Attributes["excluded_models_hash"] != expectedAPIKeyHash { + t.Fatalf("expected API key excluded hash %s, got %s", expectedAPIKeyHash, geminiAPIKeyAuth.Attributes["excluded_models_hash"]) + } + if geminiAPIKeyAuth.Attributes["auth_kind"] != "apikey" { + t.Fatalf("expected auth_kind=apikey, got %s", geminiAPIKeyAuth.Attributes["auth_kind"]) + } + + if geminiPrimary == nil { + t.Fatal("expected primary gemini-cli auth from file") + } + if !geminiPrimary.Disabled || geminiPrimary.Status != coreauth.StatusDisabled { + t.Fatal("expected primary gemini-cli auth to be disabled when virtual auths are synthesized") + } + expectedOAuthHash := diff.ComputeExcludedModelsHash([]string{"Foo", "bar"}) + if geminiPrimary.Attributes["excluded_models_hash"] != expectedOAuthHash { + t.Fatalf("expected OAuth excluded hash %s, got %s", expectedOAuthHash, geminiPrimary.Attributes["excluded_models_hash"]) + } + if geminiPrimary.Attributes["auth_kind"] != "oauth" { + t.Fatalf("expected auth_kind=oauth, got %s", geminiPrimary.Attributes["auth_kind"]) + } + + if len(virtuals) != 2 { + t.Fatalf("expected 2 virtual auths, got %d", len(virtuals)) + } + for _, v := range virtuals { + if v.Attributes["gemini_virtual_parent"] != geminiPrimary.ID { + t.Fatalf("virtual auth missing parent link to %s", geminiPrimary.ID) + } + if v.Attributes["excluded_models_hash"] != expectedOAuthHash { + t.Fatalf("expected virtual excluded hash %s, got %s", expectedOAuthHash, v.Attributes["excluded_models_hash"]) + } + if v.Status != coreauth.StatusActive { + t.Fatalf("expected virtual auth to be active, got %s", v.Status) + } + } +} + +func TestReloadConfigIfChanged_TriggersOnChangeAndSkipsUnchanged(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + + configPath := filepath.Join(tmpDir, "config.yaml") + writeConfig := func(port int, allowRemote bool) { + cfg := &config.Config{ + Port: port, + AuthDir: authDir, + RemoteManagement: config.RemoteManagement{ + AllowRemote: allowRemote, + }, + } + data, err := yaml.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + if err = os.WriteFile(configPath, data, 0o644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + } + + writeConfig(8080, false) + + reloads := 0 + w := &Watcher{ + configPath: configPath, + authDir: authDir, + reloadCallback: func(*config.Config) { reloads++ }, + } + + w.reloadConfigIfChanged() + if reloads != 1 { + t.Fatalf("expected first reload to trigger callback once, got %d", reloads) + } + + // Same content should be skipped by hash check. + w.reloadConfigIfChanged() + if reloads != 1 { + t.Fatalf("expected unchanged config to be skipped, callback count %d", reloads) + } + + writeConfig(9090, true) + w.reloadConfigIfChanged() + if reloads != 2 { + t.Fatalf("expected changed config to trigger reload, callback count %d", reloads) + } + w.clientsMutex.RLock() + defer w.clientsMutex.RUnlock() + if w.config == nil || w.config.Port != 9090 || !w.config.RemoteManagement.AllowRemote { + t.Fatalf("expected config to be updated after reload, got %+v", w.config) + } +} + +func TestStartAndStopSuccess(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir), 0o644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + + var reloads int32 + w, err := NewWatcher(configPath, authDir, func(*config.Config) { + atomic.AddInt32(&reloads, 1) + }) + if err != nil { + t.Fatalf("failed to create watcher: %v", err) + } + w.SetConfig(&config.Config{AuthDir: authDir}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := w.Start(ctx); err != nil { + t.Fatalf("expected Start to succeed: %v", err) + } + cancel() + if err := w.Stop(); err != nil { + t.Fatalf("expected Stop to succeed: %v", err) + } + if got := atomic.LoadInt32(&reloads); got != 1 { + t.Fatalf("expected one reload callback, got %d", got) + } +} + +func TestStartFailsWhenConfigMissing(t *testing.T) { + tmpDir := t.TempDir() + authDir := filepath.Join(tmpDir, "auth") + if err := os.MkdirAll(authDir, 0o755); err != nil { + t.Fatalf("failed to create auth dir: %v", err) + } + configPath := filepath.Join(tmpDir, "missing-config.yaml") + + w, err := NewWatcher(configPath, authDir, nil) + if err != nil { + t.Fatalf("failed to create watcher: %v", err) + } + defer w.Stop() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := w.Start(ctx); err == nil { + t.Fatal("expected Start to fail for missing config file") + } +} + +func TestDispatchRuntimeAuthUpdateEnqueuesAndUpdatesState(t *testing.T) { + queue := make(chan AuthUpdate, 4) + w := &Watcher{} + w.SetAuthUpdateQueue(queue) + defer w.stopDispatch() + + auth := &coreauth.Auth{ID: "auth-1", Provider: "test"} + if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionAdd, Auth: auth}); !ok { + t.Fatal("expected DispatchRuntimeAuthUpdate to enqueue") + } + + select { + case update := <-queue: + if update.Action != AuthUpdateActionAdd || update.Auth.ID != "auth-1" { + t.Fatalf("unexpected update: %+v", update) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for auth update") + } + + if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionDelete, ID: "auth-1"}); !ok { + t.Fatal("expected delete update to enqueue") + } + select { + case update := <-queue: + if update.Action != AuthUpdateActionDelete || update.ID != "auth-1" { + t.Fatalf("unexpected delete update: %+v", update) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for delete update") + } + w.clientsMutex.RLock() + if _, exists := w.runtimeAuths["auth-1"]; exists { + w.clientsMutex.RUnlock() + t.Fatal("expected runtime auth to be cleared after delete") + } + w.clientsMutex.RUnlock() +} + +func TestAddOrUpdateClientSkipsUnchanged(t *testing.T) { + tmpDir := t.TempDir() + authFile := filepath.Join(tmpDir, "sample.json") + if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil { + t.Fatalf("failed to create auth file: %v", err) + } + data, _ := os.ReadFile(authFile) + sum := sha256.Sum256(data) + + var reloads int32 + w := &Watcher{ + authDir: tmpDir, + lastAuthHashes: map[string]string{ + filepath.Clean(authFile): hexString(sum[:]), + }, + reloadCallback: func(*config.Config) { + atomic.AddInt32(&reloads, 1) + }, + } + w.SetConfig(&config.Config{AuthDir: tmpDir}) + + w.addOrUpdateClient(authFile) + if got := atomic.LoadInt32(&reloads); got != 0 { + t.Fatalf("expected no reload for unchanged file, got %d", got) + } +} + +func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) { + tmpDir := t.TempDir() + authFile := filepath.Join(tmpDir, "sample.json") + if err := os.WriteFile(authFile, []byte(`{"type":"demo","api_key":"k"}`), 0o644); err != nil { + t.Fatalf("failed to create auth file: %v", err) + } + + var reloads int32 + w := &Watcher{ + authDir: tmpDir, + lastAuthHashes: make(map[string]string), + reloadCallback: func(*config.Config) { + atomic.AddInt32(&reloads, 1) + }, + } + w.SetConfig(&config.Config{AuthDir: tmpDir}) + + w.addOrUpdateClient(authFile) + + if got := atomic.LoadInt32(&reloads); got != 1 { + t.Fatalf("expected reload callback once, got %d", got) + } + normalized := filepath.Clean(authFile) + if _, ok := w.lastAuthHashes[normalized]; !ok { + t.Fatalf("expected hash to be stored for %s", normalized) + } +} + +func TestRemoveClientRemovesHash(t *testing.T) { + tmpDir := t.TempDir() + authFile := filepath.Join(tmpDir, "sample.json") + var reloads int32 + + w := &Watcher{ + authDir: tmpDir, + lastAuthHashes: map[string]string{ + filepath.Clean(authFile): "hash", + }, + reloadCallback: func(*config.Config) { + atomic.AddInt32(&reloads, 1) + }, + } + w.SetConfig(&config.Config{AuthDir: tmpDir}) + + w.removeClient(authFile) + if _, ok := w.lastAuthHashes[filepath.Clean(authFile)]; ok { + t.Fatal("expected hash to be removed after deletion") + } + if got := atomic.LoadInt32(&reloads); got != 1 { + t.Fatalf("expected reload callback once, got %d", got) + } +} + +func TestShouldDebounceRemove(t *testing.T) { + w := &Watcher{} + path := filepath.Clean("test.json") + + if w.shouldDebounceRemove(path, time.Now()) { + t.Fatal("first call should not debounce") + } + if !w.shouldDebounceRemove(path, time.Now()) { + t.Fatal("second call within window should debounce") + } + + w.clientsMutex.Lock() + w.lastRemoveTimes = map[string]time.Time{path: time.Now().Add(-2 * authRemoveDebounceWindow)} + w.clientsMutex.Unlock() + + if w.shouldDebounceRemove(path, time.Now()) { + t.Fatal("call after window should not debounce") + } +} + +func TestAuthFileUnchangedUsesHash(t *testing.T) { + tmpDir := t.TempDir() + authFile := filepath.Join(tmpDir, "sample.json") + content := []byte(`{"type":"demo"}`) + if err := os.WriteFile(authFile, content, 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + + w := &Watcher{lastAuthHashes: make(map[string]string)} + unchanged, err := w.authFileUnchanged(authFile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if unchanged { + t.Fatal("expected first check to report changed") + } + + sum := sha256.Sum256(content) + w.lastAuthHashes[filepath.Clean(authFile)] = hexString(sum[:]) + + unchanged, err = w.authFileUnchanged(authFile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !unchanged { + t.Fatal("expected hash match to report unchanged") + } +} + +func TestReloadClientsCachesAuthHashes(t *testing.T) { + tmpDir := t.TempDir() + authFile := filepath.Join(tmpDir, "one.json") + if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + w := &Watcher{ + authDir: tmpDir, + config: &config.Config{AuthDir: tmpDir}, + } + + w.reloadClients(true, nil) + + w.clientsMutex.RLock() + defer w.clientsMutex.RUnlock() + if len(w.lastAuthHashes) != 1 { + t.Fatalf("expected hash cache for one auth file, got %d", len(w.lastAuthHashes)) + } +} + +func TestReloadClientsLogsConfigDiffs(t *testing.T) { + tmpDir := t.TempDir() + oldCfg := &config.Config{AuthDir: tmpDir, Port: 1, Debug: false} + newCfg := &config.Config{AuthDir: tmpDir, Port: 2, Debug: true} + + w := &Watcher{ + authDir: tmpDir, + config: oldCfg, + } + w.SetConfig(oldCfg) + w.oldConfigYaml, _ = yaml.Marshal(oldCfg) + + w.clientsMutex.Lock() + w.config = newCfg + w.clientsMutex.Unlock() + + w.reloadClients(false, nil) +} + +func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) { + w := &Watcher{} + queue := make(chan AuthUpdate, 1) + w.SetAuthUpdateQueue(queue) + if w.dispatchCond == nil || w.dispatchCancel == nil { + t.Fatal("expected dispatch to be initialized") + } + w.SetAuthUpdateQueue(nil) + if w.dispatchCancel != nil { + t.Fatal("expected dispatch cancel to be cleared when queue nil") + } +} + +func TestStopConfigReloadTimerSafeWhenNil(t *testing.T) { + w := &Watcher{} + w.stopConfigReloadTimer() + w.configReloadMu.Lock() + w.configReloadTimer = time.AfterFunc(10*time.Millisecond, func() {}) + w.configReloadMu.Unlock() + time.Sleep(1 * time.Millisecond) + w.stopConfigReloadTimer() +} + +func TestHandleEventRemovesAuthFile(t *testing.T) { + tmpDir := t.TempDir() + authFile := filepath.Join(tmpDir, "remove.json") + if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + if err := os.Remove(authFile); err != nil { + t.Fatalf("failed to remove auth file pre-check: %v", err) + } + + var reloads int32 + w := &Watcher{ + authDir: tmpDir, + config: &config.Config{AuthDir: tmpDir}, + lastAuthHashes: map[string]string{ + filepath.Clean(authFile): "hash", + }, + reloadCallback: func(*config.Config) { + atomic.AddInt32(&reloads, 1) + }, + } + w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove}) + + if atomic.LoadInt32(&reloads) != 1 { + t.Fatalf("expected reload callback once, got %d", reloads) + } + if _, ok := w.lastAuthHashes[filepath.Clean(authFile)]; ok { + t.Fatal("expected hash entry to be removed") + } +} + +func TestDispatchAuthUpdatesFlushesQueue(t *testing.T) { + queue := make(chan AuthUpdate, 4) + w := &Watcher{} + w.SetAuthUpdateQueue(queue) + defer w.stopDispatch() + + w.dispatchAuthUpdates([]AuthUpdate{ + {Action: AuthUpdateActionAdd, ID: "a"}, + {Action: AuthUpdateActionModify, ID: "b"}, + }) + + got := make([]AuthUpdate, 0, 2) + for i := 0; i < 2; i++ { + select { + case u := <-queue: + got = append(got, u) + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for update %d", i) + } + } + if len(got) != 2 || got[0].ID != "a" || got[1].ID != "b" { + t.Fatalf("unexpected updates order/content: %+v", got) + } +} + +func hexString(data []byte) string { + return strings.ToLower(fmt.Sprintf("%x", data)) +}