diff --git a/config.example.yaml b/config.example.yaml index d44955df..f99ee74f 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -159,6 +159,7 @@ nonstream-keepalive-interval: 0 # sensitive-words: # optional: words to obfuscate with zero-width characters # - "API" # - "proxy" +# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request # Default headers for Claude API requests. Update when Claude Code releases new versions. # These are used as fallbacks when the client does not send its own headers. diff --git a/internal/config/config.go b/internal/config/config.go index 5b18f3df..ed57b993 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -301,6 +301,10 @@ type CloakConfig struct { // SensitiveWords is a list of words to obfuscate with zero-width characters. // This can help bypass certain content filters. SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"` + + // CacheUserID controls whether Claude user_id values are cached per API key. + // When false, a fresh random user_id is generated for every request. + CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"` } // ClaudeKey represents the configuration for a Claude API key, diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 04a1242a..681e7b8d 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -117,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r // Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation) // based on client type and configuration. - body = applyCloaking(ctx, e.cfg, auth, body, baseModel) + body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey) requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) @@ -258,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A // Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation) // based on client type and configuration. - body = applyCloaking(ctx, e.cfg, auth, body, baseModel) + body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey) requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) @@ -982,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string { } // getCloakConfigFromAuth extracts cloak configuration from auth attributes. -// Returns (cloakMode, strictMode, sensitiveWords). -func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) { +// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID). +func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) { if auth == nil || auth.Attributes == nil { - return "auto", false, nil + return "auto", false, nil, false } cloakMode := auth.Attributes["cloak_mode"] @@ -1003,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) { } } - return cloakMode, strictMode, sensitiveWords + cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true") + + return cloakMode, strictMode, sensitiveWords, cacheUserID } // resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig. @@ -1036,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c } // injectFakeUserID generates and injects a fake user ID into the request metadata. -func injectFakeUserID(payload []byte) []byte { +// When useCache is false, a new user ID is generated for every call. +func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte { + generateID := func() string { + if useCache { + return cachedUserID(apiKey) + } + return generateFakeUserID() + } + metadata := gjson.GetBytes(payload, "metadata") if !metadata.Exists() { - payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID()) + payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID()) return payload } existingUserID := gjson.GetBytes(payload, "metadata.user_id").String() if existingUserID == "" || !isValidUserID(existingUserID) { - payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID()) + payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID()) } return payload } @@ -1082,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte { // applyCloaking applies cloaking transformations to the payload based on config and client. // Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation. -func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte { +func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte { clientUserAgent := getClientUserAgent(ctx) // Get cloak config from ClaudeKey configuration @@ -1092,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A var cloakMode string var strictMode bool var sensitiveWords []string + var cacheUserID bool if cloakCfg != nil { cloakMode = cloakCfg.Mode strictMode = cloakCfg.StrictMode sensitiveWords = cloakCfg.SensitiveWords + if cloakCfg.CacheUserID != nil { + cacheUserID = *cloakCfg.CacheUserID + } } // Fallback to auth attributes if no config found if cloakMode == "" { - attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth) + attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth) cloakMode = attrMode if !strictMode { strictMode = attrStrict @@ -1109,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A if len(sensitiveWords) == 0 { sensitiveWords = attrWords } + if cloakCfg == nil || cloakCfg.CacheUserID == nil { + cacheUserID = attrCache + } + } else if cloakCfg == nil || cloakCfg.CacheUserID == nil { + _, _, _, attrCache := getCloakConfigFromAuth(auth) + cacheUserID = attrCache } // Determine if cloaking should be applied @@ -1122,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A } // Inject fake user ID - payload = injectFakeUserID(payload) + payload = injectFakeUserID(payload, apiKey, cacheUserID) // Apply sensitive word obfuscation if len(sensitiveWords) > 0 { diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 017e0913..dd29ed8a 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -2,9 +2,18 @@ package executor import ( "bytes" + "context" + "io" + "net/http" + "net/http/httptest" "testing" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) func TestApplyClaudeToolPrefix(t *testing.T) { @@ -199,6 +208,119 @@ func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) { } } +func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) { + resetUserIDCache() + + var userIDs []string + var requestModels []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + userID := gjson.GetBytes(body, "metadata.user_id").String() + model := gjson.GetBytes(body, "model").String() + userIDs = append(userIDs, userID) + requestModels = append(requestModels, model) + t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String()) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL) + + cacheEnabled := true + executor := NewClaudeExecutor(&config.Config{ + ClaudeKey: []config.ClaudeKey{ + { + APIKey: "key-123", + BaseURL: server.URL, + Cloak: &config.CloakConfig{ + CacheUserID: &cacheEnabled, + }, + }, + }, + }) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"} + for _, model := range models { + t.Logf("Sending request for model: %s", model) + modelPayload, _ := sjson.SetBytes(payload, "model", model) + if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: model, + Payload: modelPayload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }); err != nil { + t.Fatalf("Execute(%s) error: %v", model, err) + } + } + + if len(userIDs) != 2 { + t.Fatalf("expected 2 requests, got %d", len(userIDs)) + } + if userIDs[0] == "" || userIDs[1] == "" { + t.Fatal("expected user_id to be populated") + } + t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0]) + t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1]) + if userIDs[0] != userIDs[1] { + t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1]) + } + if !isValidUserID(userIDs[0]) { + t.Fatalf("user_id %q is not valid", userIDs[0]) + } + t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0]) +} + +func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) { + resetUserIDCache() + + var userIDs []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String()) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + for i := 0; i < 2; i++ { + if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }); err != nil { + t.Fatalf("Execute call %d error: %v", i, err) + } + } + + if len(userIDs) != 2 { + t.Fatalf("expected 2 requests, got %d", len(userIDs)) + } + if userIDs[0] == "" || userIDs[1] == "" { + t.Fatal("expected user_id to be populated") + } + if userIDs[0] == userIDs[1] { + t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0]) + } + if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) { + t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1]) + } +} + func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) { input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`) out := stripClaudeToolPrefixFromResponse(input, "proxy_") diff --git a/internal/runtime/executor/user_id_cache.go b/internal/runtime/executor/user_id_cache.go new file mode 100644 index 00000000..ff8efd9d --- /dev/null +++ b/internal/runtime/executor/user_id_cache.go @@ -0,0 +1,89 @@ +package executor + +import ( + "crypto/sha256" + "encoding/hex" + "sync" + "time" +) + +type userIDCacheEntry struct { + value string + expire time.Time +} + +var ( + userIDCache = make(map[string]userIDCacheEntry) + userIDCacheMu sync.RWMutex + userIDCacheCleanupOnce sync.Once +) + +const ( + userIDTTL = time.Hour + userIDCacheCleanupPeriod = 15 * time.Minute +) + +func startUserIDCacheCleanup() { + go func() { + ticker := time.NewTicker(userIDCacheCleanupPeriod) + defer ticker.Stop() + for range ticker.C { + purgeExpiredUserIDs() + } + }() +} + +func purgeExpiredUserIDs() { + now := time.Now() + userIDCacheMu.Lock() + for key, entry := range userIDCache { + if !entry.expire.After(now) { + delete(userIDCache, key) + } + } + userIDCacheMu.Unlock() +} + +func userIDCacheKey(apiKey string) string { + sum := sha256.Sum256([]byte(apiKey)) + return hex.EncodeToString(sum[:]) +} + +func cachedUserID(apiKey string) string { + if apiKey == "" { + return generateFakeUserID() + } + + userIDCacheCleanupOnce.Do(startUserIDCacheCleanup) + + key := userIDCacheKey(apiKey) + now := time.Now() + + userIDCacheMu.RLock() + entry, ok := userIDCache[key] + valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) + userIDCacheMu.RUnlock() + if valid { + userIDCacheMu.Lock() + entry = userIDCache[key] + if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) { + entry.expire = now.Add(userIDTTL) + userIDCache[key] = entry + userIDCacheMu.Unlock() + return entry.value + } + userIDCacheMu.Unlock() + } + + newID := generateFakeUserID() + + userIDCacheMu.Lock() + entry, ok = userIDCache[key] + if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) { + entry.value = newID + } + entry.expire = now.Add(userIDTTL) + userIDCache[key] = entry + userIDCacheMu.Unlock() + return entry.value +} diff --git a/internal/runtime/executor/user_id_cache_test.go b/internal/runtime/executor/user_id_cache_test.go new file mode 100644 index 00000000..420a3cad --- /dev/null +++ b/internal/runtime/executor/user_id_cache_test.go @@ -0,0 +1,86 @@ +package executor + +import ( + "testing" + "time" +) + +func resetUserIDCache() { + userIDCacheMu.Lock() + userIDCache = make(map[string]userIDCacheEntry) + userIDCacheMu.Unlock() +} + +func TestCachedUserID_ReusesWithinTTL(t *testing.T) { + resetUserIDCache() + + first := cachedUserID("api-key-1") + second := cachedUserID("api-key-1") + + if first == "" { + t.Fatal("expected generated user_id to be non-empty") + } + if first != second { + t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second) + } +} + +func TestCachedUserID_ExpiresAfterTTL(t *testing.T) { + resetUserIDCache() + + expiredID := cachedUserID("api-key-expired") + cacheKey := userIDCacheKey("api-key-expired") + userIDCacheMu.Lock() + userIDCache[cacheKey] = userIDCacheEntry{ + value: expiredID, + expire: time.Now().Add(-time.Minute), + } + userIDCacheMu.Unlock() + + newID := cachedUserID("api-key-expired") + if newID == expiredID { + t.Fatalf("expected expired user_id to be replaced, got %q", newID) + } + if newID == "" { + t.Fatal("expected regenerated user_id to be non-empty") + } +} + +func TestCachedUserID_IsScopedByAPIKey(t *testing.T) { + resetUserIDCache() + + first := cachedUserID("api-key-1") + second := cachedUserID("api-key-2") + + if first == second { + t.Fatalf("expected different API keys to have different user_ids, got %q", first) + } +} + +func TestCachedUserID_RenewsTTLOnHit(t *testing.T) { + resetUserIDCache() + + key := "api-key-renew" + id := cachedUserID(key) + cacheKey := userIDCacheKey(key) + + soon := time.Now() + userIDCacheMu.Lock() + userIDCache[cacheKey] = userIDCacheEntry{ + value: id, + expire: soon.Add(2 * time.Second), + } + userIDCacheMu.Unlock() + + if refreshed := cachedUserID(key); refreshed != id { + t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed) + } + + userIDCacheMu.RLock() + entry := userIDCache[cacheKey] + userIDCacheMu.RUnlock() + + if entry.expire.Sub(soon) < 30*time.Minute { + t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon)) + } +}