mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 13:36:08 +08:00
feat: add passthrough headers configuration
- Introduced `passthrough-headers` option in configuration to control forwarding of upstream response headers. - Updated handlers to respect the passthrough headers setting. - Added tests to verify behavior when passthrough is enabled or disabled.
This commit is contained in:
@@ -68,6 +68,10 @@ proxy-url: ""
|
|||||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||||
force-model-prefix: false
|
force-model-prefix: false
|
||||||
|
|
||||||
|
# When true, forward filtered upstream response headers to downstream clients.
|
||||||
|
# Default is false (disabled).
|
||||||
|
passthrough-headers: false
|
||||||
|
|
||||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||||
request-retry: 3
|
request-retry: 3
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ type SDKConfig struct {
|
|||||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||||
|
|
||||||
|
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
|
||||||
|
// Default is false (disabled).
|
||||||
|
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
|
||||||
|
|
||||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||||
|
|
||||||
|
|||||||
@@ -179,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
|||||||
return retries
|
return retries
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
|
||||||
|
// Default is false.
|
||||||
|
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
|
||||||
|
return cfg != nil && cfg.PassthroughHeaders
|
||||||
|
}
|
||||||
|
|
||||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||||
// It is forwarded as execution metadata; when absent we generate a UUID.
|
// It is forwarded as execution metadata; when absent we generate a UUID.
|
||||||
@@ -499,6 +505,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
}
|
}
|
||||||
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||||
}
|
}
|
||||||
|
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||||
|
return resp.Payload, nil, nil
|
||||||
|
}
|
||||||
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -542,6 +551,9 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
}
|
}
|
||||||
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||||
}
|
}
|
||||||
|
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||||
|
return resp.Payload, nil, nil
|
||||||
|
}
|
||||||
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -592,11 +604,15 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
close(errChan)
|
close(errChan)
|
||||||
return nil, nil, errChan
|
return nil, nil, errChan
|
||||||
}
|
}
|
||||||
|
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
|
||||||
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
|
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
|
||||||
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
|
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
|
||||||
upstreamHeaders := cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
|
var upstreamHeaders http.Header
|
||||||
if upstreamHeaders == nil {
|
if passthroughHeadersEnabled {
|
||||||
upstreamHeaders = make(http.Header)
|
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
|
||||||
|
if upstreamHeaders == nil {
|
||||||
|
upstreamHeaders = make(http.Header)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
chunks := streamResult.Chunks
|
chunks := streamResult.Chunks
|
||||||
dataChan := make(chan []byte)
|
dataChan := make(chan []byte)
|
||||||
@@ -674,7 +690,9 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
bootstrapRetries++
|
bootstrapRetries++
|
||||||
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
|
if passthroughHeadersEnabled {
|
||||||
|
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
|
||||||
|
}
|
||||||
chunks = retryResult.Chunks
|
chunks = retryResult.Chunks
|
||||||
continue outer
|
continue outer
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -237,6 +237,7 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
PassthroughHeaders: true,
|
||||||
Streaming: sdkconfig.StreamingConfig{
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
BootstrapRetries: 1,
|
BootstrapRetries: 1,
|
||||||
},
|
},
|
||||||
@@ -269,6 +270,66 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
|
||||||
|
executor := &failOnceStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth2 := &coreauth.Auth{
|
||||||
|
ID: "auth2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test2@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
|
BootstrapRetries: 1,
|
||||||
|
},
|
||||||
|
}, manager)
|
||||||
|
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("unexpected error: %+v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(got) != "ok" {
|
||||||
|
t.Fatalf("expected payload ok, got %q", string(got))
|
||||||
|
}
|
||||||
|
if upstreamHeaders != nil {
|
||||||
|
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||||
executor := &payloadThenErrorStreamExecutor{}
|
executor := &payloadThenErrorStreamExecutor{}
|
||||||
manager := coreauth.NewManager(nil, nil, nil)
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
|||||||
Reference in New Issue
Block a user