fix: ensure connection-scoped headers are filtered in upstream requests

- Added `connectionScopedHeaders` utility to respect "Connection" header directives.
- Updated `FilterUpstreamHeaders` to remove connection-scoped headers dynamically.
- Refactored and tested upstream header filtering with additional validations.
- Adjusted upstream header handling during retries to replace headers safely.
This commit is contained in:
Luis Pater
2026-02-19 13:19:10 +08:00
parent 61da7bd981
commit 2789396435
7 changed files with 136 additions and 19 deletions

View File

@@ -363,7 +363,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
} }
} }
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model) log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
@@ -436,7 +436,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}) })
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
var upstreamHeaders http.Header
if respHS != nil { if respHS != nil {
upstreamHeaders = respHS.Header.Clone()
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone()) recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
} }
if errDial != nil { if errDial != nil {
@@ -516,7 +518,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
markCodexWebsocketCreateSent(sess, conn, wsReqBody) markCodexWebsocketCreateSent(sess, conn, wsReqBody)
out := make(chan cliproxyexecutor.StreamChunk) out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() { go func() {
terminateReason := "completed" terminateReason := "completed"
var terminateErr error var terminateErr error
@@ -627,7 +628,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
} }
}() }()
return stream, nil return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil
} }
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
@@ -1343,7 +1344,7 @@ func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return e.httpExec.Execute(ctx, auth, req, opts) return e.httpExec.Execute(ctx, auth, req, opts)
} }
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if e == nil || e.httpExec == nil || e.wsExec == nil { if e == nil || e.httpExec == nil || e.wsExec == nil {
return nil, fmt.Errorf("codex auto executor: executor is nil") return nil, fmt.Errorf("codex auto executor: executor is nil")
} }

View File

@@ -593,7 +593,11 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return nil, nil, errChan return nil, nil, errChan
} }
// Capture upstream headers from the initial connection synchronously before the goroutine starts. // Capture upstream headers from the initial connection synchronously before the goroutine starts.
upstreamHeaders := FilterUpstreamHeaders(streamResult.Headers) // Keep a mutable map so bootstrap retries can replace it before first payload is sent.
upstreamHeaders := cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
if upstreamHeaders == nil {
upstreamHeaders = make(http.Header)
}
chunks := streamResult.Chunks chunks := streamResult.Chunks
dataChan := make(chan []byte) dataChan := make(chan []byte)
errChan := make(chan *interfaces.ErrorMessage, 1) errChan := make(chan *interfaces.ErrorMessage, 1)
@@ -670,6 +674,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
bootstrapRetries++ bootstrapRetries++
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if retryErr == nil { if retryErr == nil {
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
chunks = retryResult.Chunks chunks = retryResult.Chunks
continue outer continue outer
} }
@@ -761,6 +766,26 @@ func cloneBytes(src []byte) []byte {
return dst return dst
} }
func cloneHeader(src http.Header) http.Header {
if src == nil {
return nil
}
dst := make(http.Header, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func replaceHeader(dst http.Header, src http.Header) {
for key := range dst {
delete(dst, key)
}
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
}
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message. // WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) { func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError status := http.StatusInternalServerError

View File

@@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth,
}, },
} }
close(ch) close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil return &coreexecutor.StreamResult{
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
Chunks: ch,
}, nil
} }
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch) close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil return &coreexecutor.StreamResult{
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
Chunks: ch,
}, nil
} }
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -134,7 +140,7 @@ func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coree
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
} }
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
_ = ctx _ = ctx
_ = req _ = req
_ = opts _ = opts
@@ -160,12 +166,12 @@ func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *corea
}, },
} }
close(ch) close(ch)
return ch, nil return &coreexecutor.StreamResult{Chunks: ch}, nil
} }
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch) close(ch)
return ch, nil return &coreexecutor.StreamResult{Chunks: ch}, nil
} }
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -235,7 +241,7 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
BootstrapRetries: 1, BootstrapRetries: 1,
}, },
}, manager) }, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil { if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels") t.Fatalf("expected non-nil channels")
} }
@@ -257,6 +263,10 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
if executor.Calls() != 2 { if executor.Calls() != 2 {
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls()) t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
} }
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
if upstreamAttemptHeader != "2" {
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
}
} }
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) { func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
@@ -367,7 +377,7 @@ func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T)
}, },
}, manager) }, manager)
ctx := WithPinnedAuthID(context.Background(), "auth1") ctx := WithPinnedAuthID(context.Background(), "auth1")
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil { if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels") t.Fatalf("expected non-nil channels")
} }
@@ -431,7 +441,7 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) { ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
selectedAuthID = authID selectedAuthID = authID
}) })
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil { if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels") t.Fatalf("expected non-nil channels")
} }

