From 2789396435b046258e7605577745b6b2423d78fb Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 19 Feb 2026 13:19:10 +0800 Subject: [PATCH] fix: ensure connection-scoped headers are filtered in upstream requests - Added `connectionScopedHeaders` utility to respect "Connection" header directives. - Updated `FilterUpstreamHeaders` to remove connection-scoped headers dynamically. - Refactored and tested upstream header filtering with additional validations. - Adjusted upstream header handling during retries to replace headers safely. --- .../executor/codex_websockets_executor.go | 9 +-- sdk/api/handlers/handlers.go | 27 ++++++++- .../handlers_stream_bootstrap_test.go | 26 ++++++--- sdk/api/handlers/header_filter.go | 26 ++++++++- sdk/api/handlers/header_filter_test.go | 55 +++++++++++++++++++ .../openai/openai_responses_websocket.go | 2 +- .../auth/conductor_executor_replace_test.go | 10 +++- 7 files changed, 136 insertions(+), 19 deletions(-) create mode 100644 sdk/api/handlers/header_filter_test.go diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 38ffad77..7c887221 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -363,7 +363,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } } -func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { +func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model) if ctx == nil { ctx = context.Background() @@ -436,7 +436,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr }) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + var upstreamHeaders http.Header if respHS != nil { + upstreamHeaders = respHS.Header.Clone() recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone()) } if errDial != nil { @@ -516,7 +518,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr markCodexWebsocketCreateSent(sess, conn, wsReqBody) out := make(chan cliproxyexecutor.StreamChunk) - stream = out go func() { terminateReason := "completed" var terminateErr error @@ -627,7 +628,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr } }() - return stream, nil + return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil } func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { @@ -1343,7 +1344,7 @@ func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth return e.httpExec.Execute(ctx, auth, req, opts) } -func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { if e == nil || e.httpExec == nil || e.wsExec == nil { return nil, fmt.Errorf("codex auto executor: executor is nil") } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index c7e578cf..54bd09cd 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -593,7 +593,11 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return nil, nil, errChan } // Capture upstream headers from the initial connection synchronously before the goroutine starts. - upstreamHeaders := FilterUpstreamHeaders(streamResult.Headers) + // Keep a mutable map so bootstrap retries can replace it before first payload is sent. + upstreamHeaders := cloneHeader(FilterUpstreamHeaders(streamResult.Headers)) + if upstreamHeaders == nil { + upstreamHeaders = make(http.Header) + } chunks := streamResult.Chunks dataChan := make(chan []byte) errChan := make(chan *interfaces.ErrorMessage, 1) @@ -670,6 +674,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl bootstrapRetries++ retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if retryErr == nil { + replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers)) chunks = retryResult.Chunks continue outer } @@ -761,6 +766,26 @@ func cloneBytes(src []byte) []byte { return dst } +func cloneHeader(src http.Header) http.Header { + if src == nil { + return nil + } + dst := make(http.Header, len(src)) + for key, values := range src { + dst[key] = append([]string(nil), values...) + } + return dst +} + +func replaceHeader(dst http.Header, src http.Header) { + for key := range dst { + delete(dst, key) + } + for key, values := range src { + dst[key] = append([]string(nil), values...) + } +} + // WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message. func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) { status := http.StatusInternalServerError diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index 4642d2be..20274124 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, }, } close(ch) - return &coreexecutor.StreamResult{Chunks: ch}, nil + return &coreexecutor.StreamResult{ + Headers: http.Header{"X-Upstream-Attempt": {"1"}}, + Chunks: ch, + }, nil } ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} close(ch) - return &coreexecutor.StreamResult{Chunks: ch}, nil + return &coreexecutor.StreamResult{ + Headers: http.Header{"X-Upstream-Attempt": {"2"}}, + Chunks: ch, + }, nil } func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { @@ -134,7 +140,7 @@ func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coree return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} } -func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { +func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) { _ = ctx _ = req _ = opts @@ -160,12 +166,12 @@ func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *corea }, } close(ch) - return ch, nil + return &coreexecutor.StreamResult{Chunks: ch}, nil } ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} close(ch) - return ch, nil + return &coreexecutor.StreamResult{Chunks: ch}, nil } func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { @@ -235,7 +241,7 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { BootstrapRetries: 1, }, }, manager) - dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } @@ -257,6 +263,10 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { if executor.Calls() != 2 { t.Fatalf("expected 2 stream attempts, got %d", executor.Calls()) } + upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt") + if upstreamAttemptHeader != "2" { + t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader) + } } func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) { @@ -367,7 +377,7 @@ func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) }, }, manager) ctx := WithPinnedAuthID(context.Background(), "auth1") - dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } @@ -431,7 +441,7 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) { selectedAuthID = authID }) - dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } diff --git a/sdk/api/handlers/header_filter.go b/sdk/api/handlers/header_filter.go index e2fdf8a7..135223a7 100644 --- a/sdk/api/handlers/header_filter.go +++ b/sdk/api/handlers/header_filter.go @@ -1,6 +1,9 @@ package handlers -import "net/http" +import ( + "net/http" + "strings" +) // hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT // be forwarded by proxies, plus security-sensitive headers that should not leak. @@ -27,9 +30,14 @@ func FilterUpstreamHeaders(src http.Header) http.Header { if src == nil { return nil } + connectionScoped := connectionScopedHeaders(src) dst := make(http.Header) for key, values := range src { - if _, blocked := hopByHopHeaders[http.CanonicalHeaderKey(key)]; blocked { + canonicalKey := http.CanonicalHeaderKey(key) + if _, blocked := hopByHopHeaders[canonicalKey]; blocked { + continue + } + if _, scoped := connectionScoped[canonicalKey]; scoped { continue } dst[key] = values @@ -40,6 +48,20 @@ func FilterUpstreamHeaders(src http.Header) http.Header { return dst } +func connectionScopedHeaders(src http.Header) map[string]struct{} { + scoped := make(map[string]struct{}) + for _, rawValue := range src.Values("Connection") { + for _, token := range strings.Split(rawValue, ",") { + headerName := strings.TrimSpace(token) + if headerName == "" { + continue + } + scoped[http.CanonicalHeaderKey(headerName)] = struct{}{} + } + } + return scoped +} + // WriteUpstreamHeaders writes filtered upstream headers to the gin response writer. // Headers already set by CPA (e.g., Content-Type) are NOT overwritten. func WriteUpstreamHeaders(dst http.Header, src http.Header) { diff --git a/sdk/api/handlers/header_filter_test.go b/sdk/api/handlers/header_filter_test.go new file mode 100644 index 00000000..a87e65a1 --- /dev/null +++ b/sdk/api/handlers/header_filter_test.go @@ -0,0 +1,55 @@ +package handlers + +import ( + "net/http" + "testing" +) + +func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) { + src := http.Header{} + src.Add("Connection", "keep-alive, x-hop-a, x-hop-b") + src.Add("Connection", "x-hop-c") + src.Set("Keep-Alive", "timeout=5") + src.Set("X-Hop-A", "a") + src.Set("X-Hop-B", "b") + src.Set("X-Hop-C", "c") + src.Set("X-Request-Id", "req-1") + src.Set("Set-Cookie", "session=secret") + + filtered := FilterUpstreamHeaders(src) + if filtered == nil { + t.Fatalf("expected filtered headers, got nil") + } + + requestID := filtered.Get("X-Request-Id") + if requestID != "req-1" { + t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID) + } + + blockedHeaderKeys := []string{ + "Connection", + "Keep-Alive", + "X-Hop-A", + "X-Hop-B", + "X-Hop-C", + "Set-Cookie", + } + for _, key := range blockedHeaderKeys { + value := filtered.Get(key) + if value != "" { + t.Fatalf("expected %s to be removed, got %q", key, value) + } + } +} + +func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) { + src := http.Header{} + src.Add("Connection", "x-hop-a") + src.Set("X-Hop-A", "a") + src.Set("Set-Cookie", "session=secret") + + filtered := FilterUpstreamHeaders(src) + if filtered != nil { + t.Fatalf("expected nil when all headers are filtered, got %#v", filtered) + } +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index bcf09311..f2d44f05 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -153,7 +153,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { pinnedAuthID = strings.TrimSpace(authID) }) } - dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") + dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) if errForward != nil { diff --git a/sdk/cliproxy/auth/conductor_executor_replace_test.go b/sdk/cliproxy/auth/conductor_executor_replace_test.go index 3854f341..2ee91a87 100644 --- a/sdk/cliproxy/auth/conductor_executor_replace_test.go +++ b/sdk/cliproxy/auth/conductor_executor_replace_test.go @@ -24,10 +24,10 @@ func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor. return cliproxyexecutor.Response{}, nil } -func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { ch := make(chan cliproxyexecutor.StreamChunk) close(ch) - return ch, nil + return &cliproxyexecutor.StreamResult{Chunks: ch}, nil } func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { @@ -89,7 +89,11 @@ func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) { if !okResolved { t.Fatal("expected registered executor to be found") } - if resolved != current { + resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor) + if !okResolvedExecutor { + t.Fatalf("expected resolved executor type %T, got %T", current, resolved) + } + if resolvedExecutor != current { t.Fatal("expected resolved executor to match registered executor") }