mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Merge pull request #874 from MohammadErfan-Jabbari/fix/streaming-finish-reason-tool-calls
fix(antigravity): preserve finish_reason tool_calls across streaming chunks
This commit is contained in:
@@ -24,6 +24,8 @@ import (
|
|||||||
type convertCliResponseToOpenAIChatParams struct {
|
type convertCliResponseToOpenAIChatParams struct {
|
||||||
UnixTimestamp int64
|
UnixTimestamp int64
|
||||||
FunctionIndex int
|
FunctionIndex int
|
||||||
|
SawToolCall bool // Tracks if any tool call was seen in the entire stream
|
||||||
|
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
|
||||||
}
|
}
|
||||||
|
|
||||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
@@ -79,10 +81,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set the finish reason.
|
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
(*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(finishReasonResult.String())
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
@@ -112,7 +113,6 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
// Process the main content part of the response.
|
// Process the main content part of the response.
|
||||||
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||||
hasFunctionCall := false
|
|
||||||
if partsResult.IsArray() {
|
if partsResult.IsArray() {
|
||||||
partResults := partsResult.Array()
|
partResults := partsResult.Array()
|
||||||
for i := 0; i < len(partResults); i++ {
|
for i := 0; i < len(partResults); i++ {
|
||||||
@@ -148,7 +148,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
} else if functionCallResult.Exists() {
|
} else if functionCallResult.Exists() {
|
||||||
// Handle function call content.
|
// Handle function call content.
|
||||||
hasFunctionCall = true
|
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
|
||||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||||
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
||||||
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
||||||
@@ -195,9 +195,25 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasFunctionCall {
|
// Determine finish_reason only on the final chunk (has both finishReason and usage metadata)
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
params := (*param).(*convertCliResponseToOpenAIChatParams)
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
upstreamFinishReason := params.UpstreamFinishReason
|
||||||
|
sawToolCall := params.SawToolCall
|
||||||
|
|
||||||
|
usageExists := gjson.GetBytes(rawJSON, "response.usageMetadata").Exists()
|
||||||
|
isFinalChunk := upstreamFinishReason != "" && usageExists
|
||||||
|
|
||||||
|
if isFinalChunk {
|
||||||
|
var finishReason string
|
||||||
|
if sawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
} else if upstreamFinishReason == "MAX_TOKENS" {
|
||||||
|
finishReason = "max_tokens"
|
||||||
|
} else {
|
||||||
|
finishReason = "stop"
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{template}
|
return []string{template}
|
||||||
|
|||||||
@@ -0,0 +1,128 @@
|
|||||||
|
package chat_completions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Contains functionCall - should set SawToolCall = true
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_files","args":{"path":"."}}}]}}]}}`)
|
||||||
|
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Verify chunk1 has no finish_reason (null)
|
||||||
|
if len(result1) != 1 {
|
||||||
|
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
|
||||||
|
}
|
||||||
|
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||||
|
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||||
|
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chunk 2: Contains finishReason STOP + usage (final chunk, no functionCall)
|
||||||
|
// This simulates what the upstream sends AFTER the tool call chunk
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":20,"totalTokenCount":30}}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify chunk2 has finish_reason: "tool_calls" (not "stop")
|
||||||
|
if len(result2) != 1 {
|
||||||
|
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
|
||||||
|
}
|
||||||
|
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||||
|
if fr2 != "tool_calls" {
|
||||||
|
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify native_finish_reason is lowercase upstream value
|
||||||
|
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
|
||||||
|
if nfr2 != "stop" {
|
||||||
|
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinishReasonStopForNormalText(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Text content only
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello world"}]}}]}}`)
|
||||||
|
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Chunk 2: Final chunk with STOP
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify finish_reason is "stop" (no tool calls were made)
|
||||||
|
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||||
|
if fr != "stop" {
|
||||||
|
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinishReasonMaxTokens(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Text content
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}`)
|
||||||
|
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Chunk 2: Final chunk with MAX_TOKENS
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"MAX_TOKENS"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":100,"totalTokenCount":110}}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify finish_reason is "max_tokens"
|
||||||
|
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||||
|
if fr != "max_tokens" {
|
||||||
|
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Contains functionCall
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"test","args":{}}}]}}]}}`)
|
||||||
|
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Chunk 2: Final chunk with MAX_TOKENS (but we had a tool call, so tool_calls should win)
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"MAX_TOKENS"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":100,"totalTokenCount":110}}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
|
||||||
|
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||||
|
if fr != "tool_calls" {
|
||||||
|
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Text content (no finish reason, no usage)
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}`)
|
||||||
|
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Verify no finish_reason on intermediate chunk
|
||||||
|
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||||
|
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||||
|
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chunk 2: More text (no finish reason, no usage)
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":" world"}]}}]}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify no finish_reason on intermediate chunk
|
||||||
|
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
|
||||||
|
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
|
||||||
|
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user