fix: skip built-in tools in tool_reference prefix + refactor to switch

- Collect built-in tool names (those with a "type" field like
  web_search, code_execution) and skip prefixing tool_reference
  blocks that reference them, preventing name mismatch.
- Refactor if-else if chains to switch statements in all three
  prefix functions for idiomatic Go style.
This commit is contained in:
Kirill Turanskiy
2026-02-16 19:37:11 +03:00
parent 603f06a762
commit 24c18614f0
2 changed files with 36 additions and 11 deletions

View File

@@ -753,6 +753,19 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return body return body
} }
// Build a set of built-in tool names (tools with a "type" field)
builtinTools := make(map[string]bool)
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
if name := tool.Get("name").String(); name != "" {
builtinTools[name] = true
}
}
return true
})
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() { if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool { tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have // Skip built-in tools (web_search, code_execution, etc.) which have
@@ -785,28 +798,29 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
} }
content.ForEach(func(contentIndex, part gjson.Result) bool { content.ForEach(func(contentIndex, part gjson.Result) bool {
partType := part.Get("type").String() partType := part.Get("type").String()
if partType == "tool_use" { switch partType {
case "tool_use":
name := part.Get("name").String() name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) { if name == "" || strings.HasPrefix(name, prefix) {
return true return true
} }
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int()) path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name) body, _ = sjson.SetBytes(body, path, prefix+name)
} else if partType == "tool_reference" { case "tool_reference":
toolName := part.Get("tool_name").String() toolName := part.Get("tool_name").String()
if toolName == "" || strings.HasPrefix(toolName, prefix) { if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
return true return true
} }
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int()) path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+toolName) body, _ = sjson.SetBytes(body, path, prefix+toolName)
} else if partType == "tool_result" { case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[] // Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content") nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() { if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool { nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" { if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String() nestedToolName := nestedPart.Get("tool_name").String()
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) { if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int()) nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName) body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
} }
@@ -834,21 +848,22 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
} }
content.ForEach(func(index, part gjson.Result) bool { content.ForEach(func(index, part gjson.Result) bool {
partType := part.Get("type").String() partType := part.Get("type").String()
if partType == "tool_use" { switch partType {
case "tool_use":
name := part.Get("name").String() name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) { if !strings.HasPrefix(name, prefix) {
return true return true
} }
path := fmt.Sprintf("content.%d.name", index.Int()) path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix)) body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
} else if partType == "tool_reference" { case "tool_reference":
toolName := part.Get("tool_name").String() toolName := part.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) { if !strings.HasPrefix(toolName, prefix) {
return true return true
} }
path := fmt.Sprintf("content.%d.tool_name", index.Int()) path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix)) body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
} else if partType == "tool_result" { case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[] // Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content") nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() { if nestedContent.Exists() && nestedContent.IsArray() {
@@ -886,7 +901,8 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
var updated []byte var updated []byte
var err error var err error
if blockType == "tool_use" { switch blockType {
case "tool_use":
name := contentBlock.Get("name").String() name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) { if !strings.HasPrefix(name, prefix) {
return line return line
@@ -895,7 +911,7 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
if err != nil { if err != nil {
return line return line
} }
} else if blockType == "tool_reference" { case "tool_reference":
toolName := contentBlock.Get("tool_name").String() toolName := contentBlock.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) { if !strings.HasPrefix(toolName, prefix) {
return line return line
@@ -904,7 +920,7 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
if err != nil { if err != nil {
return line return line
} }
} else { default:
return line return line
} }

View File

@@ -126,3 +126,12 @@ func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T
t.Fatalf("string content should remain unchanged = %q", got) t.Fatalf("string content should remain unchanged = %q", got)
} }
} }
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "web_search" {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}