mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 12:20:52 +08:00
Fix Kimi tool-call reasoning_content normalization
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
|||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -94,6 +95,10 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, err = normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
url := kimiauth.KimiAPIBaseURL + "/chat/completions"
|
url := kimiauth.KimiAPIBaseURL + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -189,6 +194,10 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, err = normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
url := kimiauth.KimiAPIBaseURL + "/chat/completions"
|
url := kimiauth.KimiAPIBaseURL + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -291,6 +300,150 @@ func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := body
|
||||||
|
pending := make([]string, 0)
|
||||||
|
patched := 0
|
||||||
|
patchedReasoning := 0
|
||||||
|
ambiguous := 0
|
||||||
|
latestReasoning := ""
|
||||||
|
hasLatestReasoning := false
|
||||||
|
|
||||||
|
removePending := func(id string) {
|
||||||
|
for idx := range pending {
|
||||||
|
if pending[idx] != id {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pending = append(pending[:idx], pending[idx+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := messages.Array()
|
||||||
|
for msgIdx := range msgs {
|
||||||
|
msg := msgs[msgIdx]
|
||||||
|
role := strings.TrimSpace(msg.Get("role").String())
|
||||||
|
switch role {
|
||||||
|
case "assistant":
|
||||||
|
reasoning := msg.Get("reasoning_content")
|
||||||
|
if reasoning.Exists() {
|
||||||
|
reasoningText := reasoning.String()
|
||||||
|
if strings.TrimSpace(reasoningText) != "" {
|
||||||
|
latestReasoning = reasoningText
|
||||||
|
hasLatestReasoning = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := msg.Get("tool_calls")
|
||||||
|
if !toolCalls.Exists() || !toolCalls.IsArray() || len(toolCalls.Array()) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reasoning.Exists() || strings.TrimSpace(reasoning.String()) == "" {
|
||||||
|
reasoningText := fallbackAssistantReasoning(msg, hasLatestReasoning, latestReasoning)
|
||||||
|
path := fmt.Sprintf("messages.%d.reasoning_content", msgIdx)
|
||||||
|
next, err := sjson.SetBytes(out, path, reasoningText)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("kimi executor: failed to set assistant reasoning_content: %w", err)
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
patchedReasoning++
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range toolCalls.Array() {
|
||||||
|
id := strings.TrimSpace(tc.Get("id").String())
|
||||||
|
if id == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pending = append(pending, id)
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
toolCallID := strings.TrimSpace(msg.Get("tool_call_id").String())
|
||||||
|
if toolCallID == "" {
|
||||||
|
toolCallID = strings.TrimSpace(msg.Get("call_id").String())
|
||||||
|
if toolCallID != "" {
|
||||||
|
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||||
|
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("kimi executor: failed to set tool_call_id from call_id: %w", err)
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
patched++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCallID == "" {
|
||||||
|
if len(pending) == 1 {
|
||||||
|
toolCallID = pending[0]
|
||||||
|
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||||
|
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("kimi executor: failed to infer tool_call_id: %w", err)
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
patched++
|
||||||
|
} else if len(pending) > 1 {
|
||||||
|
ambiguous++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCallID != "" {
|
||||||
|
removePending(toolCallID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if patched > 0 || patchedReasoning > 0 {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"patched_tool_messages": patched,
|
||||||
|
"patched_reasoning_messages": patchedReasoning,
|
||||||
|
}).Debug("kimi executor: normalized tool message fields")
|
||||||
|
}
|
||||||
|
if ambiguous > 0 {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"ambiguous_tool_messages": ambiguous,
|
||||||
|
"pending_tool_calls": len(pending),
|
||||||
|
}).Warn("kimi executor: tool messages missing tool_call_id with ambiguous candidates")
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
|
||||||
|
if hasLatest && strings.TrimSpace(latest) != "" {
|
||||||
|
return latest
|
||||||
|
}
|
||||||
|
|
||||||
|
content := msg.Get("content")
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
if text := strings.TrimSpace(content.String()); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content.IsArray() {
|
||||||
|
parts := make([]string, 0, len(content.Array()))
|
||||||
|
for _, item := range content.Array() {
|
||||||
|
text := strings.TrimSpace(item.Get("text").String())
|
||||||
|
if text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts = append(parts, text)
|
||||||
|
}
|
||||||
|
if len(parts) > 0 {
|
||||||
|
return strings.Join(parts, "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "[reasoning unavailable]"
|
||||||
|
}
|
||||||
|
|
||||||
// Refresh refreshes the Kimi token using the refresh token.
|
// Refresh refreshes the Kimi token using the refresh token.
|
||||||
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
log.Debugf("kimi executor: refresh called")
|
log.Debugf("kimi executor: refresh called")
|
||||||
|
|||||||
205
internal/runtime/executor/kimi_executor_test.go
Normal file
205
internal/runtime/executor/kimi_executor_test.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_UsesCallIDFallback(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"list_directory:1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","call_id":"list_directory:1","content":"[]"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||||
|
if got != "list_directory:1" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "list_directory:1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_InferSinglePendingID(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_123","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","content":"file-content"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||||
|
if got != "call_123" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_123")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_AmbiguousMissingIDIsNotInferred(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[
|
||||||
|
{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}},
|
||||||
|
{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}
|
||||||
|
]},
|
||||||
|
{"role":"tool","content":"result-without-id"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.GetBytes(out, "messages.1.tool_call_id").Exists() {
|
||||||
|
t.Fatalf("messages.1.tool_call_id should be absent for ambiguous case, got %q", gjson.GetBytes(out, "messages.1.tool_call_id").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_PreservesExistingToolCallID(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","tool_call_id":"call_1","call_id":"different-id","content":"result"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||||
|
if got != "call_1" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_InheritsPreviousReasoningForAssistantToolCalls(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","content":"plan","reasoning_content":"previous reasoning"},
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.1.reasoning_content").String()
|
||||||
|
if got != "previous reasoning" {
|
||||||
|
t.Fatalf("messages.1.reasoning_content = %q, want %q", got, "previous reasoning")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_InsertsFallbackReasoningWhenMissing(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reasoning := gjson.GetBytes(out, "messages.0.reasoning_content")
|
||||||
|
if !reasoning.Exists() {
|
||||||
|
t.Fatalf("messages.0.reasoning_content should exist")
|
||||||
|
}
|
||||||
|
if reasoning.String() != "[reasoning unavailable]" {
|
||||||
|
t.Fatalf("messages.0.reasoning_content = %q, want %q", reasoning.String(), "[reasoning unavailable]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_UsesContentAsReasoningFallback(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","content":[{"type":"text","text":"first line"},{"type":"text","text":"second line"}],"tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||||
|
if got != "first line\nsecond line" {
|
||||||
|
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "first line\nsecond line")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_ReplacesEmptyReasoningContent(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","content":"assistant summary","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":""}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||||
|
if got != "assistant summary" {
|
||||||
|
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "assistant summary")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_PreservesExistingAssistantReasoning(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"keep me"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||||
|
if got != "keep me" {
|
||||||
|
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "keep me")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"r1"},
|
||||||
|
{"role":"tool","call_id":"call_1","content":"[]"},
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","call_id":"call_2","content":"file"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_1" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.3.tool_call_id").String(); got != "call_2" {
|
||||||
|
t.Fatalf("messages.3.tool_call_id = %q, want %q", got, "call_2")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.2.reasoning_content").String(); got != "r1" {
|
||||||
|
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user