mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
Merge pull request #874 from MohammadErfan-Jabbari/fix/streaming-finish-reason-tool-calls
fix(antigravity): preserve finish_reason tool_calls across streaming chunks
This commit is contained in:
@@ -22,8 +22,10 @@ import (
|
||||
|
||||
// convertCliResponseToOpenAIChatParams holds parameters for response conversion.
|
||||
type convertCliResponseToOpenAIChatParams struct {
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
SawToolCall bool // Tracks if any tool call was seen in the entire stream
|
||||
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -79,10 +81,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(finishReasonResult.String())
|
||||
}
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
@@ -112,7 +113,6 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
|
||||
// Process the main content part of the response.
|
||||
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||
hasFunctionCall := false
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
for i := 0; i < len(partResults); i++ {
|
||||
@@ -148,7 +148,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
hasFunctionCall = true
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
|
||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
||||
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
||||
@@ -195,9 +195,25 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
}
|
||||
}
|
||||
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
// Determine finish_reason only on the final chunk (has both finishReason and usage metadata)
|
||||
params := (*param).(*convertCliResponseToOpenAIChatParams)
|
||||
upstreamFinishReason := params.UpstreamFinishReason
|
||||
sawToolCall := params.SawToolCall
|
||||
|
||||
usageExists := gjson.GetBytes(rawJSON, "response.usageMetadata").Exists()
|
||||
isFinalChunk := upstreamFinishReason != "" && usageExists
|
||||
|
||||
if isFinalChunk {
|
||||
var finishReason string
|
||||
if sawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
} else if upstreamFinishReason == "MAX_TOKENS" {
|
||||
finishReason = "max_tokens"
|
||||
} else {
|
||||
finishReason = "stop"
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
// Chunk 1: Contains functionCall - should set SawToolCall = true
|
||||
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_files","args":{"path":"."}}}]}}]}}`)
|
||||
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
|
||||
// Verify chunk1 has no finish_reason (null)
|
||||
if len(result1) != 1 {
|
||||
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
|
||||
}
|
||||
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
|
||||
}
|
||||
|
||||
// Chunk 2: Contains finishReason STOP + usage (final chunk, no functionCall)
|
||||
// This simulates what the upstream sends AFTER the tool call chunk
|
||||
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":20,"totalTokenCount":30}}}`)
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify chunk2 has finish_reason: "tool_calls" (not "stop")
|
||||
if len(result2) != 1 {
|
||||
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
|
||||
}
|
||||
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
if fr2 != "tool_calls" {
|
||||
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
|
||||
}
|
||||
|
||||
// Verify native_finish_reason is lowercase upstream value
|
||||
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
|
||||
if nfr2 != "stop" {
|
||||
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishReasonStopForNormalText(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
// Chunk 1: Text content only
|
||||
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello world"}]}}]}}`)
|
||||
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
|
||||
// Chunk 2: Final chunk with STOP
|
||||
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}}`)
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "stop" (no tool calls were made)
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "stop" {
|
||||
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishReasonMaxTokens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
// Chunk 1: Text content
|
||||
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}`)
|
||||
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
|
||||
// Chunk 2: Final chunk with MAX_TOKENS
|
||||
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"MAX_TOKENS"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":100,"totalTokenCount":110}}}`)
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "max_tokens"
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "max_tokens" {
|
||||
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
// Chunk 1: Contains functionCall
|
||||
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"test","args":{}}}]}}]}}`)
|
||||
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
|
||||
// Chunk 2: Final chunk with MAX_TOKENS (but we had a tool call, so tool_calls should win)
|
||||
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"MAX_TOKENS"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":100,"totalTokenCount":110}}}`)
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
|
||||
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||
if fr != "tool_calls" {
|
||||
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
// Chunk 1: Text content (no finish reason, no usage)
|
||||
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}`)
|
||||
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||
|
||||
// Verify no finish_reason on intermediate chunk
|
||||
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
|
||||
}
|
||||
|
||||
// Chunk 2: More text (no finish reason, no usage)
|
||||
chunk2 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":" world"}]}}]}}`)
|
||||
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||
|
||||
// Verify no finish_reason on intermediate chunk
|
||||
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
|
||||
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
|
||||
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user