diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 7be4f41b..d426326f 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -35,6 +35,8 @@ type ClaudeExecutor struct { cfg *config.Config } +const claudeToolPrefix = "proxy_" + func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} } func (e *ClaudeExecutor) Identifier() string { return "claude" } @@ -81,9 +83,14 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r // Extract betas from body and convert to header var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) + bodyForTranslation := body + bodyForUpstream := body + if isClaudeOAuthToken(apiKey) { + bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix) + } url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream)) if err != nil { return resp, err } @@ -98,7 +105,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r URL: url, Method: http.MethodPost, Headers: httpReq.Header.Clone(), - Body: body, + Body: bodyForUpstream, Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, @@ -152,8 +159,20 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } else { reporter.publish(ctx, parseClaudeUsage(data)) } + if isClaudeOAuthToken(apiKey) { + data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix) + } var param any - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + out := sdktranslator.TranslateNonStream( + ctx, + to, + from, + req.Model, + bytes.Clone(opts.OriginalRequest), + bodyForTranslation, + data, + ¶m, + ) resp = cliproxyexecutor.Response{Payload: []byte(out)} return resp, nil } @@ -193,9 +212,14 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A // Extract betas from body and convert to header var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) + bodyForTranslation := body + bodyForUpstream := body + if isClaudeOAuthToken(apiKey) { + bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix) + } url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream)) if err != nil { return nil, err } @@ -210,7 +234,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A URL: url, Method: http.MethodPost, Headers: httpReq.Header.Clone(), - Body: body, + Body: bodyForUpstream, Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, @@ -263,6 +287,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if detail, ok := parseClaudeStreamUsage(line); ok { reporter.publish(ctx, detail) } + if isClaudeOAuthToken(apiKey) { + line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix) + } // Forward the line as-is to preserve SSE format cloned := make([]byte, len(line)+1) copy(cloned, line) @@ -287,7 +314,19 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if detail, ok := parseClaudeStreamUsage(line); ok { reporter.publish(ctx, detail) } - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + if isClaudeOAuthToken(apiKey) { + line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix) + } + chunks := sdktranslator.TranslateStream( + ctx, + to, + from, + req.Model, + bytes.Clone(opts.OriginalRequest), + bodyForTranslation, + bytes.Clone(line), + ¶m, + ) for i := range chunks { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} } @@ -326,6 +365,9 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut // Extract betas from body and convert to header (for count_tokens too) var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) + if isClaudeOAuthToken(apiKey) { + body = applyClaudeToolPrefix(body, claudeToolPrefix) + } url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) @@ -770,3 +812,107 @@ func checkSystemInstructions(payload []byte) []byte { } return payload } + +func isClaudeOAuthToken(apiKey string) bool { + return strings.Contains(apiKey, "sk-ant-oat") +} + +func applyClaudeToolPrefix(body []byte, prefix string) []byte { + if prefix == "" { + return body + } + + if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() { + tools.ForEach(func(index, tool gjson.Result) bool { + name := tool.Get("name").String() + if name == "" || strings.HasPrefix(name, prefix) { + return true + } + path := fmt.Sprintf("tools.%d.name", index.Int()) + body, _ = sjson.SetBytes(body, path, prefix+name) + return true + }) + } + + if gjson.GetBytes(body, "tool_choice.type").String() == "tool" { + name := gjson.GetBytes(body, "tool_choice.name").String() + if name != "" && !strings.HasPrefix(name, prefix) { + body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name) + } + } + + if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() { + messages.ForEach(func(msgIndex, msg gjson.Result) bool { + content := msg.Get("content") + if !content.Exists() || !content.IsArray() { + return true + } + content.ForEach(func(contentIndex, part gjson.Result) bool { + if part.Get("type").String() != "tool_use" { + return true + } + name := part.Get("name").String() + if name == "" || strings.HasPrefix(name, prefix) { + return true + } + path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int()) + body, _ = sjson.SetBytes(body, path, prefix+name) + return true + }) + return true + }) + } + + return body +} + +func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte { + if prefix == "" { + return body + } + content := gjson.GetBytes(body, "content") + if !content.Exists() || !content.IsArray() { + return body + } + content.ForEach(func(index, part gjson.Result) bool { + if part.Get("type").String() != "tool_use" { + return true + } + name := part.Get("name").String() + if !strings.HasPrefix(name, prefix) { + return true + } + path := fmt.Sprintf("content.%d.name", index.Int()) + body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix)) + return true + }) + return body +} + +func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte { + if prefix == "" { + return line + } + payload := jsonPayload(line) + if len(payload) == 0 || !gjson.ValidBytes(payload) { + return line + } + contentBlock := gjson.GetBytes(payload, "content_block") + if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" { + return line + } + name := contentBlock.Get("name").String() + if !strings.HasPrefix(name, prefix) { + return line + } + updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix)) + if err != nil { + return line + } + + trimmed := bytes.TrimSpace(line) + if bytes.HasPrefix(trimmed, []byte("data:")) { + return append([]byte("data: "), updated...) + } + return updated +} diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go new file mode 100644 index 00000000..05f5b60c --- /dev/null +++ b/internal/runtime/executor/claude_executor_test.go @@ -0,0 +1,51 @@ +package executor + +import ( + "bytes" + "testing" + + "github.com/tidwall/gjson" +) + +func TestApplyClaudeToolPrefix(t *testing.T) { + input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`) + out := applyClaudeToolPrefix(input, "proxy_") + + if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_alpha" { + t.Fatalf("tools.0.name = %q, want %q", got, "proxy_alpha") + } + if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_bravo" { + t.Fatalf("tools.1.name = %q, want %q", got, "proxy_bravo") + } + if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "proxy_charlie" { + t.Fatalf("tool_choice.name = %q, want %q", got, "proxy_charlie") + } + if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_delta" { + t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_delta") + } +} + +func TestStripClaudeToolPrefixFromResponse(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`) + out := stripClaudeToolPrefixFromResponse(input, "proxy_") + + if got := gjson.GetBytes(out, "content.0.name").String(); got != "alpha" { + t.Fatalf("content.0.name = %q, want %q", got, "alpha") + } + if got := gjson.GetBytes(out, "content.1.name").String(); got != "bravo" { + t.Fatalf("content.1.name = %q, want %q", got, "bravo") + } +} + +func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) { + line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`) + out := stripClaudeToolPrefixFromStreamLine(line, "proxy_") + + payload := bytes.TrimSpace(out) + if bytes.HasPrefix(payload, []byte("data:")) { + payload = bytes.TrimSpace(payload[len("data:"):]) + } + if got := gjson.GetBytes(payload, "content_block.name").String(); got != "alpha" { + t.Fatalf("content_block.name = %q, want %q", got, "alpha") + } +}