feat: add passthrough headers configuration

- Introduced `passthrough-headers` option in configuration to control forwarding of upstream response headers.
- Updated handlers to respect the passthrough headers setting.
- Added tests to verify behavior when passthrough is enabled or disabled.
This commit is contained in:
Luis Pater
2026-02-19 21:31:29 +08:00
parent 2789396435
commit a6bdd9a652
4 changed files with 91 additions and 4 deletions

View File

@@ -68,6 +68,10 @@ proxy-url: ""
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
# When true, forward filtered upstream response headers to downstream clients.
# Default is false (disabled).
passthrough-headers: false
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
request-retry: 3

View File

@@ -20,6 +20,10 @@ type SDKConfig struct {
// APIKeys is a list of keys for authenticating clients to this proxy server.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
// Default is false (disabled).
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`

View File

@@ -179,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
return retries
}
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
// Default is false.
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
return cfg != nil && cfg.PassthroughHeaders
}
func requestExecutionMetadata(ctx context.Context) map[string]any {
// Idempotency-Key is an optional client-supplied header used to correlate retries.
// It is forwarded as execution metadata; when absent we generate a UUID.
@@ -499,6 +505,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
}
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
if !PassthroughHeadersEnabled(h.Cfg) {
return resp.Payload, nil, nil
}
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
}
@@ -542,6 +551,9 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
}
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
if !PassthroughHeadersEnabled(h.Cfg) {
return resp.Payload, nil, nil
}
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
}
@@ -592,11 +604,15 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
close(errChan)
return nil, nil, errChan
}
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
upstreamHeaders := cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
if upstreamHeaders == nil {
upstreamHeaders = make(http.Header)
var upstreamHeaders http.Header
if passthroughHeadersEnabled {
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
if upstreamHeaders == nil {
upstreamHeaders = make(http.Header)
}
}
chunks := streamResult.Chunks
dataChan := make(chan []byte)
@@ -674,7 +690,9 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
bootstrapRetries++
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if retryErr == nil {
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
if passthroughHeadersEnabled {
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
}
chunks = retryResult.Chunks
continue outer
}

View File

@@ -237,6 +237,7 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
PassthroughHeaders: true,
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
@@ -269,6 +270,66 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
}
}
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if upstreamHeaders != nil {
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
}
}
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
executor := &payloadThenErrorStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)