From b0c5d9640aee3bba741d95334a8a5b41024ff052 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Tue, 16 Dec 2025 22:39:19 +0800 Subject: [PATCH] refactor(diff): improve security and stability of config change detection Introduce formatProxyURL helper to sanitize proxy addresses before logging, stripping credentials and path components while preserving host information. Rework model hash computation to sort and deduplicate name/alias pairs with case normalization, ensuring consistent output regardless of input ordering. Add signature-based identification for anonymous OpenAI-compatible provider entries to maintain stable keys across configuration reloads. Replace direct stdout prints with structured logger calls for file change notifications. --- internal/watcher/diff/config_diff.go | 39 +++++++++-- internal/watcher/diff/config_diff_test.go | 24 +++++++ internal/watcher/diff/model_hash.go | 76 ++++++++++++++++----- internal/watcher/diff/model_hash_test.go | 55 +++++++++++++++ internal/watcher/diff/openai_compat.go | 61 ++++++++++++++++- internal/watcher/diff/openai_compat_test.go | 74 ++++++++++++++++++++ internal/watcher/watcher.go | 8 +-- internal/watcher/watcher_test.go | 2 +- 8 files changed, 310 insertions(+), 29 deletions(-) diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index 092001fd..2722b94d 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -2,6 +2,7 @@ package diff import ( "fmt" + "net/url" "reflect" "strings" @@ -45,7 +46,7 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval)) } if oldCfg.ProxyURL != newCfg.ProxyURL { - changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL)) + changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", formatProxyURL(oldCfg.ProxyURL), formatProxyURL(newCfg.ProxyURL))) } if oldCfg.WebsocketAuth != newCfg.WebsocketAuth { changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth)) @@ -75,7 +76,7 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) } if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { - changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) + changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL))) } if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i)) @@ -102,7 +103,7 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) } if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { - changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) + changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL))) } if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i)) @@ -129,7 +130,7 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) } if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { - changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) + changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL))) } if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i)) @@ -219,7 +220,7 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { changes = append(changes, fmt.Sprintf("vertex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL))) } if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) { - changes = append(changes, fmt.Sprintf("vertex[%d].proxy-url: %s -> %s", i, strings.TrimSpace(o.ProxyURL), strings.TrimSpace(n.ProxyURL))) + changes = append(changes, fmt.Sprintf("vertex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL))) } if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { changes = append(changes, fmt.Sprintf("vertex[%d].api-key: updated", i)) @@ -257,3 +258,31 @@ func equalStringMap(a, b map[string]string) bool { } return true } + +func formatProxyURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil { + return "" + } + host := strings.TrimSpace(parsed.Host) + scheme := strings.TrimSpace(parsed.Scheme) + if host == "" { + // Allow host:port style without scheme. + parsed2, err2 := url.Parse("http://" + trimmed) + if err2 == nil { + host = strings.TrimSpace(parsed2.Host) + } + scheme = "" + } + if host == "" { + return "" + } + if scheme == "" { + return host + } + return scheme + "://" + host +} diff --git a/internal/watcher/diff/config_diff_test.go b/internal/watcher/diff/config_diff_test.go index f952b695..fab762f5 100644 --- a/internal/watcher/diff/config_diff_test.go +++ b/internal/watcher/diff/config_diff_test.go @@ -416,6 +416,30 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) { expectContains(t, changes, "openai-compatibility:") } +func TestFormatProxyURL(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "empty", in: "", want: ""}, + {name: "invalid", in: "http://[::1", want: ""}, + {name: "fullURLRedactsUserinfoAndPath", in: "http://user:pass@example.com:8080/path?x=1#frag", want: "http://example.com:8080"}, + {name: "socks5RedactsUserinfoAndPath", in: "socks5://user:pass@192.168.1.1:1080/path?x=1", want: "socks5://192.168.1.1:1080"}, + {name: "socks5HostPort", in: "socks5://proxy.example.com:1080/", want: "socks5://proxy.example.com:1080"}, + {name: "hostPortNoScheme", in: "example.com:1234/path?x=1", want: "example.com:1234"}, + {name: "relativePathRedacted", in: "/just/path", want: ""}, + {name: "schemeAndHost", in: "https://example.com", want: "https://example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatProxyURL(tt.in); got != tt.want { + t.Fatalf("expected %q, got %q", tt.want, got) + } + }) + } +} + func TestBuildConfigChangeDetails_SecretAndUpstreamUpdates(t *testing.T) { oldCfg := &config.Config{ AmpCode: config.AmpCode{ diff --git a/internal/watcher/diff/model_hash.go b/internal/watcher/diff/model_hash.go index 796b09cf..a8b1aba6 100644 --- a/internal/watcher/diff/model_hash.go +++ b/internal/watcher/diff/model_hash.go @@ -13,32 +13,47 @@ import ( // ComputeOpenAICompatModelsHash returns a stable hash for OpenAI-compat models. // Used to detect model list changes during hot reload. func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string { - if len(models) == 0 { - return "" - } - data, _ := json.Marshal(models) - sum := sha256.Sum256(data) - return hex.EncodeToString(sum[:]) + keys := normalizeModelPairs(func(out func(key string)) { + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + out(strings.ToLower(name) + "|" + strings.ToLower(alias)) + } + }) + return hashJoined(keys) } // ComputeVertexCompatModelsHash returns a stable hash for Vertex-compatible models. func ComputeVertexCompatModelsHash(models []config.VertexCompatModel) string { - if len(models) == 0 { - return "" - } - data, _ := json.Marshal(models) - sum := sha256.Sum256(data) - return hex.EncodeToString(sum[:]) + keys := normalizeModelPairs(func(out func(key string)) { + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + out(strings.ToLower(name) + "|" + strings.ToLower(alias)) + } + }) + return hashJoined(keys) } // ComputeClaudeModelsHash returns a stable hash for Claude model aliases. func ComputeClaudeModelsHash(models []config.ClaudeModel) string { - if len(models) == 0 { - return "" - } - data, _ := json.Marshal(models) - sum := sha256.Sum256(data) - return hex.EncodeToString(sum[:]) + keys := normalizeModelPairs(func(out func(key string)) { + for _, model := range models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + out(strings.ToLower(name) + "|" + strings.ToLower(alias)) + } + }) + return hashJoined(keys) } // ComputeExcludedModelsHash returns a normalized hash for excluded model lists. @@ -60,3 +75,28 @@ func ComputeExcludedModelsHash(excluded []string) string { sum := sha256.Sum256(data) return hex.EncodeToString(sum[:]) } + +func normalizeModelPairs(collect func(out func(key string))) []string { + seen := make(map[string]struct{}) + keys := make([]string, 0) + collect(func(key string) { + if _, exists := seen[key]; exists { + return + } + seen[key] = struct{}{} + keys = append(keys, key) + }) + if len(keys) == 0 { + return nil + } + sort.Strings(keys) + return keys +} + +func hashJoined(keys []string) string { + if len(keys) == 0 { + return "" + } + sum := sha256.Sum256([]byte(strings.Join(keys, "\n"))) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/watcher/diff/model_hash_test.go b/internal/watcher/diff/model_hash_test.go index a91f97ab..a7046080 100644 --- a/internal/watcher/diff/model_hash_test.go +++ b/internal/watcher/diff/model_hash_test.go @@ -25,6 +25,27 @@ func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) { } } +func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) { + a := []config.OpenAICompatibilityModel{ + {Name: "gpt-4", Alias: "gpt4"}, + {Name: " "}, + {Name: "GPT-4", Alias: "GPT4"}, + {Alias: "a1"}, + } + b := []config.OpenAICompatibilityModel{ + {Alias: "A1"}, + {Name: "gpt-4", Alias: "gpt4"}, + } + h1 := ComputeOpenAICompatModelsHash(a) + h2 := ComputeOpenAICompatModelsHash(b) + if h1 == "" || h2 == "" { + t.Fatal("expected non-empty hashes for non-empty model sets") + } + if h1 != h2 { + t.Fatalf("expected normalized hashes to match, got %s / %s", h1, h2) + } +} + func TestComputeVertexCompatModelsHash_DifferentInputs(t *testing.T) { models := []config.VertexCompatModel{{Name: "gemini-pro", Alias: "pro"}} hash1 := ComputeVertexCompatModelsHash(models) @@ -37,6 +58,20 @@ func TestComputeVertexCompatModelsHash_DifferentInputs(t *testing.T) { } } +func TestComputeVertexCompatModelsHash_IgnoresBlankAndOrder(t *testing.T) { + a := []config.VertexCompatModel{ + {Name: "m1", Alias: "a1"}, + {Name: " "}, + {Name: "M1", Alias: "A1"}, + } + b := []config.VertexCompatModel{ + {Name: "m1", Alias: "a1"}, + } + if h1, h2 := ComputeVertexCompatModelsHash(a), ComputeVertexCompatModelsHash(b); h1 == "" || h1 != h2 { + t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2) + } +} + func TestComputeClaudeModelsHash_Empty(t *testing.T) { if got := ComputeClaudeModelsHash(nil); got != "" { t.Fatalf("expected empty hash for nil models, got %q", got) @@ -46,6 +81,20 @@ func TestComputeClaudeModelsHash_Empty(t *testing.T) { } } +func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) { + a := []config.ClaudeModel{ + {Name: "m1", Alias: "a1"}, + {Name: " "}, + {Name: "M1", Alias: "A1"}, + } + b := []config.ClaudeModel{ + {Name: "m1", Alias: "a1"}, + } + if h1, h2 := ComputeClaudeModelsHash(a), ComputeClaudeModelsHash(b); h1 == "" || h1 != h2 { + t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2) + } +} + func TestComputeExcludedModelsHash_Normalizes(t *testing.T) { hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"}) hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"}) @@ -68,6 +117,9 @@ func TestComputeOpenAICompatModelsHash_Empty(t *testing.T) { if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{}); got != "" { t.Fatalf("expected empty hash for empty slice, got %q", got) } + if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: " "}, {Alias: ""}}); got != "" { + t.Fatalf("expected empty hash for blank models, got %q", got) + } } func TestComputeVertexCompatModelsHash_Empty(t *testing.T) { @@ -77,6 +129,9 @@ func TestComputeVertexCompatModelsHash_Empty(t *testing.T) { if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{}); got != "" { t.Fatalf("expected empty hash for empty slice, got %q", got) } + if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: " "}}); got != "" { + t.Fatalf("expected empty hash for blank models, got %q", got) + } } func TestComputeExcludedModelsHash_Empty(t *testing.T) { diff --git a/internal/watcher/diff/openai_compat.go b/internal/watcher/diff/openai_compat.go index dee47802..6b01aed2 100644 --- a/internal/watcher/diff/openai_compat.go +++ b/internal/watcher/diff/openai_compat.go @@ -1,6 +1,8 @@ package diff import ( + "crypto/sha256" + "encoding/hex" "fmt" "sort" "strings" @@ -120,5 +122,62 @@ func openAICompatKey(entry config.OpenAICompatibility, index int) (string, strin return "alias:" + alias, alias } } - return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1) + sig := openAICompatSignature(entry) + if sig == "" { + return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1) + } + short := sig + if len(short) > 8 { + short = short[:8] + } + return "sig:" + sig, "compat-" + short +} + +func openAICompatSignature(entry config.OpenAICompatibility) string { + var parts []string + + if v := strings.TrimSpace(entry.Name); v != "" { + parts = append(parts, "name="+strings.ToLower(v)) + } + if v := strings.TrimSpace(entry.BaseURL); v != "" { + parts = append(parts, "base="+v) + } + + models := make([]string, 0, len(entry.Models)) + for _, model := range entry.Models { + name := strings.TrimSpace(model.Name) + alias := strings.TrimSpace(model.Alias) + if name == "" && alias == "" { + continue + } + models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias)) + } + if len(models) > 0 { + sort.Strings(models) + parts = append(parts, "models="+strings.Join(models, ",")) + } + + if len(entry.Headers) > 0 { + keys := make([]string, 0, len(entry.Headers)) + for k := range entry.Headers { + if trimmed := strings.TrimSpace(k); trimmed != "" { + keys = append(keys, strings.ToLower(trimmed)) + } + } + if len(keys) > 0 { + sort.Strings(keys) + parts = append(parts, "headers="+strings.Join(keys, ",")) + } + } + + // Intentionally exclude API key material; only count non-empty entries. + if count := countAPIKeys(entry); count > 0 { + parts = append(parts, fmt.Sprintf("api_keys=%d", count)) + } + + if len(parts) == 0 { + return "" + } + sum := sha256.Sum256([]byte(strings.Join(parts, "|"))) + return hex.EncodeToString(sum[:]) } diff --git a/internal/watcher/diff/openai_compat_test.go b/internal/watcher/diff/openai_compat_test.go index c0ec9090..db33db14 100644 --- a/internal/watcher/diff/openai_compat_test.go +++ b/internal/watcher/diff/openai_compat_test.go @@ -1,6 +1,7 @@ package diff import ( + "strings" "testing" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -89,6 +90,79 @@ func TestOpenAICompatKeyFallbacks(t *testing.T) { } } +func TestOpenAICompatKey_UsesName(t *testing.T) { + entry := config.OpenAICompatibility{Name: "My-Provider"} + key, label := openAICompatKey(entry, 0) + if key != "name:My-Provider" || label != "My-Provider" { + t.Fatalf("expected name key, got %s/%s", key, label) + } +} + +func TestOpenAICompatKey_SignatureFallbackWhenOnlyAPIKeys(t *testing.T) { + entry := config.OpenAICompatibility{ + APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k1"}, {APIKey: "k2"}}, + } + key, label := openAICompatKey(entry, 0) + if !strings.HasPrefix(key, "sig:") || !strings.HasPrefix(label, "compat-") { + t.Fatalf("expected signature key, got %s/%s", key, label) + } +} + +func TestOpenAICompatSignature_EmptyReturnsEmpty(t *testing.T) { + if got := openAICompatSignature(config.OpenAICompatibility{}); got != "" { + t.Fatalf("expected empty signature, got %q", got) + } +} + +func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) { + a := config.OpenAICompatibility{ + Name: " Provider ", + BaseURL: "http://base", + Models: []config.OpenAICompatibilityModel{ + {Name: "m1"}, + {Name: " "}, + {Alias: "A1"}, + }, + Headers: map[string]string{ + "X-Test": "1", + " ": "ignored", + }, + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "k1"}, + {APIKey: " "}, + }, + } + b := config.OpenAICompatibility{ + Name: "provider", + BaseURL: "http://base", + Models: []config.OpenAICompatibilityModel{ + {Alias: "a1"}, + {Name: "m1"}, + }, + Headers: map[string]string{ + "x-test": "2", + }, + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "k2"}, + }, + } + + sigA := openAICompatSignature(a) + sigB := openAICompatSignature(b) + if sigA == "" || sigB == "" { + t.Fatalf("expected non-empty signatures, got %q / %q", sigA, sigB) + } + if sigA != sigB { + t.Fatalf("expected normalized signatures to match, got %s / %s", sigA, sigB) + } + + c := b + c.Models = append(c.Models, config.OpenAICompatibilityModel{Name: "m2"}) + if sigC := openAICompatSignature(c); sigC == sigB { + t.Fatalf("expected signature to change when models change, got %s", sigC) + } +} + func TestCountOpenAIModelsSkipsBlanks(t *testing.T) { models := []config.OpenAICompatibilityModel{ {Name: "m1"}, diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 8f03bf5b..bb682840 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -658,7 +658,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) { log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name)) return } - fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name)) + log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) w.addOrUpdateClient(event.Name) return } @@ -666,7 +666,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) { log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name)) return } - fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name)) + log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) w.removeClient(event.Name) return } @@ -675,7 +675,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) { log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name)) return } - fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name)) + log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name)) w.addOrUpdateClient(event.Name) } } @@ -715,7 +715,7 @@ func (w *Watcher) reloadConfigIfChanged() { log.Debugf("config file content unchanged (hash match), skipping reload") return } - fmt.Printf("config file changed, reloading: %s\n", w.configPath) + log.Infof("config file changed, reloading: %s", w.configPath) if w.reloadConfig() { finalHash := newHash if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 { diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index 71fb1852..37c81196 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/json" "fmt" - "github.com/fsnotify/fsnotify" "os" "path/filepath" "strings" @@ -13,6 +12,7 @@ import ( "testing" "time" + "github.com/fsnotify/fsnotify" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"