mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Fix Claude OAuth tool name mapping
Prefix tool names with proxy_ for Claude OAuth requests and strip the prefix from streaming and non-streaming responses to restore client-facing names. Updates the Claude executor to: - add prefixing for tools, tool_choice, and tool_use messages when using OAuth tokens - strip the prefix from tool_use events in SSE and non-streaming payloads - add focused unit tests for prefix/strip helpers
This commit is contained in:
@@ -35,6 +35,8 @@ type ClaudeExecutor struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const claudeToolPrefix = "proxy_"
|
||||||
|
|
||||||
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
||||||
|
|
||||||
func (e *ClaudeExecutor) Identifier() string { return "claude" }
|
func (e *ClaudeExecutor) Identifier() string { return "claude" }
|
||||||
@@ -81,6 +83,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
|
if isClaudeOAuthToken(apiKey) {
|
||||||
|
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -152,6 +157,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
} else {
|
} else {
|
||||||
reporter.publish(ctx, parseClaudeUsage(data))
|
reporter.publish(ctx, parseClaudeUsage(data))
|
||||||
}
|
}
|
||||||
|
if isClaudeOAuthToken(apiKey) {
|
||||||
|
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||||
|
}
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
@@ -193,6 +201,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
|
if isClaudeOAuthToken(apiKey) {
|
||||||
|
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -263,6 +274,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.publish(ctx, detail)
|
||||||
}
|
}
|
||||||
|
if isClaudeOAuthToken(apiKey) {
|
||||||
|
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||||
|
}
|
||||||
// Forward the line as-is to preserve SSE format
|
// Forward the line as-is to preserve SSE format
|
||||||
cloned := make([]byte, len(line)+1)
|
cloned := make([]byte, len(line)+1)
|
||||||
copy(cloned, line)
|
copy(cloned, line)
|
||||||
@@ -287,6 +301,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||||
reporter.publish(ctx, detail)
|
reporter.publish(ctx, detail)
|
||||||
}
|
}
|
||||||
|
if isClaudeOAuthToken(apiKey) {
|
||||||
|
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||||
|
}
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||||
@@ -326,6 +343,9 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
// Extract betas from body and convert to header (for count_tokens too)
|
// Extract betas from body and convert to header (for count_tokens too)
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
|
if isClaudeOAuthToken(apiKey) {
|
||||||
|
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
|
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -770,3 +790,107 @@ func checkSystemInstructions(payload []byte) []byte {
|
|||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isClaudeOAuthToken(apiKey string) bool {
|
||||||
|
return strings.Contains(apiKey, "sk-ant-oat")
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||||
|
if prefix == "" {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||||
|
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||||
|
name := tool.Get("name").String()
|
||||||
|
if name == "" || strings.HasPrefix(name, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("tools.%d.name", index.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
|
||||||
|
name := gjson.GetBytes(body, "tool_choice.name").String()
|
||||||
|
if name != "" && !strings.HasPrefix(name, prefix) {
|
||||||
|
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||||
|
messages.ForEach(func(msgIndex, msg gjson.Result) bool {
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
content.ForEach(func(contentIndex, part gjson.Result) bool {
|
||||||
|
if part.Get("type").String() != "tool_use" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
name := part.Get("name").String()
|
||||||
|
if name == "" || strings.HasPrefix(name, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
|
||||||
|
if prefix == "" {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
content := gjson.GetBytes(body, "content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
content.ForEach(func(index, part gjson.Result) bool {
|
||||||
|
if part.Get("type").String() != "tool_use" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
name := part.Get("name").String()
|
||||||
|
if !strings.HasPrefix(name, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||||
|
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
||||||
|
if prefix == "" {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
payload := jsonPayload(line)
|
||||||
|
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
contentBlock := gjson.GetBytes(payload, "content_block")
|
||||||
|
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
name := contentBlock.Get("name").String()
|
||||||
|
if !strings.HasPrefix(name, prefix) {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||||
|
if err != nil {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := bytes.TrimSpace(line)
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||||
|
return append([]byte("data: "), updated...)
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|||||||
51
internal/runtime/executor/claude_executor_test.go
Normal file
51
internal/runtime/executor/claude_executor_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||||
|
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_alpha" {
|
||||||
|
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_alpha")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_bravo" {
|
||||||
|
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_bravo")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "proxy_charlie" {
|
||||||
|
t.Fatalf("tool_choice.name = %q, want %q", got, "proxy_charlie")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_delta" {
|
||||||
|
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_delta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||||
|
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||||
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "content.0.name").String(); got != "alpha" {
|
||||||
|
t.Fatalf("content.0.name = %q, want %q", got, "alpha")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "content.1.name").String(); got != "bravo" {
|
||||||
|
t.Fatalf("content.1.name = %q, want %q", got, "bravo")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||||
|
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||||
|
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||||
|
|
||||||
|
payload := bytes.TrimSpace(out)
|
||||||
|
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||||
|
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(payload, "content_block.name").String(); got != "alpha" {
|
||||||
|
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user