diff --git a/examples/custom-provider/main.go b/examples/custom-provider/main.go index 2f530d7c..7c611f9e 100644 --- a/examples/custom-provider/main.go +++ b/examples/custom-provider/main.go @@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, return clipexec.Response{}, errors.New("count tokens not implemented") } -func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) { +func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) { ch := make(chan clipexec.StreamChunk, 1) go func() { defer close(ch) ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")} }() - return ch, nil + return &clipexec.StreamResult{Chunks: ch}, nil } func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) { diff --git a/examples/http-request/main.go b/examples/http-request/main.go index 4daee547..a667a9ca 100644 --- a/examples/http-request/main.go +++ b/examples/http-request/main.go @@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c return clipexec.Response{}, errors.New("echo executor: Execute not implemented") } -func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) { +func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) { return nil, errors.New("echo executor: ExecuteStream not implemented") } diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 6e33472e..b1e23860 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) var param any out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m) - resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))} + resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()} return resp, nil } // ExecuteStream performs a streaming request to the AI Studio API. -func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth return nil, statusErr{code: firstEvent.Status, msg: body.String()} } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func(first wsrelay.StreamEvent) { defer close(out) var param any @@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth } } }(firstEvent) - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil } // CountTokens counts tokens for the given request using the AI Studio API. diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 24765740..9d395a9c 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -232,7 +232,7 @@ attemptLoop: reporter.publish(ctx, parseAntigravityUsage(bodyBytes)) var param any converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(converted)} + resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()} reporter.ensurePublished(ctx) return resp, nil } @@ -436,7 +436,7 @@ attemptLoop: reporter.publish(ctx, parseAntigravityUsage(resp.Payload)) var param any converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(converted)} + resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()} reporter.ensurePublished(ctx) return resp, nil @@ -645,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte { } // ExecuteStream performs a streaming request to the Antigravity API. -func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -775,7 +775,6 @@ attemptLoop: } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func(resp *http.Response) { defer close(out) defer func() { @@ -820,7 +819,7 @@ attemptLoop: reporter.ensurePublished(ctx) } }(httpResp) - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } switch { @@ -968,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { count := gjson.GetBytes(bodyBytes, "totalTokens").Int() translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes) - return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil } lastStatus = httpResp.StatusCode diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 89a366ee..e2c62c06 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -222,11 +222,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r data, ¶m, ) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } -func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -329,7 +329,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -398,7 +397,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { @@ -487,7 +486,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut appendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "input_tokens").Int() out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil } func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 728e7cb7..80a941fb 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"} @@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A reporter.ensurePublished(ctx) var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } -func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"} } @@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 3e218c0f..cb3ffb59 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth reporter.publish(ctx, parseGeminiCLIUsage(data)) var param any out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } @@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth } // ExecuteStream performs a streaming request to the Gemini CLI API. -func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func(resp *http.Response, reqBody []byte, attemptModel string) { defer close(out) defer func() { @@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } }(httpResp, append([]byte(nil), payload...), attemptModel) - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } if len(lastBody) > 0 { @@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. if resp.StatusCode >= 200 && resp.StatusCode < 300 { count := gjson.GetBytes(data, "totalTokens").Int() translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil } lastStatus = resp.StatusCode lastBody = append([]byte(nil), data...) diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 9e868df8..7c25b893 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r reporter.publish(ctx, parseGeminiUsage(data)) var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } // ExecuteStream performs a streaming request to the Gemini API. -func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } // CountTokens counts tokens for the given request using the Gemini API. @@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut count := gjson.GetBytes(data, "totalTokens").Int() translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil } // Refresh refreshes the authentication credentials (no-op for Gemini API key). diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 5eceac31..7ad1c618 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A } // ExecuteStream performs a streaming request to the Vertex AI API. -func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au to := sdktranslator.FromString("gemini") var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } @@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip reporter.publish(ctx, parseGeminiUsage(data)) var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } // executeStreamWithServiceAccount handles streaming authentication using service account credentials. -func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) { baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) @@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } // executeStreamWithAPIKey handles streaming authentication using API key credentials. -func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) { baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) @@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } // countTokensWithServiceAccount counts tokens using service account credentials. @@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context appendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "totalTokens").Int() out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil } // countTokensWithAPIKey handles token counting using API key credentials. @@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * appendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "totalTokens").Int() out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil } // vertexCreds extracts project, location and raw service account JSON from auth metadata. diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 30c37726..65a0b8f8 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re // Note: TranslateNonStream uses req.Model (original with suffix) to preserve // the original model name in the response for client compatibility. out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } // ExecuteStream performs a streaming chat completion request. -func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au reporter.ensurePublished(ctx) }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/internal/runtime/executor/kimi_executor.go b/internal/runtime/executor/kimi_executor.go index 3276bf17..d5e3702f 100644 --- a/internal/runtime/executor/kimi_executor.go +++ b/internal/runtime/executor/kimi_executor.go @@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // Note: TranslateNonStream uses req.Model (original with suffix) to preserve // the original model name in the response for client compatibility. out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } // ExecuteStream performs a streaming chat completion request to Kimi. -func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { from := opts.SourceFormat if from.String() == "claude" { auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL @@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } // CountTokens estimates token count for Kimi requests. diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index b5796e44..d28b3625 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A // Translate response back to source format when needed var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } -func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) @@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy // Ensure we record the request if no usage chunk was ever seen reporter.ensurePublished(ctx) }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 69e1f7fa..bcc4a057 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -150,11 +150,11 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // Note: TranslateNonStream uses req.Model (original with suffix) to preserve // the original model name in the response for client compatibility. out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } -func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -236,7 +236,6 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -268,7 +267,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 22e10fa5..074ffc0d 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -112,12 +112,13 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) { modelName := gjson.GetBytes(rawJSON, "model").String() - resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -165,7 +166,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO modelName := gjson.GetBytes(rawJSON, "model").String() - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) @@ -194,6 +195,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO } } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -225,7 +227,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ // This allows proper cleanup and cancellation of ongoing requests cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") @@ -257,6 +259,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ if !ok { // Stream closed without data? Send DONE or just headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) flusher.Flush() cliCancel(nil) return @@ -264,6 +267,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ // Success! Set headers now. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write the first chunk if len(chunk) > 0 { diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go index 07cedc55..b5fd4943 100644 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ b/sdk/api/handlers/gemini/gemini-cli_handlers.go @@ -159,7 +159,8 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context modelName := modelResult.String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan) return } @@ -172,12 +173,13 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ modelName := modelResult.String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index a5eb337d..e51ad19b 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -188,7 +188,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName } cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") @@ -223,6 +223,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName if alt == "" { setSSEHeaders() } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) flusher.Flush() cliCancel(nil) return @@ -232,6 +233,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName if alt == "" { setSSEHeaders() } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write first chunk if alt == "" { @@ -262,12 +264,13 @@ func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, r c.Header("Content-Type", "application/json") alt := h.GetAlt(c) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -286,13 +289,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin alt := h.GetAlt(c) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 4ad2efb0..b0f2b2b1 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -370,10 +370,10 @@ func appendAPIResponse(c *gin.Context, data []byte) { // ExecuteWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. -func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { +func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) { providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { - return nil, errMsg + return nil, nil, errMsg } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel @@ -406,17 +406,17 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType addon = hdr.Clone() } } - return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} + return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } - return resp.Payload, nil + return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil } // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. -func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { +func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) { providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { - return nil, errMsg + return nil, nil, errMsg } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel @@ -449,20 +449,21 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle addon = hdr.Clone() } } - return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} + return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } - return resp.Payload, nil + return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil } // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // This path is the only supported execution route. -func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { +// The returned http.Header carries upstream response headers captured before streaming begins. +func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) { providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { errChan := make(chan *interfaces.ErrorMessage, 1) errChan <- errMsg close(errChan) - return nil, errChan + return nil, nil, errChan } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel @@ -481,7 +482,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl SourceFormat: sdktranslator.FromString(handlerType), } opts.Metadata = reqMeta - chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { errChan := make(chan *interfaces.ErrorMessage, 1) status := http.StatusInternalServerError @@ -498,8 +499,11 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl } errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} close(errChan) - return nil, errChan + return nil, nil, errChan } + // Capture upstream headers from the initial connection synchronously before the goroutine starts. + upstreamHeaders := FilterUpstreamHeaders(streamResult.Headers) + chunks := streamResult.Chunks dataChan := make(chan []byte) errChan := make(chan *interfaces.ErrorMessage, 1) go func() { @@ -573,9 +577,9 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl if !sentPayload { if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) { bootstrapRetries++ - retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if retryErr == nil { - chunks = retryChunks + chunks = retryResult.Chunks continue outer } streamErr = retryErr @@ -606,7 +610,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl } } }() - return dataChan, errChan + return dataChan, upstreamHeaders, errChan } func statusFromError(err error) int { diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index 7814ff1b..92da6b7c 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -23,7 +23,7 @@ func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreex return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} } -func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { +func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { e.mu.Lock() e.calls++ call := e.calls @@ -40,12 +40,12 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, }, } close(ch) - return ch, nil + return &coreexecutor.StreamResult{Chunks: ch}, nil } ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} close(ch) - return ch, nil + return &coreexecutor.StreamResult{Chunks: ch}, nil } func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { @@ -81,7 +81,7 @@ func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} } -func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { +func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { e.mu.Lock() e.calls++ e.mu.Unlock() @@ -97,7 +97,7 @@ func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreaut }, } close(ch) - return ch, nil + return &coreexecutor.StreamResult{Chunks: ch}, nil } func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { @@ -159,7 +159,7 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { BootstrapRetries: 1, }, }, manager) - dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } @@ -220,7 +220,7 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) { BootstrapRetries: 1, }, }, manager) - dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } diff --git a/sdk/api/handlers/header_filter.go b/sdk/api/handlers/header_filter.go new file mode 100644 index 00000000..e2fdf8a7 --- /dev/null +++ b/sdk/api/handlers/header_filter.go @@ -0,0 +1,58 @@ +package handlers + +import "net/http" + +// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT +// be forwarded by proxies, plus security-sensitive headers that should not leak. +var hopByHopHeaders = map[string]struct{}{ + // RFC 7230 hop-by-hop + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Te": {}, + "Trailer": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, + // Security-sensitive + "Set-Cookie": {}, + // CPA-managed (set by handlers, not upstream) + "Content-Length": {}, + "Content-Encoding": {}, +} + +// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive +// headers removed. Returns nil if src is nil or empty after filtering. +func FilterUpstreamHeaders(src http.Header) http.Header { + if src == nil { + return nil + } + dst := make(http.Header) + for key, values := range src { + if _, blocked := hopByHopHeaders[http.CanonicalHeaderKey(key)]; blocked { + continue + } + dst[key] = values + } + if len(dst) == 0 { + return nil + } + return dst +} + +// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer. +// Headers already set by CPA (e.g., Content-Type) are NOT overwritten. +func WriteUpstreamHeaders(dst http.Header, src http.Header) { + if src == nil { + return + } + for key, values := range src { + // Don't overwrite headers already set by CPA handlers + if dst.Get(key) != "" { + continue + } + for _, v := range values { + dst.Add(key, v) + } + } +} diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index 09471ce1..56bef990 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -425,12 +425,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON [] modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -457,7 +458,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") @@ -490,6 +491,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt if !ok { // Stream closed without data? Send DONE or just headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() cliCancel(nil) @@ -498,6 +500,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt // Success! Commit to streaming headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) flusher.Flush() @@ -525,13 +528,14 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, modelName := gjson.GetBytes(chatCompletionsJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) completionsResp := convertChatCompletionsResponseToCompletions(resp) _, _ = c.Writer.Write(completionsResp) cliCancel() @@ -562,7 +566,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra modelName := gjson.GetBytes(chatCompletionsJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") @@ -593,6 +597,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra case chunk, ok := <-dataChan: if !ok { setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() cliCancel(nil) @@ -601,6 +606,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra // Success! Set headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write the first chunk converted := convertChatCompletionsStreamChunkToCompletions(chunk) diff --git a/sdk/api/handlers/openai/openai_responses_compact_test.go b/sdk/api/handlers/openai/openai_responses_compact_test.go index a62a9682..dcfcc99a 100644 --- a/sdk/api/handlers/openai/openai_responses_compact_test.go +++ b/sdk/api/handlers/openai/openai_responses_compact_test.go @@ -31,7 +31,7 @@ func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Aut return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil } -func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { +func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { return nil, errors.New("not implemented") } diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index 4b611af3..1cd7e04f 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -124,13 +124,14 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) { modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact") + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact") stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -149,13 +150,14 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -183,7 +185,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ // New core execution path modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") @@ -216,6 +218,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ if !ok { // Stream closed without data? Send headers and done. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write([]byte("\n")) flusher.Flush() cliCancel(nil) @@ -224,6 +227,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ // Success! Set headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write first chunk logic (matching forwardResponsesStream) if bytes.HasPrefix(chunk, []byte("event:")) { diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 2c3e9f48..4d1cb732 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -30,8 +30,9 @@ type ProviderExecutor interface { Identifier() string // Execute handles non-streaming execution and returns the provider response payload. Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) - // ExecuteStream handles streaming execution and returns a channel of provider chunks. - ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) + // ExecuteStream handles streaming execution and returns a StreamResult containing + // upstream headers and a channel of provider chunks. + ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) // Refresh attempts to refresh provider credentials and returns the updated auth state. Refresh(ctx context.Context, auth *Auth) (*Auth, error) // CountTokens returns the token count for the given request. @@ -533,7 +534,7 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip // ExecuteStream performs a streaming execution using the configured selector and executor. // It supports multiple providers for the same model and round-robins the starting provider per model. -func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { normalized := m.normalizeProviders(providers) if len(normalized) == 0 { return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} @@ -543,9 +544,9 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli var lastErr error for attempt := 0; ; attempt++ { - chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts) + result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts) if errStream == nil { - return chunks, nil + return result, nil } lastErr = errStream wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait) @@ -672,7 +673,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, } } -func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { if len(providers) == 0 { return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} } @@ -702,7 +703,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execReq.Model = rewriteModelForAuth(routeModel, auth) execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) + streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx @@ -750,8 +751,11 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string if !failed { m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) } - }(execCtx, auth.Clone(), provider, chunks) - return out, nil + }(execCtx, auth.Clone(), provider, streamResult.Chunks) + return &cliproxyexecutor.StreamResult{ + Headers: streamResult.Headers, + Chunks: out, + }, nil } } diff --git a/sdk/cliproxy/executor/types.go b/sdk/cliproxy/executor/types.go index 8c11bbc4..04b81e83 100644 --- a/sdk/cliproxy/executor/types.go +++ b/sdk/cliproxy/executor/types.go @@ -46,6 +46,8 @@ type Response struct { Payload []byte // Metadata exposes optional structured data for translators. Metadata map[string]any + // Headers carries upstream HTTP response headers for passthrough to clients. + Headers http.Header } // StreamChunk represents a single streaming payload unit emitted by provider executors. @@ -56,6 +58,15 @@ type StreamChunk struct { Err error } +// StreamResult wraps the streaming response, providing both the chunk channel +// and the upstream HTTP response headers captured before streaming begins. +type StreamResult struct { + // Headers carries upstream HTTP response headers from the initial connection. + Headers http.Header + // Chunks is the channel of streaming payload units. + Chunks <-chan StreamChunk +} + // StatusError represents an error that carries an HTTP-like status code. // Provider executors should implement this when possible to enable // better auth state updates on failures (e.g., 401/402/429).