mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 08:23:03 +08:00
Merge branch 'codex/pr-1626' into dev
This commit is contained in:
@@ -68,6 +68,10 @@ proxy-url: ""
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# When true, forward filtered upstream response headers to downstream clients.
|
||||
# Default is false (disabled).
|
||||
passthrough-headers: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
|
||||
@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||
}()
|
||||
return ch, nil
|
||||
return &clipexec.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
|
||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ type SDKConfig struct {
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
|
||||
// Default is false (disabled).
|
||||
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
|
||||
|
||||
@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(first wsrelay.StreamEvent) {
|
||||
defer close(out)
|
||||
var param any
|
||||
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
}
|
||||
}(firstEvent)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||
|
||||
@@ -232,7 +232,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -436,7 +436,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
@@ -645,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -775,7 +775,6 @@ attemptLoop:
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -820,7 +819,7 @@ attemptLoop:
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
@@ -968,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
|
||||
@@ -223,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
data,
|
||||
¶m,
|
||||
)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -330,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -399,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -488,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
|
||||
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||
}
|
||||
@@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
@@ -363,7 +363,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
}
|
||||
|
||||
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
@@ -436,7 +436,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
})
|
||||
|
||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
var upstreamHeaders http.Header
|
||||
if respHS != nil {
|
||||
upstreamHeaders = respHS.Header.Clone()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
||||
}
|
||||
if errDial != nil {
|
||||
@@ -516,7 +518,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
terminateReason := "completed"
|
||||
var terminateErr error
|
||||
@@ -627,7 +628,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
|
||||
@@ -1343,7 +1344,7 @@ func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
return e.httpExec.Execute(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if e == nil || e.httpExec == nil || e.wsExec == nil {
|
||||
return nil, fmt.Errorf("codex auto executor: executor is nil")
|
||||
}
|
||||
|
||||
@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
if len(lastBody) > 0 {
|
||||
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
|
||||
@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini API.
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the Gemini API.
|
||||
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||
|
||||
@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
|
||||
@@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request.
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for Kimi requests.
|
||||
|
||||
@@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
// Ensure we record the request if no usage chunk was ever seen
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
@@ -150,11 +150,11 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -236,7 +236,6 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -268,7 +267,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
@@ -112,12 +112,13 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
|
||||
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -165,7 +166,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
@@ -194,6 +195,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
||||
}
|
||||
}
|
||||
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -225,7 +227,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
// This allows proper cleanup and cancellation of ongoing requests
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@@ -257,6 +259,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
if !ok {
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
@@ -264,6 +267,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
|
||||
// Success! Set headers now.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write the first chunk
|
||||
if len(chunk) > 0 {
|
||||
|
||||
@@ -159,7 +159,8 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
}
|
||||
@@ -172,12 +173,13 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
||||
}
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -223,6 +223,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
||||
if alt == "" {
|
||||
setSSEHeaders()
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
@@ -232,6 +233,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
||||
if alt == "" {
|
||||
setSSEHeaders()
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write first chunk
|
||||
if alt == "" {
|
||||
@@ -262,12 +264,13 @@ func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, r
|
||||
c.Header("Content-Type", "application/json")
|
||||
alt := h.GetAlt(c)
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -286,13 +289,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
||||
alt := h.GetAlt(c)
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
@@ -179,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||
return retries
|
||||
}
|
||||
|
||||
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
|
||||
// Default is false.
|
||||
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
|
||||
return cfg != nil && cfg.PassthroughHeaders
|
||||
}
|
||||
|
||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||
// It is forwarded as execution metadata; when absent we generate a UUID.
|
||||
@@ -461,10 +467,10 @@ func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
return nil, nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
@@ -497,17 +503,20 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
}
|
||||
return resp.Payload, nil
|
||||
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||
return resp.Payload, nil, nil
|
||||
}
|
||||
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||
}
|
||||
|
||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
return nil, nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
@@ -540,20 +549,24 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
}
|
||||
return resp.Payload, nil
|
||||
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||
return resp.Payload, nil, nil
|
||||
}
|
||||
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||
}
|
||||
|
||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
||||
// The returned http.Header carries upstream response headers captured before streaming begins.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
errChan <- errMsg
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
return nil, nil, errChan
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
@@ -572,7 +585,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = reqMeta
|
||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
status := http.StatusInternalServerError
|
||||
@@ -589,8 +602,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
}
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
return nil, nil, errChan
|
||||
}
|
||||
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
|
||||
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
|
||||
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
|
||||
var upstreamHeaders http.Header
|
||||
if passthroughHeadersEnabled {
|
||||
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
|
||||
if upstreamHeaders == nil {
|
||||
upstreamHeaders = make(http.Header)
|
||||
}
|
||||
}
|
||||
chunks := streamResult.Chunks
|
||||
dataChan := make(chan []byte)
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
go func() {
|
||||
@@ -664,9 +688,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
if !sentPayload {
|
||||
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
||||
bootstrapRetries++
|
||||
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if retryErr == nil {
|
||||
chunks = retryChunks
|
||||
if passthroughHeadersEnabled {
|
||||
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
|
||||
}
|
||||
chunks = retryResult.Chunks
|
||||
continue outer
|
||||
}
|
||||
streamErr = retryErr
|
||||
@@ -697,7 +724,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
}
|
||||
}
|
||||
}()
|
||||
return dataChan, errChan
|
||||
return dataChan, upstreamHeaders, errChan
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
@@ -757,6 +784,26 @@ func cloneBytes(src []byte) []byte {
|
||||
return dst
|
||||
}
|
||||
|
||||
func cloneHeader(src http.Header) http.Header {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := make(http.Header, len(src))
|
||||
for key, values := range src {
|
||||
dst[key] = append([]string(nil), values...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func replaceHeader(dst http.Header, src http.Header) {
|
||||
for key := range dst {
|
||||
delete(dst, key)
|
||||
}
|
||||
for key, values := range src {
|
||||
dst[key] = append([]string(nil), values...)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
||||
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||
status := http.StatusInternalServerError
|
||||
|
||||
@@ -23,7 +23,7 @@ func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreex
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
call := e.calls
|
||||
@@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth,
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{
|
||||
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
|
||||
Chunks: ch,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{
|
||||
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
|
||||
Chunks: ch,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
@@ -81,7 +87,7 @@ func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
e.mu.Unlock()
|
||||
@@ -97,7 +103,7 @@ func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreaut
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
@@ -134,7 +140,7 @@ func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coree
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
_ = opts
|
||||
@@ -160,12 +166,12 @@ func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *corea
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
@@ -231,11 +237,12 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
PassthroughHeaders: true,
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
@@ -257,6 +264,70 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
if executor.Calls() != 2 {
|
||||
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
||||
}
|
||||
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
|
||||
if upstreamAttemptHeader != "2" {
|
||||
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
|
||||
executor := &failOnceStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
t.Fatalf("unexpected error: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got) != "ok" {
|
||||
t.Fatalf("expected payload ok, got %q", string(got))
|
||||
}
|
||||
if upstreamHeaders != nil {
|
||||
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||
@@ -296,7 +367,7 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
@@ -367,7 +438,7 @@ func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T)
|
||||
},
|
||||
}, manager)
|
||||
ctx := WithPinnedAuthID(context.Background(), "auth1")
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
@@ -431,7 +502,7 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
||||
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
|
||||
selectedAuthID = authID
|
||||
})
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
80
sdk/api/handlers/header_filter.go
Normal file
80
sdk/api/handlers/header_filter.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
|
||||
// be forwarded by proxies, plus security-sensitive headers that should not leak.
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
// RFC 7230 hop-by-hop
|
||||
"Connection": {},
|
||||
"Keep-Alive": {},
|
||||
"Proxy-Authenticate": {},
|
||||
"Proxy-Authorization": {},
|
||||
"Te": {},
|
||||
"Trailer": {},
|
||||
"Transfer-Encoding": {},
|
||||
"Upgrade": {},
|
||||
// Security-sensitive
|
||||
"Set-Cookie": {},
|
||||
// CPA-managed (set by handlers, not upstream)
|
||||
"Content-Length": {},
|
||||
"Content-Encoding": {},
|
||||
}
|
||||
|
||||
// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive
|
||||
// headers removed. Returns nil if src is nil or empty after filtering.
|
||||
func FilterUpstreamHeaders(src http.Header) http.Header {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
connectionScoped := connectionScopedHeaders(src)
|
||||
dst := make(http.Header)
|
||||
for key, values := range src {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
|
||||
continue
|
||||
}
|
||||
if _, scoped := connectionScoped[canonicalKey]; scoped {
|
||||
continue
|
||||
}
|
||||
dst[key] = values
|
||||
}
|
||||
if len(dst) == 0 {
|
||||
return nil
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func connectionScopedHeaders(src http.Header) map[string]struct{} {
|
||||
scoped := make(map[string]struct{})
|
||||
for _, rawValue := range src.Values("Connection") {
|
||||
for _, token := range strings.Split(rawValue, ",") {
|
||||
headerName := strings.TrimSpace(token)
|
||||
if headerName == "" {
|
||||
continue
|
||||
}
|
||||
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
|
||||
}
|
||||
}
|
||||
return scoped
|
||||
}
|
||||
|
||||
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
|
||||
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
|
||||
func WriteUpstreamHeaders(dst http.Header, src http.Header) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
for key, values := range src {
|
||||
// Don't overwrite headers already set by CPA handlers
|
||||
if dst.Get(key) != "" {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
dst.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
55
sdk/api/handlers/header_filter_test.go
Normal file
55
sdk/api/handlers/header_filter_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
|
||||
src.Add("Connection", "x-hop-c")
|
||||
src.Set("Keep-Alive", "timeout=5")
|
||||
src.Set("X-Hop-A", "a")
|
||||
src.Set("X-Hop-B", "b")
|
||||
src.Set("X-Hop-C", "c")
|
||||
src.Set("X-Request-Id", "req-1")
|
||||
src.Set("Set-Cookie", "session=secret")
|
||||
|
||||
filtered := FilterUpstreamHeaders(src)
|
||||
if filtered == nil {
|
||||
t.Fatalf("expected filtered headers, got nil")
|
||||
}
|
||||
|
||||
requestID := filtered.Get("X-Request-Id")
|
||||
if requestID != "req-1" {
|
||||
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
|
||||
}
|
||||
|
||||
blockedHeaderKeys := []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"X-Hop-A",
|
||||
"X-Hop-B",
|
||||
"X-Hop-C",
|
||||
"Set-Cookie",
|
||||
}
|
||||
for _, key := range blockedHeaderKeys {
|
||||
value := filtered.Get(key)
|
||||
if value != "" {
|
||||
t.Fatalf("expected %s to be removed, got %q", key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "x-hop-a")
|
||||
src.Set("X-Hop-A", "a")
|
||||
src.Set("Set-Cookie", "session=secret")
|
||||
|
||||
filtered := FilterUpstreamHeaders(src)
|
||||
if filtered != nil {
|
||||
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
|
||||
}
|
||||
}
|
||||
@@ -431,12 +431,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -463,7 +464,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -496,6 +497,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
if !ok {
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
@@ -504,6 +506,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
|
||||
// Success! Commit to streaming headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||
flusher.Flush()
|
||||
@@ -531,13 +534,14 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
||||
_, _ = c.Writer.Write(completionsResp)
|
||||
cliCancel()
|
||||
@@ -568,7 +572,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
||||
|
||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -599,6 +603,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
@@ -607,6 +612,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
||||
|
||||
// Success! Set headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write the first chunk
|
||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||
|
||||
@@ -31,7 +31,7 @@ func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Aut
|
||||
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
|
||||
}
|
||||
|
||||
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -124,13 +124,14 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) {
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -149,13 +150,14 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -183,7 +185,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
// New core execution path
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -216,6 +218,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
if !ok {
|
||||
// Stream closed without data? Send headers and done.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
@@ -224,6 +227,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
|
||||
// Success! Set headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write first chunk logic (matching forwardResponsesStream)
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
|
||||
@@ -153,7 +153,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
pinnedAuthID = strings.TrimSpace(authID)
|
||||
})
|
||||
}
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||
if errForward != nil {
|
||||
|
||||
@@ -30,8 +30,9 @@ type ProviderExecutor interface {
|
||||
Identifier() string
|
||||
// Execute handles non-streaming execution and returns the provider response payload.
|
||||
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
|
||||
// ExecuteStream handles streaming execution and returns a channel of provider chunks.
|
||||
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error)
|
||||
// ExecuteStream handles streaming execution and returns a StreamResult containing
|
||||
// upstream headers and a channel of provider chunks.
|
||||
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error)
|
||||
// Refresh attempts to refresh provider credentials and returns the updated auth state.
|
||||
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
|
||||
// CountTokens returns the token count for the given request.
|
||||
@@ -558,7 +559,7 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
|
||||
// ExecuteStream performs a streaming execution using the configured selector and executor.
|
||||
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
||||
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
normalized := m.normalizeProviders(providers)
|
||||
if len(normalized) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
@@ -568,9 +569,9 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; ; attempt++ {
|
||||
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||
if errStream == nil {
|
||||
return chunks, nil
|
||||
return result, nil
|
||||
}
|
||||
lastErr = errStream
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
|
||||
@@ -699,7 +700,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
@@ -730,7 +731,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
@@ -778,8 +779,11 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: streamResult.Headers,
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,10 +24,10 @@ func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
ch := make(chan cliproxyexecutor.StreamChunk)
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &cliproxyexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
@@ -89,7 +89,11 @@ func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
|
||||
if !okResolved {
|
||||
t.Fatal("expected registered executor to be found")
|
||||
}
|
||||
if resolved != current {
|
||||
resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor)
|
||||
if !okResolvedExecutor {
|
||||
t.Fatalf("expected resolved executor type %T, got %T", current, resolved)
|
||||
}
|
||||
if resolvedExecutor != current {
|
||||
t.Fatal("expected resolved executor to match registered executor")
|
||||
}
|
||||
|
||||
|
||||
@@ -57,6 +57,8 @@ type Response struct {
|
||||
Payload []byte
|
||||
// Metadata exposes optional structured data for translators.
|
||||
Metadata map[string]any
|
||||
// Headers carries upstream HTTP response headers for passthrough to clients.
|
||||
Headers http.Header
|
||||
}
|
||||
|
||||
// StreamChunk represents a single streaming payload unit emitted by provider executors.
|
||||
@@ -67,6 +69,15 @@ type StreamChunk struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// StreamResult wraps the streaming response, providing both the chunk channel
|
||||
// and the upstream HTTP response headers captured before streaming begins.
|
||||
type StreamResult struct {
|
||||
// Headers carries upstream HTTP response headers from the initial connection.
|
||||
Headers http.Header
|
||||
// Chunks is the channel of streaming payload units.
|
||||
Chunks <-chan StreamChunk
|
||||
}
|
||||
|
||||
// StatusError represents an error that carries an HTTP-like status code.
|
||||
// Provider executors should implement this when possible to enable
|
||||
// better auth state updates on failures (e.g., 401/402/429).
|
||||
|
||||
Reference in New Issue
Block a user