diff --git a/config.example.yaml b/config.example.yaml index 92619493..d44955df 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -68,6 +68,10 @@ proxy-url: "" # When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name). force-model-prefix: false +# When true, forward filtered upstream response headers to downstream clients. +# Default is false (disabled). +passthrough-headers: false + # Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504. request-retry: 3 diff --git a/examples/custom-provider/main.go b/examples/custom-provider/main.go index 2f530d7c..7c611f9e 100644 --- a/examples/custom-provider/main.go +++ b/examples/custom-provider/main.go @@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, return clipexec.Response{}, errors.New("count tokens not implemented") } -func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) { +func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) { ch := make(chan clipexec.StreamChunk, 1) go func() { defer close(ch) ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")} }() - return ch, nil + return &clipexec.StreamResult{Chunks: ch}, nil } func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) { diff --git a/examples/http-request/main.go b/examples/http-request/main.go index 4daee547..a667a9ca 100644 --- a/examples/http-request/main.go +++ b/examples/http-request/main.go @@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c return clipexec.Response{}, errors.New("echo executor: Execute not implemented") } -func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) { +func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) { return nil, errors.New("echo executor: ExecuteStream not implemented") } diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index 5c3990a6..9d99c924 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -20,6 +20,10 @@ type SDKConfig struct { // APIKeys is a list of keys for authenticating clients to this proxy server. APIKeys []string `yaml:"api-keys" json:"api-keys"` + // PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients. + // Default is false (disabled). + PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"` + // Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries). Streaming StreamingConfig `yaml:"streaming" json:"streaming"` diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 6e33472e..b1e23860 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) var param any out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m) - resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))} + resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()} return resp, nil } // ExecuteStream performs a streaming request to the AI Studio API. -func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth return nil, statusErr{code: firstEvent.Status, msg: body.String()} } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func(first wsrelay.StreamEvent) { defer close(out) var param any @@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth } } }(firstEvent) - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil } // CountTokens counts tokens for the given request using the AI Studio API. diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 24765740..9d395a9c 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -232,7 +232,7 @@ attemptLoop: reporter.publish(ctx, parseAntigravityUsage(bodyBytes)) var param any converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(converted)} + resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()} reporter.ensurePublished(ctx) return resp, nil } @@ -436,7 +436,7 @@ attemptLoop: reporter.publish(ctx, parseAntigravityUsage(resp.Payload)) var param any converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(converted)} + resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()} reporter.ensurePublished(ctx) return resp, nil @@ -645,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte { } // ExecuteStream performs a streaming request to the Antigravity API. -func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -775,7 +775,6 @@ attemptLoop: } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func(resp *http.Response) { defer close(out) defer func() { @@ -820,7 +819,7 @@ attemptLoop: reporter.ensurePublished(ctx) } }(httpResp) - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } switch { @@ -968,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { count := gjson.GetBytes(bodyBytes, "totalTokens").Int() translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes) - return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil } lastStatus = httpResp.StatusCode diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 43cb69ec..04a1242a 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -223,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r data, ¶m, ) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } -func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -330,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -399,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { @@ -488,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut appendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "input_tokens").Int() out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil } func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 6cfc0c24..01de8f97 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"} @@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A reporter.ensurePublished(ctx) var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } -func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"} } @@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 38ffad77..7c887221 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -363,7 +363,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } } -func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model) if ctx == nil { ctx = context.Background() @@ -436,7 +436,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr }) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + var upstreamHeaders http.Header if respHS != nil { + upstreamHeaders = respHS.Header.Clone() recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone()) } if errDial != nil { @@ -516,7 +518,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr markCodexWebsocketCreateSent(sess, conn, wsReqBody) out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { terminateReason := "completed" var terminateErr error @@ -627,7 +628,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil } func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { @@ -1343,7 +1344,7 @@ func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth return e.httpExec.Execute(ctx, auth, req, opts) } -func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { if e == nil || e.httpExec == nil || e.wsExec == nil { return nil, fmt.Errorf("codex auto executor: executor is nil") } diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 3e218c0f..cb3ffb59 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth reporter.publish(ctx, parseGeminiCLIUsage(data)) var param any out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } @@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth } // ExecuteStream performs a streaming request to the Gemini CLI API. -func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func(resp *http.Response, reqBody []byte, attemptModel string) { defer close(out) defer func() { @@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } }(httpResp, append([]byte(nil), payload...), attemptModel) - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } if len(lastBody) > 0 { @@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. if resp.StatusCode >= 200 && resp.StatusCode < 300 { count := gjson.GetBytes(data, "totalTokens").Int() translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil } lastStatus = resp.StatusCode lastBody = append([]byte(nil), data...) diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 9e868df8..7c25b893 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r reporter.publish(ctx, parseGeminiUsage(data)) var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } // ExecuteStream performs a streaming request to the Gemini API. -func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } // CountTokens counts tokens for the given request using the Gemini API. @@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut count := gjson.GetBytes(data, "totalTokens").Int() translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil } // Refresh refreshes the authentication credentials (no-op for Gemini API key). diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 5eceac31..7ad1c618 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A } // ExecuteStream performs a streaming request to the Vertex AI API. -func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au to := sdktranslator.FromString("gemini") var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } @@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip reporter.publish(ctx, parseGeminiUsage(data)) var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } // executeStreamWithServiceAccount handles streaming authentication using service account credentials. -func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) { baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) @@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } // executeStreamWithAPIKey handles streaming authentication using API key credentials. -func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) { baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) @@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } // countTokensWithServiceAccount counts tokens using service account credentials. @@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context appendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "totalTokens").Int() out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil } // countTokensWithAPIKey handles token counting using API key credentials. @@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * appendAPIResponseChunk(ctx, e.cfg, data) count := gjson.GetBytes(data, "totalTokens").Int() out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil } // vertexCreds extracts project, location and raw service account JSON from auth metadata. diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 30c37726..65a0b8f8 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re // Note: TranslateNonStream uses req.Model (original with suffix) to preserve // the original model name in the response for client compatibility. out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } // ExecuteStream performs a streaming chat completion request. -func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au reporter.ensurePublished(ctx) }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/internal/runtime/executor/kimi_executor.go b/internal/runtime/executor/kimi_executor.go index 3276bf17..d5e3702f 100644 --- a/internal/runtime/executor/kimi_executor.go +++ b/internal/runtime/executor/kimi_executor.go @@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // Note: TranslateNonStream uses req.Model (original with suffix) to preserve // the original model name in the response for client compatibility. out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } // ExecuteStream performs a streaming chat completion request to Kimi. -func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { from := opts.SourceFormat if from.String() == "claude" { auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL @@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } // CountTokens estimates token count for Kimi requests. diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index b5796e44..d28b3625 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A // Translate response back to source format when needed var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } -func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) @@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy // Ensure we record the request if no usage chunk was ever seen reporter.ensurePublished(ctx) }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 69e1f7fa..bcc4a057 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -150,11 +150,11 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // Note: TranslateNonStream uses req.Model (original with suffix) to preserve // the original model name in the response for client compatibility. out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) - resp = cliproxyexecutor.Response{Payload: []byte(out)} + resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} return resp, nil } -func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } @@ -236,7 +236,6 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return nil, err } out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { defer close(out) defer func() { @@ -268,7 +267,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut out <- cliproxyexecutor.StreamChunk{Err: errScan} } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 22e10fa5..074ffc0d 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -112,12 +112,13 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) { modelName := gjson.GetBytes(rawJSON, "model").String() - resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -165,7 +166,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO modelName := gjson.GetBytes(rawJSON, "model").String() - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) @@ -194,6 +195,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO } } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -225,7 +227,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ // This allows proper cleanup and cancellation of ongoing requests cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") @@ -257,6 +259,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ if !ok { // Stream closed without data? Send DONE or just headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) flusher.Flush() cliCancel(nil) return @@ -264,6 +267,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ // Success! Set headers now. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write the first chunk if len(chunk) > 0 { diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go index 07cedc55..b5fd4943 100644 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ b/sdk/api/handlers/gemini/gemini-cli_handlers.go @@ -159,7 +159,8 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context modelName := modelResult.String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan) return } @@ -172,12 +173,13 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ modelName := modelResult.String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index a5eb337d..e51ad19b 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -188,7 +188,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName } cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") @@ -223,6 +223,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName if alt == "" { setSSEHeaders() } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) flusher.Flush() cliCancel(nil) return @@ -232,6 +233,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName if alt == "" { setSSEHeaders() } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write first chunk if alt == "" { @@ -262,12 +264,13 @@ func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, r c.Header("Content-Type", "application/json") alt := h.GetAlt(c) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -286,13 +289,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin alt := h.GetAlt(c) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 23ef6535..d3359353 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -179,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int { return retries } +// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients. +// Default is false. +func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool { + return cfg != nil && cfg.PassthroughHeaders +} + func requestExecutionMetadata(ctx context.Context) map[string]any { // Idempotency-Key is an optional client-supplied header used to correlate retries. // It is forwarded as execution metadata; when absent we generate a UUID. @@ -461,10 +467,10 @@ func appendAPIResponse(c *gin.Context, data []byte) { // ExecuteWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. -func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { +func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) { providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { - return nil, errMsg + return nil, nil, errMsg } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel @@ -497,17 +503,20 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType addon = hdr.Clone() } } - return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} + return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } - return resp.Payload, nil + if !PassthroughHeadersEnabled(h.Cfg) { + return resp.Payload, nil, nil + } + return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil } // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. -func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { +func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) { providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { - return nil, errMsg + return nil, nil, errMsg } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel @@ -540,20 +549,24 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle addon = hdr.Clone() } } - return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} + return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } - return resp.Payload, nil + if !PassthroughHeadersEnabled(h.Cfg) { + return resp.Payload, nil, nil + } + return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil } // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // This path is the only supported execution route. -func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { +// The returned http.Header carries upstream response headers captured before streaming begins. +func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) { providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { errChan := make(chan *interfaces.ErrorMessage, 1) errChan <- errMsg close(errChan) - return nil, errChan + return nil, nil, errChan } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel @@ -572,7 +585,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl SourceFormat: sdktranslator.FromString(handlerType), } opts.Metadata = reqMeta - chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { errChan := make(chan *interfaces.ErrorMessage, 1) status := http.StatusInternalServerError @@ -589,8 +602,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl } errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} close(errChan) - return nil, errChan + return nil, nil, errChan } + passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg) + // Capture upstream headers from the initial connection synchronously before the goroutine starts. + // Keep a mutable map so bootstrap retries can replace it before first payload is sent. + var upstreamHeaders http.Header + if passthroughHeadersEnabled { + upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers)) + if upstreamHeaders == nil { + upstreamHeaders = make(http.Header) + } + } + chunks := streamResult.Chunks dataChan := make(chan []byte) errChan := make(chan *interfaces.ErrorMessage, 1) go func() { @@ -664,9 +688,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl if !sentPayload { if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) { bootstrapRetries++ - retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if retryErr == nil { - chunks = retryChunks + if passthroughHeadersEnabled { + replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers)) + } + chunks = retryResult.Chunks continue outer } streamErr = retryErr @@ -697,7 +724,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl } } }() - return dataChan, errChan + return dataChan, upstreamHeaders, errChan } func statusFromError(err error) int { @@ -757,6 +784,26 @@ func cloneBytes(src []byte) []byte { return dst } +func cloneHeader(src http.Header) http.Header { + if src == nil { + return nil + } + dst := make(http.Header, len(src)) + for key, values := range src { + dst[key] = append([]string(nil), values...) + } + return dst +} + +func replaceHeader(dst http.Header, src http.Header) { + for key := range dst { + delete(dst, key) + } + for key, values := range src { + dst[key] = append([]string(nil), values...) + } +} + // WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message. func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) { status := http.StatusInternalServerError diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index 66a49e52..ba9dcac5 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -23,7 +23,7 @@ func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreex return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} } -func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { +func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { e.mu.Lock() e.calls++ call := e.calls @@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, }, } close(ch) - return ch, nil + return &coreexecutor.StreamResult{ + Headers: http.Header{"X-Upstream-Attempt": {"1"}}, + Chunks: ch, + }, nil } ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} close(ch) - return ch, nil + return &coreexecutor.StreamResult{ + Headers: http.Header{"X-Upstream-Attempt": {"2"}}, + Chunks: ch, + }, nil } func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { @@ -81,7 +87,7 @@ func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} } -func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { +func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { e.mu.Lock() e.calls++ e.mu.Unlock() @@ -97,7 +103,7 @@ func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreaut }, } close(ch) - return ch, nil + return &coreexecutor.StreamResult{Chunks: ch}, nil } func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { @@ -134,7 +140,7 @@ func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coree return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} } -func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { +func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) { _ = ctx _ = req _ = opts @@ -160,12 +166,12 @@ func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *corea }, } close(ch) - return ch, nil + return &coreexecutor.StreamResult{Chunks: ch}, nil } ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} close(ch) - return ch, nil + return &coreexecutor.StreamResult{Chunks: ch}, nil } func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { @@ -231,11 +237,12 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { }) handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ + PassthroughHeaders: true, Streaming: sdkconfig.StreamingConfig{ BootstrapRetries: 1, }, }, manager) - dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } @@ -257,6 +264,70 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { if executor.Calls() != 2 { t.Fatalf("expected 2 stream attempts, got %d", executor.Calls()) } + upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt") + if upstreamAttemptHeader != "2" { + t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader) + } +} + +func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) { + executor := &failOnceStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + auth2 := &coreauth.Auth{ + ID: "auth2", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test2@example.com"}, + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("manager.Register(auth2): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + registry.GetGlobalRegistry().UnregisterClient(auth2.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ + Streaming: sdkconfig.StreamingConfig{ + BootstrapRetries: 1, + }, + }, manager) + dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + for msg := range errChan { + if msg != nil { + t.Fatalf("unexpected error: %+v", msg) + } + } + + if string(got) != "ok" { + t.Fatalf("expected payload ok, got %q", string(got)) + } + if upstreamHeaders != nil { + t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders) + } } func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) { @@ -296,7 +367,7 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) { BootstrapRetries: 1, }, }, manager) - dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } @@ -367,7 +438,7 @@ func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) }, }, manager) ctx := WithPinnedAuthID(context.Background(), "auth1") - dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } @@ -431,7 +502,7 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) { selectedAuthID = authID }) - dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } diff --git a/sdk/api/handlers/header_filter.go b/sdk/api/handlers/header_filter.go new file mode 100644 index 00000000..135223a7 --- /dev/null +++ b/sdk/api/handlers/header_filter.go @@ -0,0 +1,80 @@ +package handlers + +import ( + "net/http" + "strings" +) + +// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT +// be forwarded by proxies, plus security-sensitive headers that should not leak. +var hopByHopHeaders = map[string]struct{}{ + // RFC 7230 hop-by-hop + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Te": {}, + "Trailer": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, + // Security-sensitive + "Set-Cookie": {}, + // CPA-managed (set by handlers, not upstream) + "Content-Length": {}, + "Content-Encoding": {}, +} + +// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive +// headers removed. Returns nil if src is nil or empty after filtering. +func FilterUpstreamHeaders(src http.Header) http.Header { + if src == nil { + return nil + } + connectionScoped := connectionScopedHeaders(src) + dst := make(http.Header) + for key, values := range src { + canonicalKey := http.CanonicalHeaderKey(key) + if _, blocked := hopByHopHeaders[canonicalKey]; blocked { + continue + } + if _, scoped := connectionScoped[canonicalKey]; scoped { + continue + } + dst[key] = values + } + if len(dst) == 0 { + return nil + } + return dst +} + +func connectionScopedHeaders(src http.Header) map[string]struct{} { + scoped := make(map[string]struct{}) + for _, rawValue := range src.Values("Connection") { + for _, token := range strings.Split(rawValue, ",") { + headerName := strings.TrimSpace(token) + if headerName == "" { + continue + } + scoped[http.CanonicalHeaderKey(headerName)] = struct{}{} + } + } + return scoped +} + +// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer. +// Headers already set by CPA (e.g., Content-Type) are NOT overwritten. +func WriteUpstreamHeaders(dst http.Header, src http.Header) { + if src == nil { + return + } + for key, values := range src { + // Don't overwrite headers already set by CPA handlers + if dst.Get(key) != "" { + continue + } + for _, v := range values { + dst.Add(key, v) + } + } +} diff --git a/sdk/api/handlers/header_filter_test.go b/sdk/api/handlers/header_filter_test.go new file mode 100644 index 00000000..a87e65a1 --- /dev/null +++ b/sdk/api/handlers/header_filter_test.go @@ -0,0 +1,55 @@ +package handlers + +import ( + "net/http" + "testing" +) + +func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) { + src := http.Header{} + src.Add("Connection", "keep-alive, x-hop-a, x-hop-b") + src.Add("Connection", "x-hop-c") + src.Set("Keep-Alive", "timeout=5") + src.Set("X-Hop-A", "a") + src.Set("X-Hop-B", "b") + src.Set("X-Hop-C", "c") + src.Set("X-Request-Id", "req-1") + src.Set("Set-Cookie", "session=secret") + + filtered := FilterUpstreamHeaders(src) + if filtered == nil { + t.Fatalf("expected filtered headers, got nil") + } + + requestID := filtered.Get("X-Request-Id") + if requestID != "req-1" { + t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID) + } + + blockedHeaderKeys := []string{ + "Connection", + "Keep-Alive", + "X-Hop-A", + "X-Hop-B", + "X-Hop-C", + "Set-Cookie", + } + for _, key := range blockedHeaderKeys { + value := filtered.Get(key) + if value != "" { + t.Fatalf("expected %s to be removed, got %q", key, value) + } + } +} + +func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) { + src := http.Header{} + src.Add("Connection", "x-hop-a") + src.Set("X-Hop-A", "a") + src.Set("Set-Cookie", "session=secret") + + filtered := FilterUpstreamHeaders(src) + if filtered != nil { + t.Fatalf("expected nil when all headers are filtered, got %#v", filtered) + } +} diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index 9c161a1c..5991341e 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -431,12 +431,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON [] modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -463,7 +464,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") @@ -496,6 +497,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt if !ok { // Stream closed without data? Send DONE or just headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() cliCancel(nil) @@ -504,6 +506,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt // Success! Commit to streaming headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) flusher.Flush() @@ -531,13 +534,14 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, modelName := gjson.GetBytes(chatCompletionsJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) completionsResp := convertChatCompletionsResponseToCompletions(resp) _, _ = c.Writer.Write(completionsResp) cliCancel() @@ -568,7 +572,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra modelName := gjson.GetBytes(chatCompletionsJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") @@ -599,6 +603,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra case chunk, ok := <-dataChan: if !ok { setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() cliCancel(nil) @@ -607,6 +612,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra // Success! Set headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write the first chunk converted := convertChatCompletionsStreamChunkToCompletions(chunk) diff --git a/sdk/api/handlers/openai/openai_responses_compact_test.go b/sdk/api/handlers/openai/openai_responses_compact_test.go index a62a9682..dcfcc99a 100644 --- a/sdk/api/handlers/openai/openai_responses_compact_test.go +++ b/sdk/api/handlers/openai/openai_responses_compact_test.go @@ -31,7 +31,7 @@ func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Aut return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil } -func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { +func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { return nil, errors.New("not implemented") } diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index 4b611af3..1cd7e04f 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -124,13 +124,14 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) { modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact") + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact") stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -149,13 +150,14 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) - resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") stopKeepAlive() if errMsg != nil { h.WriteErrorResponse(c, errMsg) cliCancel(errMsg.Error) return } + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() } @@ -183,7 +185,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ // New core execution path modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") + dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") setSSEHeaders := func() { c.Header("Content-Type", "text/event-stream") @@ -216,6 +218,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ if !ok { // Stream closed without data? Send headers and done. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write([]byte("\n")) flusher.Flush() cliCancel(nil) @@ -224,6 +227,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ // Success! Set headers. setSSEHeaders() + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write first chunk logic (matching forwardResponsesStream) if bytes.HasPrefix(chunk, []byte("event:")) { diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index bcf09311..f2d44f05 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -153,7 +153,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { pinnedAuthID = strings.TrimSpace(authID) }) } - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") + dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) if errForward != nil { diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 76aae228..cd447e68 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -30,8 +30,9 @@ type ProviderExecutor interface { Identifier() string // Execute handles non-streaming execution and returns the provider response payload. Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) - // ExecuteStream handles streaming execution and returns a channel of provider chunks. - ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) + // ExecuteStream handles streaming execution and returns a StreamResult containing + // upstream headers and a channel of provider chunks. + ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) // Refresh attempts to refresh provider credentials and returns the updated auth state. Refresh(ctx context.Context, auth *Auth) (*Auth, error) // CountTokens returns the token count for the given request. @@ -558,7 +559,7 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip // ExecuteStream performs a streaming execution using the configured selector and executor. // It supports multiple providers for the same model and round-robins the starting provider per model. -func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { normalized := m.normalizeProviders(providers) if len(normalized) == 0 { return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} @@ -568,9 +569,9 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli var lastErr error for attempt := 0; ; attempt++ { - chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts) + result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts) if errStream == nil { - return chunks, nil + return result, nil } lastErr = errStream wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait) @@ -699,7 +700,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, } } -func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { if len(providers) == 0 { return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} } @@ -730,7 +731,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execReq.Model = rewriteModelForAuth(routeModel, auth) execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) + streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx @@ -778,8 +779,11 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string if !failed { m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) } - }(execCtx, auth.Clone(), provider, chunks) - return out, nil + }(execCtx, auth.Clone(), provider, streamResult.Chunks) + return &cliproxyexecutor.StreamResult{ + Headers: streamResult.Headers, + Chunks: out, + }, nil } } diff --git a/sdk/cliproxy/auth/conductor_executor_replace_test.go b/sdk/cliproxy/auth/conductor_executor_replace_test.go index 3854f341..2ee91a87 100644 --- a/sdk/cliproxy/auth/conductor_executor_replace_test.go +++ b/sdk/cliproxy/auth/conductor_executor_replace_test.go @@ -24,10 +24,10 @@ func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor. return cliproxyexecutor.Response{}, nil } -func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { ch := make(chan cliproxyexecutor.StreamChunk) close(ch) - return ch, nil + return &cliproxyexecutor.StreamResult{Chunks: ch}, nil } func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { @@ -89,7 +89,11 @@ func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) { if !okResolved { t.Fatal("expected registered executor to be found") } - if resolved != current { + resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor) + if !okResolvedExecutor { + t.Fatalf("expected resolved executor type %T, got %T", current, resolved) + } + if resolvedExecutor != current { t.Fatal("expected resolved executor to match registered executor") } diff --git a/sdk/cliproxy/executor/types.go b/sdk/cliproxy/executor/types.go index 4e917eb7..4ea81039 100644 --- a/sdk/cliproxy/executor/types.go +++ b/sdk/cliproxy/executor/types.go @@ -57,6 +57,8 @@ type Response struct { Payload []byte // Metadata exposes optional structured data for translators. Metadata map[string]any + // Headers carries upstream HTTP response headers for passthrough to clients. + Headers http.Header } // StreamChunk represents a single streaming payload unit emitted by provider executors. @@ -67,6 +69,15 @@ type StreamChunk struct { Err error } +// StreamResult wraps the streaming response, providing both the chunk channel +// and the upstream HTTP response headers captured before streaming begins. +type StreamResult struct { + // Headers carries upstream HTTP response headers from the initial connection. + Headers http.Header + // Chunks is the channel of streaming payload units. + Chunks <-chan StreamChunk +} + // StatusError represents an error that carries an HTTP-like status code. // Provider executors should implement this when possible to enable // better auth state updates on failures (e.g., 401/402/429).