diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 89a366ee..ff045c51 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -753,6 +753,19 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte { return body } + // Build a set of built-in tool names (tools with a "type" field) + builtinTools := make(map[string]bool) + if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() { + tools.ForEach(func(_, tool gjson.Result) bool { + if tool.Get("type").Exists() && tool.Get("type").String() != "" { + if name := tool.Get("name").String(); name != "" { + builtinTools[name] = true + } + } + return true + }) + } + if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() { tools.ForEach(func(index, tool gjson.Result) bool { // Skip built-in tools (web_search, code_execution, etc.) which have @@ -784,15 +797,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte { return true } content.ForEach(func(contentIndex, part gjson.Result) bool { - if part.Get("type").String() != "tool_use" { - return true + partType := part.Get("type").String() + switch partType { + case "tool_use": + name := part.Get("name").String() + if name == "" || strings.HasPrefix(name, prefix) { + return true + } + path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int()) + body, _ = sjson.SetBytes(body, path, prefix+name) + case "tool_reference": + toolName := part.Get("tool_name").String() + if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] { + return true + } + path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int()) + body, _ = sjson.SetBytes(body, path, prefix+toolName) + case "tool_result": + // Handle nested tool_reference blocks inside tool_result.content[] + nestedContent := part.Get("content") + if nestedContent.Exists() && nestedContent.IsArray() { + nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool { + if nestedPart.Get("type").String() == "tool_reference" { + nestedToolName := nestedPart.Get("tool_name").String() + if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] { + nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int()) + body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName) + } + } + return true + }) + } } - name := part.Get("name").String() - if name == "" || strings.HasPrefix(name, prefix) { - return true - } - path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int()) - body, _ = sjson.SetBytes(body, path, prefix+name) return true }) return true @@ -811,15 +847,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte { return body } content.ForEach(func(index, part gjson.Result) bool { - if part.Get("type").String() != "tool_use" { - return true + partType := part.Get("type").String() + switch partType { + case "tool_use": + name := part.Get("name").String() + if !strings.HasPrefix(name, prefix) { + return true + } + path := fmt.Sprintf("content.%d.name", index.Int()) + body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix)) + case "tool_reference": + toolName := part.Get("tool_name").String() + if !strings.HasPrefix(toolName, prefix) { + return true + } + path := fmt.Sprintf("content.%d.tool_name", index.Int()) + body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix)) + case "tool_result": + // Handle nested tool_reference blocks inside tool_result.content[] + nestedContent := part.Get("content") + if nestedContent.Exists() && nestedContent.IsArray() { + nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool { + if nestedPart.Get("type").String() == "tool_reference" { + nestedToolName := nestedPart.Get("tool_name").String() + if strings.HasPrefix(nestedToolName, prefix) { + nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int()) + body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix)) + } + } + return true + }) + } } - name := part.Get("name").String() - if !strings.HasPrefix(name, prefix) { - return true - } - path := fmt.Sprintf("content.%d.name", index.Int()) - body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix)) return true }) return body @@ -834,15 +893,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte { return line } contentBlock := gjson.GetBytes(payload, "content_block") - if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" { + if !contentBlock.Exists() { return line } - name := contentBlock.Get("name").String() - if !strings.HasPrefix(name, prefix) { - return line - } - updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix)) - if err != nil { + + blockType := contentBlock.Get("type").String() + var updated []byte + var err error + + switch blockType { + case "tool_use": + name := contentBlock.Get("name").String() + if !strings.HasPrefix(name, prefix) { + return line + } + updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix)) + if err != nil { + return line + } + case "tool_reference": + toolName := contentBlock.Get("tool_name").String() + if !strings.HasPrefix(toolName, prefix) { + return line + } + updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix)) + if err != nil { + return line + } + default: return line } diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 36fb7ad4..18594146 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) { } } +func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) { + input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`) + out := applyClaudeToolPrefix(input, "proxy_") + + if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" { + t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta") + } + if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" { + t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma") + } +} + func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) { input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`) out := applyClaudeToolPrefix(input, "proxy_") @@ -49,6 +61,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) { } } +func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`) + out := stripClaudeToolPrefixFromResponse(input, "proxy_") + + if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" { + t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha") + } + if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" { + t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo") + } +} + func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) { line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`) out := stripClaudeToolPrefixFromStreamLine(line, "proxy_") @@ -61,3 +85,53 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) { t.Fatalf("content_block.name = %q, want %q", got, "alpha") } } + +func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) { + line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`) + out := stripClaudeToolPrefixFromStreamLine(line, "proxy_") + + payload := bytes.TrimSpace(out) + if bytes.HasPrefix(payload, []byte("data:")) { + payload = bytes.TrimSpace(payload[len("data:"):]) + } + if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" { + t.Fatalf("content_block.tool_name = %q, want %q", got, "beta") + } +} + +func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) { + input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`) + out := applyClaudeToolPrefix(input, "proxy_") + got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String() + if got != "proxy_mcp__nia__manage_resource" { + t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource") + } +} + +func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`) + out := stripClaudeToolPrefixFromResponse(input, "proxy_") + got := gjson.GetBytes(out, "content.0.content.0.tool_name").String() + if got != "mcp__nia__manage_resource" { + t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource") + } +} + +func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) { + // tool_result.content can be a string - should not be processed + input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`) + out := applyClaudeToolPrefix(input, "proxy_") + got := gjson.GetBytes(out, "messages.0.content.0.content").String() + if got != "plain string result" { + t.Fatalf("string content should remain unchanged = %q", got) + } +} + +func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) { + input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`) + out := applyClaudeToolPrefix(input, "proxy_") + got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String() + if got != "web_search" { + t.Fatalf("built-in tool_reference should not be prefixed, got %q", got) + } +}