View File

@@ -1,6 +1,9 @@
package handlers package handlers
import "net/http" import (
"net/http"
"strings"
)
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT // hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
// be forwarded by proxies, plus security-sensitive headers that should not leak. // be forwarded by proxies, plus security-sensitive headers that should not leak.
@@ -27,9 +30,14 @@ func FilterUpstreamHeaders(src http.Header) http.Header {
if src == nil { if src == nil {
return nil return nil
} }
connectionScoped := connectionScopedHeaders(src)
dst := make(http.Header) dst := make(http.Header)
for key, values := range src { for key, values := range src {
if _, blocked := hopByHopHeaders[http.CanonicalHeaderKey(key)]; blocked { canonicalKey := http.CanonicalHeaderKey(key)
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
continue
}
if _, scoped := connectionScoped[canonicalKey]; scoped {
continue continue
} }
dst[key] = values dst[key] = values
@@ -40,6 +48,20 @@ func FilterUpstreamHeaders(src http.Header) http.Header {
return dst return dst
} }
func connectionScopedHeaders(src http.Header) map[string]struct{} {
scoped := make(map[string]struct{})
for _, rawValue := range src.Values("Connection") {
for _, token := range strings.Split(rawValue, ",") {
headerName := strings.TrimSpace(token)
if headerName == "" {
continue
}
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
}
}
return scoped
}
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer. // WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten. // Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
func WriteUpstreamHeaders(dst http.Header, src http.Header) { func WriteUpstreamHeaders(dst http.Header, src http.Header) {

View File

@@ -0,0 +1,55 @@
package handlers
import (
"net/http"
"testing"
)
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
src.Add("Connection", "x-hop-c")
src.Set("Keep-Alive", "timeout=5")
src.Set("X-Hop-A", "a")
src.Set("X-Hop-B", "b")
src.Set("X-Hop-C", "c")
src.Set("X-Request-Id", "req-1")
src.Set("Set-Cookie", "session=secret")
filtered := FilterUpstreamHeaders(src)
if filtered == nil {
t.Fatalf("expected filtered headers, got nil")
}
requestID := filtered.Get("X-Request-Id")
if requestID != "req-1" {
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
}
blockedHeaderKeys := []string{
"Connection",
"Keep-Alive",
"X-Hop-A",
"X-Hop-B",
"X-Hop-C",
"Set-Cookie",
}
for _, key := range blockedHeaderKeys {
value := filtered.Get(key)
if value != "" {
t.Fatalf("expected %s to be removed, got %q", key, value)
}
}
}
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
src := http.Header{}
src.Add("Connection", "x-hop-a")
src.Set("X-Hop-A", "a")
src.Set("Set-Cookie", "session=secret")
filtered := FilterUpstreamHeaders(src)
if filtered != nil {
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
}
}

View File

@@ -153,7 +153,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
pinnedAuthID = strings.TrimSpace(authID) pinnedAuthID = strings.TrimSpace(authID)
}) })
} }
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if errForward != nil { if errForward != nil {

View File

@@ -24,10 +24,10 @@ func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.
return cliproxyexecutor.Response{}, nil return cliproxyexecutor.Response{}, nil
} }
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
ch := make(chan cliproxyexecutor.StreamChunk) ch := make(chan cliproxyexecutor.StreamChunk)
close(ch) close(ch)
return ch, nil return &cliproxyexecutor.StreamResult{Chunks: ch}, nil
} }
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
@@ -89,7 +89,11 @@ func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
if !okResolved { if !okResolved {
t.Fatal("expected registered executor to be found") t.Fatal("expected registered executor to be found")
} }
if resolved != current { resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor)
if !okResolvedExecutor {
t.Fatalf("expected resolved executor type %T, got %T", current, resolved)
}
if resolvedExecutor != current {
t.Fatal("expected resolved executor to match registered executor") t.Fatal("expected resolved executor to match registered executor")
} }