mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 09:03:59 +08:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6ec33e8e1 | ||
|
|
081cfe806e | ||
|
|
c1c62a6c04 | ||
|
|
2fdf5d2793 | ||
|
|
7b0eb41ebc | ||
|
|
ef5901c81b | ||
|
|
d4829c82f7 | ||
|
|
a5f4166a9b | ||
|
|
1cc21cc45b | ||
|
|
07cf616e2b | ||
|
|
4445a165e9 | ||
|
|
e92e2af71a | ||
|
|
a6bdd9a652 | ||
|
|
349a6349b3 | ||
|
|
00822770ec | ||
|
|
1a0ceda0fc | ||
|
|
72add453d2 | ||
|
|
2789396435 | ||
|
|
61da7bd981 | ||
|
|
9c040445af | ||
|
|
fff866424e | ||
|
|
2d12becfd6 | ||
|
|
252f7e0751 | ||
|
|
b2b17528cb | ||
|
|
55f938164b | ||
|
|
76294f0c59 | ||
|
|
5fa23c7f41 | ||
|
|
73dc0b10b8 | ||
|
|
2ea95266e3 | ||
|
|
1f8f198c45 | ||
|
|
9261b0c20b | ||
|
|
7cc725496e | ||
|
|
709d999f9f | ||
|
|
24c18614f0 | ||
|
|
603f06a762 | ||
|
|
98f0a3e3bd |
@@ -161,6 +161,12 @@ Those projects are ports of CLIProxyAPI or inspired by it:
|
||||
|
||||
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
|
||||
|
||||
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
|
||||
|
||||
Never stop coding. Smart routing to FREE & low-cost AI models with automatic fallback.
|
||||
|
||||
OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoint with smart routing, load balancing, retries, and fallbacks. Add policies, rate limits, caching, and observability for reliable, cost-aware inference.
|
||||
|
||||
> [!NOTE]
|
||||
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
|
||||
|
||||
|
||||
@@ -160,6 +160,12 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
|
||||
|
||||
基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
|
||||
|
||||
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
|
||||
|
||||
代码不止,创新不停。智能路由至免费及低成本 AI 模型,并支持自动故障转移。
|
||||
|
||||
OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼容 OpenAI 的端点,具备智能路由、负载均衡、重试及回退机制。通过添加策略、速率限制、缓存和可观测性,确保推理过程既可靠又具备成本意识。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
||||
|
||||
|
||||
@@ -68,6 +68,10 @@ proxy-url: ""
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# When true, forward filtered upstream response headers to downstream clients.
|
||||
# Default is false (disabled).
|
||||
passthrough-headers: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -155,6 +159,15 @@ nonstream-keepalive-interval: 0
|
||||
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# timeout: "600"
|
||||
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
|
||||
@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||
}()
|
||||
return ch, nil
|
||||
return &clipexec.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
|
||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
// OAuth configuration constants for Claude/Anthropic
|
||||
const (
|
||||
AuthURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
TokenURL = "https://api.anthropic.com/v1/oauth/token"
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
RedirectURI = "http://localhost:54545/callback"
|
||||
)
|
||||
|
||||
@@ -90,6 +90,10 @@ type Config struct {
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values for Claude API requests.
|
||||
// These are used as fallbacks when the client does not send its own headers.
|
||||
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
|
||||
|
||||
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
|
||||
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
|
||||
|
||||
@@ -117,6 +121,15 @@ type Config struct {
|
||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||
type ClaudeHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
@@ -288,6 +301,10 @@ type CloakConfig struct {
|
||||
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||
// This can help bypass certain content filters.
|
||||
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||
|
||||
// CacheUserID controls whether Claude user_id values are cached per API key.
|
||||
// When false, a fresh random user_id is generated for every request.
|
||||
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeKey represents the configuration for a Claude API key,
|
||||
|
||||
@@ -20,6 +20,10 @@ type SDKConfig struct {
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
|
||||
// Default is false (disabled).
|
||||
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
|
||||
|
||||
@@ -184,6 +184,21 @@ func GetGeminiModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -294,6 +309,21 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 1, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
Object: "model",
|
||||
@@ -517,6 +547,21 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
|
||||
@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(first wsrelay.StreamEvent) {
|
||||
defer close(out)
|
||||
var param any
|
||||
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
}
|
||||
}(firstEvent)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||
|
||||
@@ -232,7 +232,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
@@ -436,7 +436,7 @@ attemptLoop:
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
@@ -645,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -775,7 +775,6 @@ attemptLoop:
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -820,7 +819,7 @@ attemptLoop:
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
@@ -968,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -116,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
|
||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
} else {
|
||||
reporter.publish(ctx, parseClaudeUsage(data))
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||
}
|
||||
var param any
|
||||
@@ -222,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
data,
|
||||
¶m,
|
||||
)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -257,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -329,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -348,7 +348,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
// Forward the line as-is to preserve SSE format
|
||||
@@ -375,7 +375,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(
|
||||
@@ -398,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
@@ -423,7 +423,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -487,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
@@ -638,7 +638,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if cfgVal != "" {
|
||||
return cfgVal
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hd config.ClaudeHeaderDefaults
|
||||
if cfg != nil {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
|
||||
if isAnthropicBase && useAPIKey {
|
||||
@@ -685,16 +727,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
if stream {
|
||||
@@ -702,6 +745,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
} else {
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
@@ -753,11 +798,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// Collect built-in tool names (those with a non-empty "type" field) so we can
|
||||
// skip them consistently in both tools and message history.
|
||||
builtinTools := map[string]bool{}
|
||||
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
|
||||
builtinTools[name] = true
|
||||
}
|
||||
|
||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||
// Skip built-in tools (web_search, code_execution, etc.) which have
|
||||
// a "type" field and require their name to remain unchanged.
|
||||
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
||||
if n := tool.Get("name").String(); n != "" {
|
||||
builtinTools[n] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
name := tool.Get("name").String()
|
||||
@@ -772,7 +827,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
|
||||
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
|
||||
name := gjson.GetBytes(body, "tool_choice.name").String()
|
||||
if name != "" && !strings.HasPrefix(name, prefix) {
|
||||
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
|
||||
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
|
||||
}
|
||||
}
|
||||
@@ -784,15 +839,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return true
|
||||
}
|
||||
content.ForEach(func(contentIndex, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+toolName)
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
nestedContent := part.Get("content")
|
||||
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||
if nestedPart.Get("type").String() == "tool_reference" {
|
||||
nestedToolName := nestedPart.Get("tool_name").String()
|
||||
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
|
||||
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
@@ -811,15 +889,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
content.ForEach(func(index, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if !strings.HasPrefix(toolName, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.tool_name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
nestedContent := part.Get("content")
|
||||
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||
if nestedPart.Get("type").String() == "tool_reference" {
|
||||
nestedToolName := nestedPart.Get("tool_name").String()
|
||||
if strings.HasPrefix(nestedToolName, prefix) {
|
||||
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
return true
|
||||
})
|
||||
return body
|
||||
@@ -834,15 +935,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
||||
return line
|
||||
}
|
||||
contentBlock := gjson.GetBytes(payload, "content_block")
|
||||
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
|
||||
if !contentBlock.Exists() {
|
||||
return line
|
||||
}
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
|
||||
blockType := contentBlock.Get("type").String()
|
||||
var updated []byte
|
||||
var err error
|
||||
|
||||
switch blockType {
|
||||
case "tool_use":
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
case "tool_reference":
|
||||
toolName := contentBlock.Get("tool_name").String()
|
||||
if !strings.HasPrefix(toolName, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
default:
|
||||
return line
|
||||
}
|
||||
|
||||
@@ -862,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string {
|
||||
}
|
||||
|
||||
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
||||
// Returns (cloakMode, strictMode, sensitiveWords).
|
||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
||||
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
|
||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
|
||||
if auth == nil || auth.Attributes == nil {
|
||||
return "auto", false, nil
|
||||
return "auto", false, nil, false
|
||||
}
|
||||
|
||||
cloakMode := auth.Attributes["cloak_mode"]
|
||||
@@ -883,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
||||
}
|
||||
}
|
||||
|
||||
return cloakMode, strictMode, sensitiveWords
|
||||
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
|
||||
|
||||
return cloakMode, strictMode, sensitiveWords, cacheUserID
|
||||
}
|
||||
|
||||
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||
@@ -916,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c
|
||||
}
|
||||
|
||||
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
||||
func injectFakeUserID(payload []byte) []byte {
|
||||
// When useCache is false, a new user ID is generated for every call.
|
||||
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||
generateID := func() string {
|
||||
if useCache {
|
||||
return cachedUserID(apiKey)
|
||||
}
|
||||
return generateFakeUserID()
|
||||
}
|
||||
|
||||
metadata := gjson.GetBytes(payload, "metadata")
|
||||
if !metadata.Exists() {
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||
return payload
|
||||
}
|
||||
|
||||
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||
if existingUserID == "" || !isValidUserID(existingUserID) {
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@@ -962,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
|
||||
// applyCloaking applies cloaking transformations to the payload based on config and client.
|
||||
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
||||
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
|
||||
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
|
||||
clientUserAgent := getClientUserAgent(ctx)
|
||||
|
||||
// Get cloak config from ClaudeKey configuration
|
||||
@@ -972,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
var cloakMode string
|
||||
var strictMode bool
|
||||
var sensitiveWords []string
|
||||
var cacheUserID bool
|
||||
|
||||
if cloakCfg != nil {
|
||||
cloakMode = cloakCfg.Mode
|
||||
strictMode = cloakCfg.StrictMode
|
||||
sensitiveWords = cloakCfg.SensitiveWords
|
||||
if cloakCfg.CacheUserID != nil {
|
||||
cacheUserID = *cloakCfg.CacheUserID
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to auth attributes if no config found
|
||||
if cloakMode == "" {
|
||||
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
|
||||
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
|
||||
cloakMode = attrMode
|
||||
if !strictMode {
|
||||
strictMode = attrStrict
|
||||
@@ -989,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
if len(sensitiveWords) == 0 {
|
||||
sensitiveWords = attrWords
|
||||
}
|
||||
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||
cacheUserID = attrCache
|
||||
}
|
||||
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||
_, _, _, attrCache := getCloakConfigFromAuth(auth)
|
||||
cacheUserID = attrCache
|
||||
}
|
||||
|
||||
// Determine if cloaking should be applied
|
||||
@@ -1002,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
// Inject fake user ID
|
||||
payload = injectFakeUserID(payload)
|
||||
payload = injectFakeUserID(payload, apiKey, cacheUserID)
|
||||
|
||||
// Apply sensitive word obfuscation
|
||||
if len(sensitiveWords) > 0 {
|
||||
|
||||
@@ -2,9 +2,18 @@ package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
@@ -25,6 +34,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
|
||||
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
|
||||
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
@@ -37,6 +58,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [{"name": "Read"}, {"name": "Write"}],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
|
||||
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search"},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"tool_choice": {"type": "tool", "name": "web_search"}
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
@@ -49,6 +161,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
|
||||
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
|
||||
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
@@ -61,3 +185,166 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
|
||||
payload := bytes.TrimSpace(out)
|
||||
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
|
||||
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "proxy_mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
var userIDs []string
|
||||
var requestModels []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
userID := gjson.GetBytes(body, "metadata.user_id").String()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
userIDs = append(userIDs, userID)
|
||||
requestModels = append(requestModels, model)
|
||||
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
|
||||
|
||||
cacheEnabled := true
|
||||
executor := NewClaudeExecutor(&config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{
|
||||
APIKey: "key-123",
|
||||
BaseURL: server.URL,
|
||||
Cloak: &config.CloakConfig{
|
||||
CacheUserID: &cacheEnabled,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
|
||||
for _, model := range models {
|
||||
t.Logf("Sending request for model: %s", model)
|
||||
modelPayload, _ := sjson.SetBytes(payload, "model", model)
|
||||
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: model,
|
||||
Payload: modelPayload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute(%s) error: %v", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(userIDs) != 2 {
|
||||
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||
}
|
||||
if userIDs[0] == "" || userIDs[1] == "" {
|
||||
t.Fatal("expected user_id to be populated")
|
||||
}
|
||||
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
|
||||
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
|
||||
if userIDs[0] != userIDs[1] {
|
||||
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
|
||||
}
|
||||
if !isValidUserID(userIDs[0]) {
|
||||
t.Fatalf("user_id %q is not valid", userIDs[0])
|
||||
}
|
||||
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
var userIDs []string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
}); err != nil {
|
||||
t.Fatalf("Execute call %d error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(userIDs) != 2 {
|
||||
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||
}
|
||||
if userIDs[0] == "" || userIDs[1] == "" {
|
||||
t.Fatal("expected user_id to be populated")
|
||||
}
|
||||
if userIDs[0] == userIDs[1] {
|
||||
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
|
||||
}
|
||||
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
|
||||
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
|
||||
if got != "mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
|
||||
// tool_result.content can be a string - should not be processed
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
|
||||
if got != "plain string result" {
|
||||
t.Fatalf("string content should remain unchanged = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "web_search" {
|
||||
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||
}
|
||||
@@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
@@ -363,7 +363,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
}
|
||||
|
||||
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
@@ -436,7 +436,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
})
|
||||
|
||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
var upstreamHeaders http.Header
|
||||
if respHS != nil {
|
||||
upstreamHeaders = respHS.Header.Clone()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
||||
}
|
||||
if errDial != nil {
|
||||
@@ -516,7 +518,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
terminateReason := "completed"
|
||||
var terminateErr error
|
||||
@@ -627,7 +628,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
|
||||
@@ -1343,7 +1344,7 @@ func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
return e.httpExec.Execute(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if e == nil || e.httpExec == nil || e.wsExec == nil {
|
||||
return nil, fmt.Errorf("codex auto executor: executor is nil")
|
||||
}
|
||||
|
||||
@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
if len(lastBody) > 0 {
|
||||
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
|
||||
@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Gemini API.
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the Gemini API.
|
||||
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||
|
||||
@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
|
||||
@@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request.
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for Kimi requests.
|
||||
|
||||
@@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
// Ensure we record the request if no usage chunk was ever seen
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
@@ -150,11 +150,11 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||
}
|
||||
@@ -236,7 +236,6 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
@@ -268,7 +267,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
|
||||
89
internal/runtime/executor/user_id_cache.go
Normal file
89
internal/runtime/executor/user_id_cache.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type userIDCacheEntry struct {
|
||||
value string
|
||||
expire time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
userIDCache = make(map[string]userIDCacheEntry)
|
||||
userIDCacheMu sync.RWMutex
|
||||
userIDCacheCleanupOnce sync.Once
|
||||
)
|
||||
|
||||
const (
|
||||
userIDTTL = time.Hour
|
||||
userIDCacheCleanupPeriod = 15 * time.Minute
|
||||
)
|
||||
|
||||
func startUserIDCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(userIDCacheCleanupPeriod)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredUserIDs()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func purgeExpiredUserIDs() {
|
||||
now := time.Now()
|
||||
userIDCacheMu.Lock()
|
||||
for key, entry := range userIDCache {
|
||||
if !entry.expire.After(now) {
|
||||
delete(userIDCache, key)
|
||||
}
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func userIDCacheKey(apiKey string) string {
|
||||
sum := sha256.Sum256([]byte(apiKey))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func cachedUserID(apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return generateFakeUserID()
|
||||
}
|
||||
|
||||
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
|
||||
|
||||
key := userIDCacheKey(apiKey)
|
||||
now := time.Now()
|
||||
|
||||
userIDCacheMu.RLock()
|
||||
entry, ok := userIDCache[key]
|
||||
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
|
||||
userIDCacheMu.RUnlock()
|
||||
if valid {
|
||||
userIDCacheMu.Lock()
|
||||
entry = userIDCache[key]
|
||||
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
|
||||
entry.expire = now.Add(userIDTTL)
|
||||
userIDCache[key] = entry
|
||||
userIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
newID := generateFakeUserID()
|
||||
|
||||
userIDCacheMu.Lock()
|
||||
entry, ok = userIDCache[key]
|
||||
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
|
||||
entry.value = newID
|
||||
}
|
||||
entry.expire = now.Add(userIDTTL)
|
||||
userIDCache[key] = entry
|
||||
userIDCacheMu.Unlock()
|
||||
return entry.value
|
||||
}
|
||||
86
internal/runtime/executor/user_id_cache_test.go
Normal file
86
internal/runtime/executor/user_id_cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func resetUserIDCache() {
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache = make(map[string]userIDCacheEntry)
|
||||
userIDCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
first := cachedUserID("api-key-1")
|
||||
second := cachedUserID("api-key-1")
|
||||
|
||||
if first == "" {
|
||||
t.Fatal("expected generated user_id to be non-empty")
|
||||
}
|
||||
if first != second {
|
||||
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
expiredID := cachedUserID("api-key-expired")
|
||||
cacheKey := userIDCacheKey("api-key-expired")
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache[cacheKey] = userIDCacheEntry{
|
||||
value: expiredID,
|
||||
expire: time.Now().Add(-time.Minute),
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
|
||||
newID := cachedUserID("api-key-expired")
|
||||
if newID == expiredID {
|
||||
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
|
||||
}
|
||||
if newID == "" {
|
||||
t.Fatal("expected regenerated user_id to be non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
first := cachedUserID("api-key-1")
|
||||
second := cachedUserID("api-key-2")
|
||||
|
||||
if first == second {
|
||||
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
||||
resetUserIDCache()
|
||||
|
||||
key := "api-key-renew"
|
||||
id := cachedUserID(key)
|
||||
cacheKey := userIDCacheKey(key)
|
||||
|
||||
soon := time.Now()
|
||||
userIDCacheMu.Lock()
|
||||
userIDCache[cacheKey] = userIDCacheEntry{
|
||||
value: id,
|
||||
expire: soon.Add(2 * time.Second),
|
||||
}
|
||||
userIDCacheMu.Unlock()
|
||||
|
||||
if refreshed := cachedUserID(key); refreshed != id {
|
||||
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
|
||||
}
|
||||
|
||||
userIDCacheMu.RLock()
|
||||
entry := userIDCache[cacheKey]
|
||||
userIDCacheMu.RUnlock()
|
||||
|
||||
if entry.expire.Sub(soon) < 30*time.Minute {
|
||||
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
|
||||
}
|
||||
}
|
||||
@@ -10,10 +10,53 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// validReasoningEffortLevels contains the standard values accepted by the
|
||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
||||
// auto) are NOT in this set and must be clamped before use.
|
||||
var validReasoningEffortLevels = map[string]struct{}{
|
||||
"none": {},
|
||||
"low": {},
|
||||
"medium": {},
|
||||
"high": {},
|
||||
}
|
||||
|
||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
||||
// mapped to the nearest standard equivalent.
|
||||
//
|
||||
// Mapping rules:
|
||||
// - none / low / medium / high → returned as-is (already valid)
|
||||
// - xhigh → "high" (nearest lower standard level)
|
||||
// - minimal → "low" (nearest higher standard level)
|
||||
// - auto → "medium" (reasonable default)
|
||||
// - anything else → "medium" (safe default)
|
||||
func clampReasoningEffort(level string) string {
|
||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
||||
return level
|
||||
}
|
||||
var clamped string
|
||||
switch level {
|
||||
case string(thinking.LevelXHigh):
|
||||
clamped = string(thinking.LevelHigh)
|
||||
case string(thinking.LevelMinimal):
|
||||
clamped = string(thinking.LevelLow)
|
||||
case string(thinking.LevelAuto):
|
||||
clamped = string(thinking.LevelMedium)
|
||||
default:
|
||||
clamped = string(thinking.LevelMedium)
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"original": level,
|
||||
"clamped": clamped,
|
||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
||||
return clamped
|
||||
}
|
||||
|
||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||
//
|
||||
// OpenAI-specific behavior:
|
||||
@@ -58,7 +101,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
}
|
||||
|
||||
if config.Mode == thinking.ModeLevel {
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -79,7 +122,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -114,7 +157,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -231,8 +231,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
} else if functionResponseResult.Raw != "" {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
// Content field is missing entirely — .Raw is empty which
|
||||
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
|
||||
@@ -661,6 +661,85 @@ func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
|
||||
// Bug repro: tool_result with no content field produces invalid JSON
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "MyTool-123-456",
|
||||
"name": "MyTool",
|
||||
"input": {"key": "value"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "MyTool-123-456"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
|
||||
// Verify the functionResponse has a valid result value
|
||||
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
|
||||
if !fr.Exists() {
|
||||
t.Error("functionResponse.response.result should exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
||||
// Bug repro: tool_result with null content produces invalid JSON
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-opus-4-6-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "MyTool-123-456",
|
||||
"name": "MyTool",
|
||||
"input": {"key": "value"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "MyTool-123-456",
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||
outputStr := string(output)
|
||||
|
||||
if !gjson.Valid(outputStr) {
|
||||
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||
// When tools + thinking but no system instruction, should create one with hint
|
||||
inputJSON := []byte(`{
|
||||
|
||||
@@ -22,8 +22,9 @@ var (
|
||||
|
||||
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
||||
type ConvertCodexResponseToClaudeParams struct {
|
||||
HasToolCall bool
|
||||
BlockIndex int
|
||||
HasToolCall bool
|
||||
BlockIndex int
|
||||
HasReceivedArgumentsDelta bool
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
||||
@@ -137,6 +138,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||
@@ -171,12 +173,29 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
}
|
||||
} else if typeStr == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
} else if typeStr == "response.function_call_arguments.done" {
|
||||
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
|
||||
// in a single "done" event without preceding "delta" events.
|
||||
// Emit the full arguments as a single input_json_delta so the
|
||||
// downstream Claude client receives the complete tool input.
|
||||
// When delta events were already received, skip to avoid duplicating arguments.
|
||||
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
|
||||
if args := rootResult.Get("arguments").String(); args != "" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", args)
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []string{output}
|
||||
|
||||
@@ -20,10 +20,12 @@ var (
|
||||
|
||||
// ConvertCliToOpenAIParams holds parameters for response conversion.
|
||||
type ConvertCliToOpenAIParams struct {
|
||||
ResponseID string
|
||||
CreatedAt int64
|
||||
Model string
|
||||
FunctionCallIndex int
|
||||
ResponseID string
|
||||
CreatedAt int64
|
||||
Model string
|
||||
FunctionCallIndex int
|
||||
HasReceivedArgumentsDelta bool
|
||||
HasToolCallAnnounced bool
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
|
||||
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
|
||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
if *param == nil {
|
||||
*param = &ConvertCliToOpenAIParams{
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
FunctionCallIndex: -1,
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
FunctionCallIndex: -1,
|
||||
HasReceivedArgumentsDelta: false,
|
||||
HasToolCallAnnounced: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
} else if dataType == "response.output_item.done" {
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
} else if dataType == "response.output_item.added" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if itemResult.Exists() {
|
||||
if itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// set the index
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened
|
||||
name := itemResult.Get("name").String()
|
||||
// Build reverse map on demand from original request tools
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Increment index for this new function call item.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
|
||||
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
|
||||
|
||||
deltaValue := rootResult.Get("delta").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.done" {
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
|
||||
// Arguments were already streamed via delta events; nothing to emit.
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Fallback: no delta events were received, emit the full arguments as a single chunk.
|
||||
fullArgs := rootResult.Get("arguments").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
|
||||
// Tool call was already announced via output_item.added; skip emission.
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Fallback path: model skipped output_item.added, so emit complete tool call now.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
@@ -112,12 +112,13 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
|
||||
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -165,7 +166,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
@@ -194,6 +195,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
||||
}
|
||||
}
|
||||
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -225,7 +227,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
// This allows proper cleanup and cancellation of ongoing requests
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@@ -257,6 +259,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
if !ok {
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
@@ -264,6 +267,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
|
||||
// Success! Set headers now.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write the first chunk
|
||||
if len(chunk) > 0 {
|
||||
|
||||
@@ -159,7 +159,8 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
}
|
||||
@@ -172,12 +173,13 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
||||
}
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -223,6 +223,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
||||
if alt == "" {
|
||||
setSSEHeaders()
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
@@ -232,6 +233,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
||||
if alt == "" {
|
||||
setSSEHeaders()
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write first chunk
|
||||
if alt == "" {
|
||||
@@ -262,12 +264,13 @@ func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, r
|
||||
c.Header("Content-Type", "application/json")
|
||||
alt := h.GetAlt(c)
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -286,13 +289,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
||||
alt := h.GetAlt(c)
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
@@ -179,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||
return retries
|
||||
}
|
||||
|
||||
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
|
||||
// Default is false.
|
||||
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
|
||||
return cfg != nil && cfg.PassthroughHeaders
|
||||
}
|
||||
|
||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||
// It is forwarded as execution metadata; when absent we generate a UUID.
|
||||
@@ -461,10 +467,10 @@ func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
return nil, nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
@@ -497,17 +503,20 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
}
|
||||
return resp.Payload, nil
|
||||
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||
return resp.Payload, nil, nil
|
||||
}
|
||||
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||
}
|
||||
|
||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
return nil, nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
@@ -540,20 +549,24 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
}
|
||||
return resp.Payload, nil
|
||||
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||
return resp.Payload, nil, nil
|
||||
}
|
||||
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||
}
|
||||
|
||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
||||
// The returned http.Header carries upstream response headers captured before streaming begins.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
errChan <- errMsg
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
return nil, nil, errChan
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
@@ -572,7 +585,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = reqMeta
|
||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
status := http.StatusInternalServerError
|
||||
@@ -589,8 +602,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
}
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
return nil, nil, errChan
|
||||
}
|
||||
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
|
||||
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
|
||||
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
|
||||
var upstreamHeaders http.Header
|
||||
if passthroughHeadersEnabled {
|
||||
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
|
||||
if upstreamHeaders == nil {
|
||||
upstreamHeaders = make(http.Header)
|
||||
}
|
||||
}
|
||||
chunks := streamResult.Chunks
|
||||
dataChan := make(chan []byte)
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
go func() {
|
||||
@@ -664,9 +688,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
if !sentPayload {
|
||||
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
||||
bootstrapRetries++
|
||||
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if retryErr == nil {
|
||||
chunks = retryChunks
|
||||
if passthroughHeadersEnabled {
|
||||
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
|
||||
}
|
||||
chunks = retryResult.Chunks
|
||||
continue outer
|
||||
}
|
||||
streamErr = retryErr
|
||||
@@ -697,7 +724,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
}
|
||||
}
|
||||
}()
|
||||
return dataChan, errChan
|
||||
return dataChan, upstreamHeaders, errChan
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
@@ -757,13 +784,33 @@ func cloneBytes(src []byte) []byte {
|
||||
return dst
|
||||
}
|
||||
|
||||
func cloneHeader(src http.Header) http.Header {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := make(http.Header, len(src))
|
||||
for key, values := range src {
|
||||
dst[key] = append([]string(nil), values...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func replaceHeader(dst http.Header, src http.Header) {
|
||||
for key := range dst {
|
||||
delete(dst, key)
|
||||
}
|
||||
for key, values := range src {
|
||||
dst[key] = append([]string(nil), values...)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
||||
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||
status := http.StatusInternalServerError
|
||||
if msg != nil && msg.StatusCode > 0 {
|
||||
status = msg.StatusCode
|
||||
}
|
||||
if msg != nil && msg.Addon != nil {
|
||||
if msg != nil && msg.Addon != nil && PassthroughHeadersEnabled(h.Cfg) {
|
||||
for key, values := range msg.Addon {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
|
||||
68
sdk/api/handlers/handlers_error_response_test.go
Normal file
68
sdk/api/handlers/handlers_error_response_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestWriteErrorResponse_AddonHeadersDisabledByDefault(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
handler := NewBaseAPIHandlers(nil, nil)
|
||||
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Error: errors.New("rate limit"),
|
||||
Addon: http.Header{
|
||||
"Retry-After": {"30"},
|
||||
"X-Request-Id": {"req-1"},
|
||||
},
|
||||
})
|
||||
|
||||
if recorder.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
if got := recorder.Header().Get("Retry-After"); got != "" {
|
||||
t.Fatalf("Retry-After should be empty when passthrough is disabled, got %q", got)
|
||||
}
|
||||
if got := recorder.Header().Get("X-Request-Id"); got != "" {
|
||||
t.Fatalf("X-Request-Id should be empty when passthrough is disabled, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Writer.Header().Set("X-Request-Id", "old-value")
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{PassthroughHeaders: true}, nil)
|
||||
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Error: errors.New("rate limit"),
|
||||
Addon: http.Header{
|
||||
"Retry-After": {"30"},
|
||||
"X-Request-Id": {"new-1", "new-2"},
|
||||
},
|
||||
})
|
||||
|
||||
if recorder.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
if got := recorder.Header().Get("Retry-After"); got != "30" {
|
||||
t.Fatalf("Retry-After = %q, want %q", got, "30")
|
||||
}
|
||||
if got := recorder.Header().Values("X-Request-Id"); !reflect.DeepEqual(got, []string{"new-1", "new-2"}) {
|
||||
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreex
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
call := e.calls
|
||||
@@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth,
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{
|
||||
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
|
||||
Chunks: ch,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{
|
||||
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
|
||||
Chunks: ch,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
@@ -81,7 +87,7 @@ func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
e.mu.Unlock()
|
||||
@@ -97,7 +103,7 @@ func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreaut
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
@@ -134,7 +140,7 @@ func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coree
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
_ = opts
|
||||
@@ -160,12 +166,12 @@ func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *corea
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
@@ -231,11 +237,12 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
PassthroughHeaders: true,
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
@@ -257,6 +264,70 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
if executor.Calls() != 2 {
|
||||
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
||||
}
|
||||
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
|
||||
if upstreamAttemptHeader != "2" {
|
||||
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
|
||||
executor := &failOnceStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
t.Fatalf("unexpected error: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got) != "ok" {
|
||||
t.Fatalf("expected payload ok, got %q", string(got))
|
||||
}
|
||||
if upstreamHeaders != nil {
|
||||
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||
@@ -296,7 +367,7 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
@@ -367,7 +438,7 @@ func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T)
|
||||
},
|
||||
}, manager)
|
||||
ctx := WithPinnedAuthID(context.Background(), "auth1")
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
@@ -431,7 +502,7 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
||||
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
|
||||
selectedAuthID = authID
|
||||
})
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
80
sdk/api/handlers/header_filter.go
Normal file
80
sdk/api/handlers/header_filter.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
|
||||
// be forwarded by proxies, plus security-sensitive headers that should not leak.
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
// RFC 7230 hop-by-hop
|
||||
"Connection": {},
|
||||
"Keep-Alive": {},
|
||||
"Proxy-Authenticate": {},
|
||||
"Proxy-Authorization": {},
|
||||
"Te": {},
|
||||
"Trailer": {},
|
||||
"Transfer-Encoding": {},
|
||||
"Upgrade": {},
|
||||
// Security-sensitive
|
||||
"Set-Cookie": {},
|
||||
// CPA-managed (set by handlers, not upstream)
|
||||
"Content-Length": {},
|
||||
"Content-Encoding": {},
|
||||
}
|
||||
|
||||
// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive
|
||||
// headers removed. Returns nil if src is nil or empty after filtering.
|
||||
func FilterUpstreamHeaders(src http.Header) http.Header {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
connectionScoped := connectionScopedHeaders(src)
|
||||
dst := make(http.Header)
|
||||
for key, values := range src {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
|
||||
continue
|
||||
}
|
||||
if _, scoped := connectionScoped[canonicalKey]; scoped {
|
||||
continue
|
||||
}
|
||||
dst[key] = values
|
||||
}
|
||||
if len(dst) == 0 {
|
||||
return nil
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func connectionScopedHeaders(src http.Header) map[string]struct{} {
|
||||
scoped := make(map[string]struct{})
|
||||
for _, rawValue := range src.Values("Connection") {
|
||||
for _, token := range strings.Split(rawValue, ",") {
|
||||
headerName := strings.TrimSpace(token)
|
||||
if headerName == "" {
|
||||
continue
|
||||
}
|
||||
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
|
||||
}
|
||||
}
|
||||
return scoped
|
||||
}
|
||||
|
||||
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
|
||||
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
|
||||
func WriteUpstreamHeaders(dst http.Header, src http.Header) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
for key, values := range src {
|
||||
// Don't overwrite headers already set by CPA handlers
|
||||
if dst.Get(key) != "" {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
dst.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
55
sdk/api/handlers/header_filter_test.go
Normal file
55
sdk/api/handlers/header_filter_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
|
||||
src.Add("Connection", "x-hop-c")
|
||||
src.Set("Keep-Alive", "timeout=5")
|
||||
src.Set("X-Hop-A", "a")
|
||||
src.Set("X-Hop-B", "b")
|
||||
src.Set("X-Hop-C", "c")
|
||||
src.Set("X-Request-Id", "req-1")
|
||||
src.Set("Set-Cookie", "session=secret")
|
||||
|
||||
filtered := FilterUpstreamHeaders(src)
|
||||
if filtered == nil {
|
||||
t.Fatalf("expected filtered headers, got nil")
|
||||
}
|
||||
|
||||
requestID := filtered.Get("X-Request-Id")
|
||||
if requestID != "req-1" {
|
||||
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
|
||||
}
|
||||
|
||||
blockedHeaderKeys := []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"X-Hop-A",
|
||||
"X-Hop-B",
|
||||
"X-Hop-C",
|
||||
"Set-Cookie",
|
||||
}
|
||||
for _, key := range blockedHeaderKeys {
|
||||
value := filtered.Get(key)
|
||||
if value != "" {
|
||||
t.Fatalf("expected %s to be removed, got %q", key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "x-hop-a")
|
||||
src.Set("X-Hop-A", "a")
|
||||
src.Set("Set-Cookie", "session=secret")
|
||||
|
||||
filtered := FilterUpstreamHeaders(src)
|
||||
if filtered != nil {
|
||||
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
|
||||
}
|
||||
}
|
||||
@@ -332,6 +332,7 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
|
||||
// Check if this chunk has any meaningful content
|
||||
hasContent := false
|
||||
hasUsage := root.Get("usage").Exists()
|
||||
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
||||
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
||||
// Check if delta has content or finish_reason
|
||||
@@ -350,8 +351,8 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
})
|
||||
}
|
||||
|
||||
// If no meaningful content, return nil to indicate this chunk should be skipped
|
||||
if !hasContent {
|
||||
// If no meaningful content and no usage, return nil to indicate this chunk should be skipped
|
||||
if !hasContent && !hasUsage {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -410,6 +411,11 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
||||
}
|
||||
|
||||
// Copy usage if present
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
@@ -425,12 +431,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -457,7 +464,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -490,6 +497,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
if !ok {
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
@@ -498,6 +506,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
|
||||
// Success! Commit to streaming headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||
flusher.Flush()
|
||||
@@ -525,13 +534,14 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
||||
_, _ = c.Writer.Write(completionsResp)
|
||||
cliCancel()
|
||||
@@ -562,7 +572,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
||||
|
||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -593,6 +603,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
@@ -601,6 +612,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
||||
|
||||
// Success! Set headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write the first chunk
|
||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||
|
||||
@@ -31,7 +31,7 @@ func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Aut
|
||||
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
|
||||
}
|
||||
|
||||
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -124,13 +124,14 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) {
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -149,13 +150,14 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
}
|
||||
@@ -183,7 +185,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
// New core execution path
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -216,6 +218,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
if !ok {
|
||||
// Stream closed without data? Send headers and done.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
@@ -224,6 +227,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
|
||||
// Success! Set headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
|
||||
// Write first chunk logic (matching forwardResponsesStream)
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
|
||||
@@ -153,7 +153,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
pinnedAuthID = strings.TrimSpace(authID)
|
||||
})
|
||||
}
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||
if errForward != nil {
|
||||
|
||||
@@ -30,8 +30,9 @@ type ProviderExecutor interface {
|
||||
Identifier() string
|
||||
// Execute handles non-streaming execution and returns the provider response payload.
|
||||
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
|
||||
// ExecuteStream handles streaming execution and returns a channel of provider chunks.
|
||||
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error)
|
||||
// ExecuteStream handles streaming execution and returns a StreamResult containing
|
||||
// upstream headers and a channel of provider chunks.
|
||||
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error)
|
||||
// Refresh attempts to refresh provider credentials and returns the updated auth state.
|
||||
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
|
||||
// CountTokens returns the token count for the given request.
|
||||
@@ -558,7 +559,7 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
|
||||
// ExecuteStream performs a streaming execution using the configured selector and executor.
|
||||
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
||||
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
normalized := m.normalizeProviders(providers)
|
||||
if len(normalized) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
@@ -568,9 +569,9 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; ; attempt++ {
|
||||
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||
if errStream == nil {
|
||||
return chunks, nil
|
||||
return result, nil
|
||||
}
|
||||
lastErr = errStream
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
|
||||
@@ -699,7 +700,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
@@ -730,7 +731,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
@@ -778,8 +779,11 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: streamResult.Headers,
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,10 +24,10 @@ func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
ch := make(chan cliproxyexecutor.StreamChunk)
|
||||
close(ch)
|
||||
return ch, nil
|
||||
return &cliproxyexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
@@ -89,7 +89,11 @@ func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
|
||||
if !okResolved {
|
||||
t.Fatal("expected registered executor to be found")
|
||||
}
|
||||
if resolved != current {
|
||||
resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor)
|
||||
if !okResolvedExecutor {
|
||||
t.Fatalf("expected resolved executor type %T, got %T", current, resolved)
|
||||
}
|
||||
if resolvedExecutor != current {
|
||||
t.Fatal("expected resolved executor to match registered executor")
|
||||
}
|
||||
|
||||
|
||||
@@ -213,6 +213,23 @@ func (a *Auth) DisableCoolingOverride() (bool, bool) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// ToolPrefixDisabled returns whether the proxy_ tool name prefix should be
|
||||
// skipped for this auth. When true, tool names are sent to Anthropic unchanged.
|
||||
// The value is read from metadata key "tool_prefix_disabled" (or "tool-prefix-disabled").
|
||||
func (a *Auth) ToolPrefixDisabled() bool {
|
||||
if a == nil || a.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
for _, key := range []string{"tool_prefix_disabled", "tool-prefix-disabled"} {
|
||||
if val, ok := a.Metadata[key]; ok {
|
||||
if parsed, okParse := parseBoolAny(val); okParse {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
|
||||
// The value is read from metadata key "request_retry" (or legacy "request-retry").
|
||||
func (a *Auth) RequestRetryOverride() (int, bool) {
|
||||
|
||||
35
sdk/cliproxy/auth/types_test.go
Normal file
35
sdk/cliproxy/auth/types_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package auth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestToolPrefixDisabled(t *testing.T) {
|
||||
var a *Auth
|
||||
if a.ToolPrefixDisabled() {
|
||||
t.Error("nil auth should return false")
|
||||
}
|
||||
|
||||
a = &Auth{}
|
||||
if a.ToolPrefixDisabled() {
|
||||
t.Error("empty auth should return false")
|
||||
}
|
||||
|
||||
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": true}}
|
||||
if !a.ToolPrefixDisabled() {
|
||||
t.Error("should return true when set to true")
|
||||
}
|
||||
|
||||
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": "true"}}
|
||||
if !a.ToolPrefixDisabled() {
|
||||
t.Error("should return true when set to string 'true'")
|
||||
}
|
||||
|
||||
a = &Auth{Metadata: map[string]any{"tool-prefix-disabled": true}}
|
||||
if !a.ToolPrefixDisabled() {
|
||||
t.Error("should return true with kebab-case key")
|
||||
}
|
||||
|
||||
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": false}}
|
||||
if a.ToolPrefixDisabled() {
|
||||
t.Error("should return false when set to false")
|
||||
}
|
||||
}
|
||||
@@ -57,6 +57,8 @@ type Response struct {
|
||||
Payload []byte
|
||||
// Metadata exposes optional structured data for translators.
|
||||
Metadata map[string]any
|
||||
// Headers carries upstream HTTP response headers for passthrough to clients.
|
||||
Headers http.Header
|
||||
}
|
||||
|
||||
// StreamChunk represents a single streaming payload unit emitted by provider executors.
|
||||
@@ -67,6 +69,15 @@ type StreamChunk struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// StreamResult wraps the streaming response, providing both the chunk channel
|
||||
// and the upstream HTTP response headers captured before streaming begins.
|
||||
type StreamResult struct {
|
||||
// Headers carries upstream HTTP response headers from the initial connection.
|
||||
Headers http.Header
|
||||
// Chunks is the channel of streaming payload units.
|
||||
Chunks <-chan StreamChunk
|
||||
}
|
||||
|
||||
// StatusError represents an error that carries an HTTP-like status code.
|
||||
// Provider executors should implement this when possible to enable
|
||||
// better auth state updates on failures (e.g., 401/402/429).
|
||||
|
||||
Reference in New Issue
Block a user