mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-04 05:20:52 +08:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0defb68c6c | ||
|
|
d6272d3300 | ||
|
|
c99d0dfb33 | ||
|
|
663b9b35ab | ||
|
|
5dced4c0a6 | ||
|
|
5891785125 | ||
|
|
ac3d47e8c0 | ||
|
|
e5ed2cba4a | ||
|
|
847c2502a5 | ||
|
|
c7196ba7dc | ||
|
|
6f9c23af5e | ||
|
|
2d5d06c809 | ||
|
|
3e20b00357 | ||
|
|
e370f86f63 | ||
|
|
7f266aa19e | ||
|
|
f3f31274e8 | ||
|
|
7061cd6058 | ||
|
|
5da5674ae2 | ||
|
|
7459c2c81a | ||
|
|
cd4706f60e | ||
|
|
359b8de44e | ||
|
|
ea6065f1b1 | ||
|
|
8aaed4cf09 | ||
|
|
c32e013605 | ||
|
|
3839d93ba0 | ||
|
|
a552a45b81 |
@@ -95,7 +95,7 @@ If a plaintext key is detected in the config at startup, it will be bcrypt‑has
|
||||
```
|
||||
- Response:
|
||||
```json
|
||||
{"debug":true,"proxy-url":"","api-keys":["1...5","JS...W"],"quota-exceeded":{"switch-project":true,"switch-preview-model":true},"generative-language-api-key":["AI...01","AI...02","AI...03"],"request-log":true,"request-retry":3,"claude-api-key":[{"api-key":"cr...56","base-url":"https://example.com/api","proxy-url":"socks5://proxy.example.com:1080"},{"api-key":"cr...e3","base-url":"http://example.com:3000/api","proxy-url":""},{"api-key":"sk-...q2","base-url":"https://example.com","proxy-url":""}],"codex-api-key":[{"api-key":"sk...01","base-url":"https://example/v1","proxy-url":""}],"openai-compatibility":[{"name":"openrouter","base-url":"https://openrouter.ai/api/v1","api-key-entries":[{"api-key":"sk...01","proxy-url":""}],"models":[{"name":"moonshotai/kimi-k2:free","alias":"kimi-k2"}]},{"name":"iflow","base-url":"https://apis.iflow.cn/v1","api-key-entries":[{"api-key":"sk...7e","proxy-url":"socks5://proxy.example.com:1080"}],"models":[{"name":"deepseek-v3.1","alias":"deepseek-v3.1"},{"name":"glm-4.5","alias":"glm-4.5"},{"name":"kimi-k2","alias":"kimi-k2"}]}]}
|
||||
{"debug":true,"proxy-url":"","api-keys":["1...5","JS...W"],"quota-exceeded":{"switch-project":true,"switch-preview-model":true},"generative-language-api-key":["AI...01","AI...02","AI...03"],"request-log":true,"request-retry":3,"claude-api-key":[{"api-key":"cr...56","base-url":"https://example.com/api","proxy-url":"socks5://proxy.example.com:1080","models":[{"name":"claude-3-5-sonnet-20241022","alias":"claude-sonnet-latest"}]},{"api-key":"cr...e3","base-url":"http://example.com:3000/api","proxy-url":""},{"api-key":"sk-...q2","base-url":"https://example.com","proxy-url":""}],"codex-api-key":[{"api-key":"sk...01","base-url":"https://example/v1","proxy-url":""}],"openai-compatibility":[{"name":"openrouter","base-url":"https://openrouter.ai/api/v1","api-key-entries":[{"api-key":"sk...01","proxy-url":""}],"models":[{"name":"moonshotai/kimi-k2:free","alias":"kimi-k2"}]},{"name":"iflow","base-url":"https://apis.iflow.cn/v1","api-key-entries":[{"api-key":"sk...7e","proxy-url":"socks5://proxy.example.com:1080"}],"models":[{"name":"deepseek-v3.1","alias":"deepseek-v3.1"},{"name":"glm-4.5","alias":"glm-4.5"},{"name":"kimi-k2","alias":"kimi-k2"}]}]}
|
||||
```
|
||||
|
||||
### Debug
|
||||
|
||||
@@ -95,7 +95,7 @@
|
||||
```
|
||||
- 响应:
|
||||
```json
|
||||
{"debug":true,"proxy-url":"","api-keys":["1...5","JS...W"],"quota-exceeded":{"switch-project":true,"switch-preview-model":true},"generative-language-api-key":["AI...01","AI...02","AI...03"],"request-log":true,"request-retry":3,"claude-api-key":[{"api-key":"cr...56","base-url":"https://example.com/api","proxy-url":"socks5://proxy.example.com:1080"},{"api-key":"cr...e3","base-url":"http://example.com:3000/api","proxy-url":""},{"api-key":"sk-...q2","base-url":"https://example.com","proxy-url":""}],"codex-api-key":[{"api-key":"sk...01","base-url":"https://example/v1","proxy-url":""}],"openai-compatibility":[{"name":"openrouter","base-url":"https://openrouter.ai/api/v1","api-key-entries":[{"api-key":"sk...01","proxy-url":""}],"models":[{"name":"moonshotai/kimi-k2:free","alias":"kimi-k2"}]},{"name":"iflow","base-url":"https://apis.iflow.cn/v1","api-key-entries":[{"api-key":"sk...7e","proxy-url":"socks5://proxy.example.com:1080"}],"models":[{"name":"deepseek-v3.1","alias":"deepseek-v3.1"},{"name":"glm-4.5","alias":"glm-4.5"},{"name":"kimi-k2","alias":"kimi-k2"}]}]}
|
||||
{"debug":true,"proxy-url":"","api-keys":["1...5","JS...W"],"quota-exceeded":{"switch-project":true,"switch-preview-model":true},"generative-language-api-key":["AI...01","AI...02","AI...03"],"request-log":true,"request-retry":3,"claude-api-key":[{"api-key":"cr...56","base-url":"https://example.com/api","proxy-url":"socks5://proxy.example.com:1080","models":[{"name":"claude-3-5-sonnet-20241022","alias":"claude-sonnet-latest"}]},{"api-key":"cr...e3","base-url":"http://example.com:3000/api","proxy-url":""},{"api-key":"sk-...q2","base-url":"https://example.com","proxy-url":""}],"codex-api-key":[{"api-key":"sk...01","base-url":"https://example/v1","proxy-url":""}],"openai-compatibility":[{"name":"openrouter","base-url":"https://openrouter.ai/api/v1","api-key-entries":[{"api-key":"sk...01","proxy-url":""}],"models":[{"name":"moonshotai/kimi-k2:free","alias":"kimi-k2"}]},{"name":"iflow","base-url":"https://apis.iflow.cn/v1","api-key-entries":[{"api-key":"sk...7e","proxy-url":"socks5://proxy.example.com:1080"}],"models":[{"name":"deepseek-v3.1","alias":"deepseek-v3.1"},{"name":"glm-4.5","alias":"glm-4.5"},{"name":"kimi-k2","alias":"kimi-k2"}]}]}
|
||||
```
|
||||
|
||||
### Debug
|
||||
|
||||
45
README.md
45
README.md
@@ -318,6 +318,9 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
|
||||
| `claude-api-key.api-key` | string | "" | Claude API key. |
|
||||
| `claude-api-key.base-url` | string | "" | Custom Claude API endpoint, if you use a third-party API endpoint. |
|
||||
| `claude-api-key.proxy-url` | string | "" | Proxy URL for this specific API key. Overrides the global proxy-url setting. Supports socks5/http/https protocols. |
|
||||
| `claude-api-key.models` | object[] | [] | Model alias entries for this key. |
|
||||
| `claude-api-key.models.*.name` | string | "" | Upstream Claude model name invoked against the API. |
|
||||
| `claude-api-key.models.*.alias` | string | "" | Client-facing alias that maps to the upstream model name. |
|
||||
| `openai-compatibility` | object[] | [] | Upstream OpenAI-compatible providers configuration (name, base-url, api-keys, models). |
|
||||
| `openai-compatibility.*.name` | string | "" | The name of the provider. It will be used in the user agent and other places. |
|
||||
| `openai-compatibility.*.base-url` | string | "" | The base URL of the provider. |
|
||||
@@ -325,9 +328,11 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
|
||||
| `openai-compatibility.*.api-key-entries` | object[] | [] | API key entries with optional per-key proxy configuration. Preferred over api-keys. |
|
||||
| `openai-compatibility.*.api-key-entries.*.api-key` | string | "" | The API key for this entry. |
|
||||
| `openai-compatibility.*.api-key-entries.*.proxy-url` | string | "" | Proxy URL for this specific API key. Overrides the global proxy-url setting. Supports socks5/http/https protocols. |
|
||||
| `openai-compatibility.*.models` | object[] | [] | The actual model name. |
|
||||
| `openai-compatibility.*.models.*.name` | string | "" | The models supported by the provider. |
|
||||
| `openai-compatibility.*.models.*.alias` | string | "" | The alias used in the API. |
|
||||
| `openai-compatibility.*.models` | object[] | [] | Model alias definitions routing client aliases to upstream names. |
|
||||
| `openai-compatibility.*.models.*.name` | string | "" | Upstream model name invoked against the provider. |
|
||||
| `openai-compatibility.*.models.*.alias` | string | "" | Client alias routed to the upstream model. |
|
||||
|
||||
When `claude-api-key.models` is specified, only the provided aliases are registered in the model registry (mirroring OpenAI compatibility behaviour), and the default Claude catalog is suppressed for that credential.
|
||||
|
||||
### Example Configuration File
|
||||
|
||||
@@ -410,7 +415,7 @@ openai-compatibility:
|
||||
# api-keys:
|
||||
# - "sk-or-v1-...b780"
|
||||
# - "sk-or-v1-...b781"
|
||||
models: # The models supported by the provider.
|
||||
models: # The models supported by the provider. Or you can use a format such as openrouter://moonshotai/kimi-k2:free to request undefined models
|
||||
- name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
alias: "kimi-k2" # The alias used in the API.
|
||||
```
|
||||
@@ -556,12 +561,17 @@ The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` req
|
||||
|
||||
## Claude Code with multiple account load balancing
|
||||
|
||||
Start CLI Proxy API server, and then set the `ANTHROPIC_BASE_URL`, `ANTHROPIC_AUTH_TOKEN`, `ANTHROPIC_MODEL`, `ANTHROPIC_SMALL_FAST_MODEL` environment variables.
|
||||
Start CLI Proxy API server, and then set the `ANTHROPIC_BASE_URL`, `ANTHROPIC_AUTH_TOKEN`, `ANTHROPIC_DEFAULT_OPUS_MODEL`, `ANTHROPIC_DEFAULT_SONNET_MODEL`, `ANTHROPIC_DEFAULT_HAIKU_MODEL` (or `ANTHROPIC_MODEL`, `ANTHROPIC_SMALL_FAST_MODEL` for version 1.x.x) environment variables.
|
||||
|
||||
Using Gemini models:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# version 2.x.x
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=gemini-2.5-pro
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=gemini-2.5-flash
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gemini-2.5-flash-lite
|
||||
# version 1.x.x
|
||||
export ANTHROPIC_MODEL=gemini-2.5-pro
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=gemini-2.5-flash
|
||||
```
|
||||
@@ -570,6 +580,11 @@ Using OpenAI GPT 5 models:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# version 2.x.x
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=gpt-5-high
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=gpt-5-medium
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gpt-5-minimal
|
||||
# version 1.x.x
|
||||
export ANTHROPIC_MODEL=gpt-5
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-minimal
|
||||
```
|
||||
@@ -578,6 +593,11 @@ Using OpenAI GPT 5 Codex models:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# version 2.x.x
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=gpt-5-codex-high
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=gpt-5-codex-medium
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gpt-5-codex-low
|
||||
# version 1.x.x
|
||||
export ANTHROPIC_MODEL=gpt-5-codex
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-codex-low
|
||||
```
|
||||
@@ -586,6 +606,11 @@ Using Claude models:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# version 2.x.x
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=claude-opus-4-1-20250805
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=claude-sonnet-4-5-20250929
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=claude-3-5-haiku-20241022
|
||||
# version 1.x.x
|
||||
export ANTHROPIC_MODEL=claude-sonnet-4-20250514
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022
|
||||
```
|
||||
@@ -594,6 +619,11 @@ Using Qwen models:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# version 2.x.x
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=qwen3-coder-flash
|
||||
# version 1.x.x
|
||||
export ANTHROPIC_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash
|
||||
```
|
||||
@@ -602,6 +632,11 @@ Using iFlow models:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# version 2.x.x
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=qwen3-max
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=qwen3-235b-a22b-instruct
|
||||
# version 1.x.x
|
||||
export ANTHROPIC_MODEL=qwen3-max
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-235b-a22b-instruct
|
||||
```
|
||||
|
||||
46
README_CN.md
46
README_CN.md
@@ -331,6 +331,9 @@ console.log(await claudeResponse.json());
|
||||
| `claude-api-key.api-key` | string | "" | Claude API密钥。 |
|
||||
| `claude-api-key.base-url` | string | "" | 自定义的Claude API端点,如果您使用第三方的API端点。 |
|
||||
| `claude-api-key.proxy-url` | string | "" | 针对该API密钥的代理URL。会覆盖全局proxy-url设置。支持socks5/http/https协议。 |
|
||||
| `claude-api-key.models` | object[] | [] | Model alias entries for this key. |
|
||||
| `claude-api-key.models.*.name` | string | "" | Upstream Claude model name invoked against the API. |
|
||||
| `claude-api-key.models.*.alias` | string | "" | Client-facing alias that maps to the upstream model name. |
|
||||
| `openai-compatibility` | object[] | [] | 上游OpenAI兼容提供商的配置(名称、基础URL、API密钥、模型)。 |
|
||||
| `openai-compatibility.*.name` | string | "" | 提供商的名称。它将被用于用户代理(User Agent)和其他地方。 |
|
||||
| `openai-compatibility.*.base-url` | string | "" | 提供商的基础URL。 |
|
||||
@@ -338,9 +341,11 @@ console.log(await claudeResponse.json());
|
||||
| `openai-compatibility.*.api-key-entries` | object[] | [] | API密钥条目,支持可选的每密钥代理配置。优先于api-keys。 |
|
||||
| `openai-compatibility.*.api-key-entries.*.api-key` | string | "" | 该条目的API密钥。 |
|
||||
| `openai-compatibility.*.api-key-entries.*.proxy-url` | string | "" | 针对该API密钥的代理URL。会覆盖全局proxy-url设置。支持socks5/http/https协议。 |
|
||||
| `openai-compatibility.*.models` | object[] | [] | 实际的模型名称。 |
|
||||
| `openai-compatibility.*.models.*.name` | string | "" | 提供商支持的模型。 |
|
||||
| `openai-compatibility.*.models.*.alias` | string | "" | 在API中使用的别名。 |
|
||||
| `openai-compatibility.*.models` | object[] | [] | Model alias definitions routing client aliases to upstream names. |
|
||||
| `openai-compatibility.*.models.*.name` | string | "" | Upstream model name invoked against the provider. |
|
||||
| `openai-compatibility.*.models.*.alias` | string | "" | Client alias routed to the upstream model. |
|
||||
|
||||
When `claude-api-key.models` is provided, only the listed aliases are registered for that credential, and the default Claude model catalog is skipped.
|
||||
|
||||
### 配置文件示例
|
||||
|
||||
@@ -423,7 +428,7 @@ openai-compatibility:
|
||||
# api-keys:
|
||||
# - "sk-or-v1-...b780"
|
||||
# - "sk-or-v1-...b781"
|
||||
models: # 提供商支持的模型。
|
||||
models: # 提供商支持的模型。或者你可以使用类似 openrouter://moonshotai/kimi-k2:free 这样的格式来请求未在这里定义的模型
|
||||
- name: "moonshotai/kimi-k2:free" # 实际的模型名称。
|
||||
alias: "kimi-k2" # 在API中使用的别名。
|
||||
```
|
||||
@@ -564,12 +569,17 @@ export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
|
||||
|
||||
## Claude Code 的使用方法
|
||||
|
||||
启动 CLI Proxy API 服务器, 设置如下系统环境变量 `ANTHROPIC_BASE_URL`, `ANTHROPIC_AUTH_TOKEN`, `ANTHROPIC_MODEL`, `ANTHROPIC_SMALL_FAST_MODEL`
|
||||
启动 CLI Proxy API 服务器, 设置如下系统环境变量 `ANTHROPIC_BASE_URL`, `ANTHROPIC_AUTH_TOKEN`, `ANTHROPIC_DEFAULT_OPUS_MODEL`, `ANTHROPIC_DEFAULT_SONNET_MODEL`, `ANTHROPIC_DEFAULT_HAIKU_MODEL` (或 `ANTHROPIC_MODEL`, `ANTHROPIC_SMALL_FAST_MODEL` 对应 1.x.x 版本)
|
||||
|
||||
使用 Gemini 模型:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# 2.x.x 版本
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=gemini-2.5-pro
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=gemini-2.5-flash
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gemini-2.5-flash-lite
|
||||
# 1.x.x 版本
|
||||
export ANTHROPIC_MODEL=gemini-2.5-pro
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=gemini-2.5-flash
|
||||
```
|
||||
@@ -578,6 +588,11 @@ export ANTHROPIC_SMALL_FAST_MODEL=gemini-2.5-flash
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# 2.x.x 版本
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=gpt-5-high
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=gpt-5-medium
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gpt-5-minimal
|
||||
# 1.x.x 版本
|
||||
export ANTHROPIC_MODEL=gpt-5
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-minimal
|
||||
```
|
||||
@@ -586,15 +601,24 @@ export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-minimal
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# 2.x.x 版本
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=gpt-5-codex-high
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=gpt-5-codex-medium
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gpt-5-codex-low
|
||||
# 1.x.x 版本
|
||||
export ANTHROPIC_MODEL=gpt-5-codex
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-codex-low
|
||||
```
|
||||
|
||||
|
||||
使用 Claude 模型:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# 2.x.x 版本
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=claude-opus-4-1-20250805
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=claude-sonnet-4-5-20250929
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=claude-3-5-haiku-20241022
|
||||
# 1.x.x 版本
|
||||
export ANTHROPIC_MODEL=claude-sonnet-4-20250514
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022
|
||||
```
|
||||
@@ -603,6 +627,11 @@ export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# 2.x.x 版本
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=qwen3-coder-flash
|
||||
# 1.x.x 版本
|
||||
export ANTHROPIC_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash
|
||||
```
|
||||
@@ -611,6 +640,11 @@ export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
# 2.x.x 版本
|
||||
export ANTHROPIC_DEFAULT_OPUS_MODEL=qwen3-max
|
||||
export ANTHROPIC_DEFAULT_SONNET_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_DEFAULT_HAIKU_MODEL=qwen3-235b-a22b-instruct
|
||||
# 1.x.x 版本
|
||||
export ANTHROPIC_MODEL=qwen3-max
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-235b-a22b-instruct
|
||||
```
|
||||
|
||||
@@ -43,6 +43,9 @@ quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
|
||||
# API keys for official Generative Language API
|
||||
#generative-language-api-key:
|
||||
# - "AIzaSy...01"
|
||||
@@ -62,6 +65,9 @@ quota-exceeded:
|
||||
# - api-key: "sk-atSM..."
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
|
||||
# OpenAI compatibility providers
|
||||
#openai-compatibility:
|
||||
|
||||
@@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
6
go.mod
6
go.mod
@@ -7,14 +7,16 @@ require (
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/klauspost/compress v1.17.4
|
||||
github.com/minio/minio-go/v7 v7.0.66
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.7.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
@@ -26,12 +28,14 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/andybalholm/brotli v1.0.6 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/cyphar/filepath-securejoin v0.4.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -4,6 +4,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
|
||||
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
@@ -23,6 +25,8 @@ github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGL
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
|
||||
@@ -64,6 +68,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -78,8 +84,6 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
|
||||
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
||||
github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA=
|
||||
github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
@@ -147,6 +151,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw=
|
||||
github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
|
||||
@@ -57,10 +57,12 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||
authHeaderAnthropic := r.Header.Get("X-Api-Key")
|
||||
queryKey := ""
|
||||
queryAuthToken := ""
|
||||
if r.URL != nil {
|
||||
queryKey = r.URL.Query().Get("key")
|
||||
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||
}
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" {
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
}
|
||||
|
||||
@@ -74,6 +76,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
{authHeaderGoogle, "x-goog-api-key"},
|
||||
{authHeaderAnthropic, "x-api-key"},
|
||||
{queryKey, "query-key"},
|
||||
{queryAuthToken, "query-auth-token"},
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
|
||||
@@ -150,6 +150,9 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) {
|
||||
}
|
||||
arr = obj.Items
|
||||
}
|
||||
for i := range arr {
|
||||
normalizeClaudeKey(&arr[i])
|
||||
}
|
||||
h.cfg.ClaudeKey = arr
|
||||
h.persist(c)
|
||||
}
|
||||
@@ -163,6 +166,7 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
normalizeClaudeKey(body.Value)
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
|
||||
h.cfg.ClaudeKey[*body.Index] = *body.Value
|
||||
h.persist(c)
|
||||
@@ -472,3 +476,26 @@ func normalizedOpenAICompatibilityEntries(entries []config.OpenAICompatibility)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeClaudeKey(entry *config.ClaudeKey) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
normalized := make([]config.ClaudeModel, 0, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
model.Name = strings.TrimSpace(model.Name)
|
||||
model.Alias = strings.TrimSpace(model.Alias)
|
||||
if model.Name == "" && model.Alias == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, model)
|
||||
}
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
@@ -63,13 +64,11 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
// It captures the URL, method, headers, and body. The request body is read and then
|
||||
// restored so that it can be processed by subsequent handlers.
|
||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
// Capture URL
|
||||
url := c.Request.URL.String()
|
||||
if c.Request.URL.Path != "" {
|
||||
url = c.Request.URL.Path
|
||||
if c.Request.URL.RawQuery != "" {
|
||||
url += "?" + c.Request.URL.RawQuery
|
||||
}
|
||||
// Capture URL with sensitive query parameters masked
|
||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
url := c.Request.URL.Path
|
||||
if maskedQuery != "" {
|
||||
url += "?" + maskedQuery
|
||||
}
|
||||
|
||||
// Capture method
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -138,6 +139,12 @@ type Server struct {
|
||||
// currentPath is the absolute path to the current working directory.
|
||||
currentPath string
|
||||
|
||||
// wsRoutes tracks registered websocket upgrade paths.
|
||||
wsRouteMu sync.Mutex
|
||||
wsRoutes map[string]struct{}
|
||||
wsAuthChanged func(bool, bool)
|
||||
wsAuthEnabled atomic.Bool
|
||||
|
||||
// management handler
|
||||
mgmt *managementHandlers.Handler
|
||||
|
||||
@@ -218,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -228,7 +239,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
configFilePath: configFilePath,
|
||||
currentPath: wd,
|
||||
envManagementSecret: envManagementSecret,
|
||||
wsRoutes: make(map[string]struct{}),
|
||||
}
|
||||
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
||||
// Save initial YAML snapshot
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
s.applyAccessConfig(nil, cfg)
|
||||
@@ -371,6 +384,43 @@ func (s *Server) setupRoutes() {
|
||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||
}
|
||||
|
||||
// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine.
|
||||
// The handler is served as-is without additional middleware beyond the standard stack already configured.
|
||||
func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
|
||||
if s == nil || s.engine == nil || handler == nil {
|
||||
return
|
||||
}
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if trimmed == "" {
|
||||
trimmed = "/v1/ws"
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "/") {
|
||||
trimmed = "/" + trimmed
|
||||
}
|
||||
s.wsRouteMu.Lock()
|
||||
if _, exists := s.wsRoutes[trimmed]; exists {
|
||||
s.wsRouteMu.Unlock()
|
||||
return
|
||||
}
|
||||
s.wsRoutes[trimmed] = struct{}{}
|
||||
s.wsRouteMu.Unlock()
|
||||
|
||||
authMiddleware := AuthMiddleware(s.accessManager)
|
||||
conditionalAuth := func(c *gin.Context) {
|
||||
if !s.wsAuthEnabled.Load() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
authMiddleware(c)
|
||||
}
|
||||
finalHandler := func(c *gin.Context) {
|
||||
handler.ServeHTTP(c.Writer, c.Request)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
s.engine.GET(trimmed, conditionalAuth, finalHandler)
|
||||
}
|
||||
|
||||
func (s *Server) registerManagementRoutes() {
|
||||
if s == nil || s.engine == nil || s.mgmt == nil {
|
||||
return
|
||||
@@ -479,7 +529,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
filePath := managementasset.FilePath(s.currentPath)
|
||||
filePath := managementasset.FilePath(s.configFilePath)
|
||||
if strings.TrimSpace(filePath) == "" {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
@@ -487,7 +537,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.currentPath), cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
@@ -770,13 +820,24 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
|
||||
s.applyAccessConfig(oldCfg, cfg)
|
||||
s.cfg = cfg
|
||||
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
||||
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
|
||||
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.OpenAICompatProviders = providerNames
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.currentPath)
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
@@ -810,6 +871,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.wsAuthChanged = fn
|
||||
}
|
||||
|
||||
// (management handlers moved to internal/api/handlers/management)
|
||||
|
||||
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
||||
@@ -846,5 +914,3 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// legacy clientsToSlice removed; handlers no longer consume legacy client slices
|
||||
|
||||
@@ -40,6 +40,9 @@ type Config struct {
|
||||
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
|
||||
|
||||
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
||||
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
||||
|
||||
// GlAPIKey is the API key for the generative language API.
|
||||
GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"`
|
||||
|
||||
@@ -91,6 +94,18 @@ type ClaudeKey struct {
|
||||
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// Models defines upstream model names and aliases for request routing.
|
||||
Models []ClaudeModel `yaml:"models" json:"models"`
|
||||
}
|
||||
|
||||
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
|
||||
type ClaudeModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the client-facing model name that maps to Name.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
// CodexKey represents the configuration for a Codex API key,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -23,7 +24,7 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
|
||||
c.Next()
|
||||
|
||||
|
||||
@@ -15,6 +15,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
@@ -411,6 +415,10 @@ func (l *FileRequestLogger) decompressResponse(responseHeaders map[string][]stri
|
||||
return l.decompressGzip(response)
|
||||
case "deflate":
|
||||
return l.decompressDeflate(response)
|
||||
case "br":
|
||||
return l.decompressBrotli(response)
|
||||
case "zstd":
|
||||
return l.decompressZstd(response)
|
||||
default:
|
||||
// No compression or unsupported compression
|
||||
return response, nil
|
||||
@@ -431,7 +439,9 @@ func (l *FileRequestLogger) decompressGzip(data []byte) ([]byte, error) {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = reader.Close()
|
||||
if errClose := reader.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close gzip reader in request logger")
|
||||
}
|
||||
}()
|
||||
|
||||
decompressed, err := io.ReadAll(reader)
|
||||
@@ -453,7 +463,9 @@ func (l *FileRequestLogger) decompressGzip(data []byte) ([]byte, error) {
|
||||
func (l *FileRequestLogger) decompressDeflate(data []byte) ([]byte, error) {
|
||||
reader := flate.NewReader(bytes.NewReader(data))
|
||||
defer func() {
|
||||
_ = reader.Close()
|
||||
if errClose := reader.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close deflate reader in request logger")
|
||||
}
|
||||
}()
|
||||
|
||||
decompressed, err := io.ReadAll(reader)
|
||||
@@ -464,6 +476,48 @@ func (l *FileRequestLogger) decompressDeflate(data []byte) ([]byte, error) {
|
||||
return decompressed, nil
|
||||
}
|
||||
|
||||
// decompressBrotli decompresses brotli-encoded data.
|
||||
//
|
||||
// Parameters:
|
||||
// - data: The brotli-encoded data to decompress
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The decompressed data
|
||||
// - error: An error if decompression fails, nil otherwise
|
||||
func (l *FileRequestLogger) decompressBrotli(data []byte) ([]byte, error) {
|
||||
reader := brotli.NewReader(bytes.NewReader(data))
|
||||
|
||||
decompressed, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decompress brotli data: %w", err)
|
||||
}
|
||||
|
||||
return decompressed, nil
|
||||
}
|
||||
|
||||
// decompressZstd decompresses zstd-encoded data.
|
||||
//
|
||||
// Parameters:
|
||||
// - data: The zstd-encoded data to decompress
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The decompressed data
|
||||
// - error: An error if decompression fails, nil otherwise
|
||||
func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
|
||||
decoder, err := zstd.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
|
||||
}
|
||||
defer decoder.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(decoder)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decompress zstd data: %w", err)
|
||||
}
|
||||
|
||||
return decompressed, nil
|
||||
}
|
||||
|
||||
// formatRequestInfo creates the request information section of the log.
|
||||
//
|
||||
// Parameters:
|
||||
|
||||
@@ -68,84 +68,8 @@ func GetClaudeModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// GetGeminiModels returns the standard Gemini model definitions
|
||||
func GetGeminiModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "gemini-2.5-flash",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-2.5-flash",
|
||||
Version: "001",
|
||||
DisplayName: "Gemini 2.5 Flash",
|
||||
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-2.5-pro",
|
||||
Version: "2.5",
|
||||
DisplayName: "Gemini 2.5 Pro",
|
||||
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-flash-lite",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-2.5-flash-lite",
|
||||
Version: "2.5",
|
||||
DisplayName: "Gemini 2.5 Flash Lite",
|
||||
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Flash Lite",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-flash-image-preview",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-2.5-flash-image-preview",
|
||||
Version: "2.5",
|
||||
DisplayName: "Gemini 2.5 Flash Image Preview",
|
||||
Description: "State-of-the-art image generation and editing model.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 8192,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-flash-image",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-2.5-flash-image",
|
||||
Version: "2.5",
|
||||
DisplayName: "Gemini 2.5 Flash Image",
|
||||
Description: "State-of-the-art image generation and editing model.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 8192,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetGeminiCLIModels returns the standard Gemini model definitions
|
||||
func GetGeminiCLIModels() []*ModelInfo {
|
||||
// GeminiModels returns the shared base Gemini model set used by multiple providers.
|
||||
func GeminiModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "gemini-2.5-flash",
|
||||
@@ -220,6 +144,63 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// GetGeminiModels returns the standard Gemini model definitions
|
||||
func GetGeminiModels() []*ModelInfo { return GeminiModels() }
|
||||
|
||||
// GetGeminiCLIModels returns the standard Gemini model definitions
|
||||
func GetGeminiCLIModels() []*ModelInfo { return GeminiModels() }
|
||||
|
||||
// GetAIStudioModels returns the Gemini model definitions for AI Studio integrations
|
||||
func GetAIStudioModels() []*ModelInfo {
|
||||
models := make([]*ModelInfo, 0, 8)
|
||||
models = append(models, GeminiModels()...)
|
||||
models = append(models,
|
||||
&ModelInfo{
|
||||
ID: "gemini-pro-latest",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-pro-latest",
|
||||
Version: "2.5",
|
||||
DisplayName: "Gemini Pro Latest",
|
||||
Description: "Latest release of Gemini Pro",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
&ModelInfo{
|
||||
ID: "gemini-flash-latest",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-flash-latest",
|
||||
Version: "2.5",
|
||||
DisplayName: "Gemini Flash Latest",
|
||||
Description: "Latest release of Gemini Flash",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
&ModelInfo{
|
||||
ID: "gemini-flash-lite-latest",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-flash-lite-latest",
|
||||
Version: "2.5",
|
||||
DisplayName: "Gemini Flash-Lite Latest",
|
||||
Description: "Latest release of Gemini Flash-Lite",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
)
|
||||
return models
|
||||
}
|
||||
|
||||
// GetOpenAIModels returns the standard OpenAI model definitions
|
||||
func GetOpenAIModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
@@ -385,6 +366,19 @@ func GetQwenModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "vision-model",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "qwen",
|
||||
Type: "qwen",
|
||||
Version: "3.0",
|
||||
DisplayName: "Qwen3 Vision Model",
|
||||
Description: "Vision model model",
|
||||
ContextLength: 32768,
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -404,7 +398,6 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language"},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build"},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905"},
|
||||
{ID: "glm-4.5", DisplayName: "GLM-4.5", Description: "Zhipu GLM 4.5 general model"},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model"},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model"},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental"},
|
||||
|
||||
396
internal/runtime/executor/aistudio_executor.go
Normal file
396
internal/runtime/executor/aistudio_executor.go
Normal file
@@ -0,0 +1,396 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// AIStudioExecutor routes AI Studio requests through a websocket-backed transport.
|
||||
type AIStudioExecutor struct {
|
||||
provider string
|
||||
relay *wsrelay.Manager
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAIStudioExecutor constructs a websocket executor for the provider name.
|
||||
func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor {
|
||||
return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the logical provider key for routing.
|
||||
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
|
||||
|
||||
// PrepareRequest is a no-op because websocket transport already injects headers.
|
||||
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
translatedReq, body, err := e.translateRequest(req, opts, false)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
URL: endpoint,
|
||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: body.payload,
|
||||
}
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||
if len(wsResp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
|
||||
}
|
||||
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
||||
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
translatedReq, body, err := e.translateRequest(req, opts, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
URL: endpoint,
|
||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: body.payload,
|
||||
}
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
wsStream, err := e.relay.Stream(ctx, authID, wsReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
var param any
|
||||
metadataLogged := false
|
||||
for event := range wsStream {
|
||||
if event.Err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||
return
|
||||
}
|
||||
switch event.Type {
|
||||
case wsrelay.MessageTypeStreamStart:
|
||||
if !metadataLogged && event.Status > 0 {
|
||||
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||
metadataLogged = true
|
||||
}
|
||||
case wsrelay.MessageTypeStreamChunk:
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
filtered := filterAIStudioUsageMetadata(event.Payload)
|
||||
if detail, ok := parseGeminiStreamUsage(filtered); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
break
|
||||
}
|
||||
case wsrelay.MessageTypeStreamEnd:
|
||||
return
|
||||
case wsrelay.MessageTypeHTTPResp:
|
||||
if !metadataLogged && event.Status > 0 {
|
||||
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(event.Payload))
|
||||
return
|
||||
case wsrelay.MessageTypeError:
|
||||
recordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
_, body, err := e.translateRequest(req, opts, false)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
body.payload, _ = sjson.DeleteBytes(body.payload, "generationConfig")
|
||||
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
|
||||
|
||||
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
URL: endpoint,
|
||||
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: body.payload,
|
||||
}
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
resp, err := e.relay.NonStream(ctx, authID, wsReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||
if len(resp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
|
||||
}
|
||||
if resp.Status < 200 || resp.Status >= 300 {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
||||
}
|
||||
totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int()
|
||||
if totalTokens <= 0 {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||
}
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
func (e *AIStudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
_ = ctx
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
type translatedPayload struct {
|
||||
payload []byte
|
||||
action string
|
||||
toFormat sdktranslator.Format
|
||||
}
|
||||
|
||||
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok {
|
||||
payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
payload = disableGeminiThinkingConfig(payload, req.Model)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
metadataAction := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
|
||||
metadataAction = action
|
||||
}
|
||||
}
|
||||
action := metadataAction
|
||||
if stream && action != "countTokens" {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
payload, _ = sjson.DeleteBytes(payload, "session_id")
|
||||
return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil
|
||||
}
|
||||
|
||||
func (e *AIStudioExecutor) buildEndpoint(model, action, alt string) string {
|
||||
base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action)
|
||||
if action == "streamGenerateContent" {
|
||||
if alt == "" {
|
||||
return base + "?alt=sse"
|
||||
}
|
||||
return base + "?$alt=" + url.QueryEscape(alt)
|
||||
}
|
||||
if alt != "" && action != "countTokens" {
|
||||
return base + "?$alt=" + url.QueryEscape(alt)
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// filterAIStudioUsageMetadata removes usageMetadata from intermediate SSE events so that
|
||||
// only the terminal chunk retains token statistics.
|
||||
func filterAIStudioUsageMetadata(payload []byte) []byte {
|
||||
if len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
|
||||
lines := bytes.Split(payload, []byte("\n"))
|
||||
modified := false
|
||||
for idx, line := range lines {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
dataIdx := bytes.Index(line, []byte("data:"))
|
||||
if dataIdx < 0 {
|
||||
continue
|
||||
}
|
||||
rawJSON := bytes.TrimSpace(line[dataIdx+5:])
|
||||
cleaned, changed := stripUsageMetadataFromJSON(rawJSON)
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
var rebuilt []byte
|
||||
rebuilt = append(rebuilt, line[:dataIdx]...)
|
||||
rebuilt = append(rebuilt, []byte("data:")...)
|
||||
if len(cleaned) > 0 {
|
||||
rebuilt = append(rebuilt, ' ')
|
||||
rebuilt = append(rebuilt, cleaned...)
|
||||
}
|
||||
lines[idx] = rebuilt
|
||||
modified = true
|
||||
}
|
||||
if !modified {
|
||||
return payload
|
||||
}
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
}
|
||||
|
||||
// stripUsageMetadataFromJSON drops usageMetadata when no finishReason is present.
|
||||
func stripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
|
||||
jsonBytes := bytes.TrimSpace(rawJSON)
|
||||
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
|
||||
return rawJSON, false
|
||||
}
|
||||
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
|
||||
if finishReason.Exists() && finishReason.String() != "" {
|
||||
return rawJSON, false
|
||||
}
|
||||
if !gjson.GetBytes(jsonBytes, "usageMetadata").Exists() {
|
||||
return rawJSON, false
|
||||
}
|
||||
cleaned, err := sjson.DeleteBytes(jsonBytes, "usageMetadata")
|
||||
if err != nil {
|
||||
return rawJSON, false
|
||||
}
|
||||
return cleaned, true
|
||||
}
|
||||
|
||||
// ensureColonSpacedJSON normalizes JSON objects so that colons are followed by a single space while
|
||||
// keeping the payload otherwise compact. Non-JSON inputs are returned unchanged.
|
||||
func ensureColonSpacedJSON(payload []byte) []byte {
|
||||
trimmed := bytes.TrimSpace(payload)
|
||||
if len(trimmed) == 0 {
|
||||
return payload
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(trimmed, &decoded); err != nil {
|
||||
return payload
|
||||
}
|
||||
|
||||
indented, err := json.MarshalIndent(decoded, "", " ")
|
||||
if err != nil {
|
||||
return payload
|
||||
}
|
||||
|
||||
compacted := make([]byte, 0, len(indented))
|
||||
inString := false
|
||||
skipSpace := false
|
||||
|
||||
for i := 0; i < len(indented); i++ {
|
||||
ch := indented[i]
|
||||
if ch == '"' && (i == 0 || indented[i-1] != '\\') {
|
||||
inString = !inString
|
||||
}
|
||||
|
||||
if !inString {
|
||||
if ch == '\n' || ch == '\r' {
|
||||
skipSpace = true
|
||||
continue
|
||||
}
|
||||
if skipSpace {
|
||||
if ch == ' ' || ch == '\t' {
|
||||
continue
|
||||
}
|
||||
skipSpace = false
|
||||
}
|
||||
}
|
||||
|
||||
compacted = append(compacted, ch)
|
||||
}
|
||||
|
||||
return compacted
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package executor
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -49,8 +52,13 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
modelForUpstream := req.Model
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", modelOverride)
|
||||
modelForUpstream = modelOverride
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.Model, "claude-3-5-haiku") {
|
||||
if !strings.HasPrefix(modelForUpstream, "claude-3-5-haiku") {
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
}
|
||||
|
||||
@@ -84,31 +92,31 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, string(b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
reader := io.Reader(httpResp.Body)
|
||||
var decoder *zstd.Decoder
|
||||
if hasZSTDEcoding(httpResp.Header.Get("Content-Encoding")) {
|
||||
decoder, err = zstd.NewReader(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, fmt.Errorf("failed to initialize zstd decoder: %w", err)
|
||||
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
reader = decoder
|
||||
defer decoder.Close()
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(reader)
|
||||
defer func() {
|
||||
if errClose := decodedBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
data, err := io.ReadAll(decodedBody)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
@@ -141,6 +149,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", modelOverride)
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
@@ -184,19 +195,27 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
if errClose := decodedBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// If from == to (Claude → Claude), directly forward the SSE stream without translation
|
||||
if from == to {
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner := bufio.NewScanner(decodedBody)
|
||||
buf := make([]byte, 20_971_520)
|
||||
scanner.Buffer(buf, 20_971_520)
|
||||
for scanner.Scan() {
|
||||
@@ -220,7 +239,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
// For other formats, use translation
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner := bufio.NewScanner(decodedBody)
|
||||
buf := make([]byte, 20_971_520)
|
||||
scanner.Buffer(buf, 20_971_520)
|
||||
var param any
|
||||
@@ -256,8 +275,13 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
modelForUpstream := req.Model
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", modelOverride)
|
||||
modelForUpstream = modelOverride
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.Model, "claude-3-5-haiku") {
|
||||
if !strings.HasPrefix(modelForUpstream, "claude-3-5-haiku") {
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
}
|
||||
|
||||
@@ -291,29 +315,29 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)}
|
||||
}
|
||||
reader := io.Reader(resp.Body)
|
||||
var decoder *zstd.Decoder
|
||||
if hasZSTDEcoding(resp.Header.Get("Content-Encoding")) {
|
||||
decoder, err = zstd.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("failed to initialize zstd decoder: %w", err)
|
||||
decodedBody, err := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
reader = decoder
|
||||
defer decoder.Close()
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
data, err := io.ReadAll(reader)
|
||||
defer func() {
|
||||
if errClose := decodedBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
data, err := io.ReadAll(decodedBody)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return cliproxyexecutor.Response{}, err
|
||||
@@ -358,17 +382,151 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func hasZSTDEcoding(contentEncoding string) bool {
|
||||
if contentEncoding == "" {
|
||||
return false
|
||||
func (e *ClaudeExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
if alias == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(contentEncoding, ",")
|
||||
for i := range parts {
|
||||
if strings.EqualFold(strings.TrimSpace(parts[i]), "zstd") {
|
||||
return true
|
||||
entry := e.resolveClaudeConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
if modelAlias != "" {
|
||||
if strings.EqualFold(modelAlias, alias) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return alias
|
||||
}
|
||||
continue
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, alias) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) resolveClaudeConfig(auth *cliproxyauth.Auth) *config.ClaudeKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.ClaudeKey {
|
||||
entry := &e.cfg.ClaudeKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.ClaudeKey {
|
||||
entry := &e.cfg.ClaudeKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type compositeReadCloser struct {
|
||||
io.Reader
|
||||
closers []func() error
|
||||
}
|
||||
|
||||
func (c *compositeReadCloser) Close() error {
|
||||
var firstErr error
|
||||
for i := range c.closers {
|
||||
if c.closers[i] == nil {
|
||||
continue
|
||||
}
|
||||
if err := c.closers[i](); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) {
|
||||
if body == nil {
|
||||
return nil, fmt.Errorf("response body is nil")
|
||||
}
|
||||
if contentEncoding == "" {
|
||||
return body, nil
|
||||
}
|
||||
encodings := strings.Split(contentEncoding, ",")
|
||||
for _, raw := range encodings {
|
||||
encoding := strings.TrimSpace(strings.ToLower(raw))
|
||||
switch encoding {
|
||||
case "", "identity":
|
||||
continue
|
||||
case "gzip":
|
||||
gzipReader, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
_ = body.Close()
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
return &compositeReadCloser{
|
||||
Reader: gzipReader,
|
||||
closers: []func() error{
|
||||
gzipReader.Close,
|
||||
func() error { return body.Close() },
|
||||
},
|
||||
}, nil
|
||||
case "deflate":
|
||||
deflateReader := flate.NewReader(body)
|
||||
return &compositeReadCloser{
|
||||
Reader: deflateReader,
|
||||
closers: []func() error{
|
||||
deflateReader.Close,
|
||||
func() error { return body.Close() },
|
||||
},
|
||||
}, nil
|
||||
case "br":
|
||||
return &compositeReadCloser{
|
||||
Reader: brotli.NewReader(body),
|
||||
closers: []func() error{
|
||||
func() error { return body.Close() },
|
||||
},
|
||||
}, nil
|
||||
case "zstd":
|
||||
decoder, err := zstd.NewReader(body)
|
||||
if err != nil {
|
||||
_ = body.Close()
|
||||
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
|
||||
}
|
||||
return &compositeReadCloser{
|
||||
Reader: decoder,
|
||||
closers: []func() error{
|
||||
func() error { decoder.Close(); return nil },
|
||||
func() error { return body.Close() },
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -277,7 +278,180 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented")
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
modelForCounting := req.Model
|
||||
|
||||
if util.InArray([]string{"gpt-5", "gpt-5-minimal", "gpt-5-low", "gpt-5-medium", "gpt-5-high"}, req.Model) {
|
||||
modelForCounting = "gpt-5"
|
||||
body, _ = sjson.SetBytes(body, "model", "gpt-5")
|
||||
switch req.Model {
|
||||
case "gpt-5-minimal":
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "minimal")
|
||||
case "gpt-5-low":
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "low")
|
||||
case "gpt-5-medium":
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "medium")
|
||||
case "gpt-5-high":
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "high")
|
||||
default:
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "low")
|
||||
}
|
||||
} else if util.InArray([]string{"gpt-5-codex", "gpt-5-codex-low", "gpt-5-codex-medium", "gpt-5-codex-high"}, req.Model) {
|
||||
modelForCounting = "gpt-5"
|
||||
body, _ = sjson.SetBytes(body, "model", "gpt-5-codex")
|
||||
switch req.Model {
|
||||
case "gpt-5-codex-low":
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "low")
|
||||
case "gpt-5-codex-medium":
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "medium")
|
||||
case "gpt-5-codex-high":
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "high")
|
||||
default:
|
||||
body, _ = sjson.SetBytes(body, "reasoning.effort", "low")
|
||||
}
|
||||
}
|
||||
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
enc, err := tokenizerForCodexModel(modelForCounting)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := countCodexInputTokens(enc, body)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
switch {
|
||||
case sanitized == "":
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4)
|
||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
default:
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
}
|
||||
}
|
||||
|
||||
func countCodexInputTokens(enc tokenizer.Codec, body []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
if len(body) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(body)
|
||||
var segments []string
|
||||
|
||||
if inst := strings.TrimSpace(root.Get("instructions").String()); inst != "" {
|
||||
segments = append(segments, inst)
|
||||
}
|
||||
|
||||
inputItems := root.Get("input")
|
||||
if inputItems.IsArray() {
|
||||
arr := inputItems.Array()
|
||||
for i := range arr {
|
||||
item := arr[i]
|
||||
switch item.Get("type").String() {
|
||||
case "message":
|
||||
content := item.Get("content")
|
||||
if content.IsArray() {
|
||||
parts := content.Array()
|
||||
for j := range parts {
|
||||
part := parts[j]
|
||||
if text := strings.TrimSpace(part.Get("text").String()); text != "" {
|
||||
segments = append(segments, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
if name := strings.TrimSpace(item.Get("name").String()); name != "" {
|
||||
segments = append(segments, name)
|
||||
}
|
||||
if args := strings.TrimSpace(item.Get("arguments").String()); args != "" {
|
||||
segments = append(segments, args)
|
||||
}
|
||||
case "function_call_output":
|
||||
if out := strings.TrimSpace(item.Get("output").String()); out != "" {
|
||||
segments = append(segments, out)
|
||||
}
|
||||
default:
|
||||
if text := strings.TrimSpace(item.Get("text").String()); text != "" {
|
||||
segments = append(segments, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tools := root.Get("tools")
|
||||
if tools.IsArray() {
|
||||
tarr := tools.Array()
|
||||
for i := range tarr {
|
||||
tool := tarr[i]
|
||||
if name := strings.TrimSpace(tool.Get("name").String()); name != "" {
|
||||
segments = append(segments, name)
|
||||
}
|
||||
if desc := strings.TrimSpace(tool.Get("description").String()); desc != "" {
|
||||
segments = append(segments, desc)
|
||||
}
|
||||
if params := tool.Get("parameters"); params.Exists() {
|
||||
val := params.Raw
|
||||
if params.Type == gjson.String {
|
||||
val = params.String()
|
||||
}
|
||||
if trimmed := strings.TrimSpace(val); trimmed != "" {
|
||||
segments = append(segments, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textFormat := root.Get("text.format")
|
||||
if textFormat.Exists() {
|
||||
if name := strings.TrimSpace(textFormat.Get("name").String()); name != "" {
|
||||
segments = append(segments, name)
|
||||
}
|
||||
if schema := textFormat.Get("schema"); schema.Exists() {
|
||||
val := schema.Raw
|
||||
if schema.Type == gjson.String {
|
||||
val = schema.String()
|
||||
}
|
||||
if trimmed := strings.TrimSpace(val); trimmed != "" {
|
||||
segments = append(segments, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text := strings.Join(segments, "\n")
|
||||
if text == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
count, err := enc.Count(text)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
|
||||
@@ -703,7 +703,7 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["Image", "Text"]`))
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
||||
}
|
||||
}
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "request.generationConfig.imageConfig")
|
||||
|
||||
@@ -494,7 +494,7 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["Image", "Text"]`))
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
||||
}
|
||||
}
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "generationConfig.imageConfig")
|
||||
|
||||
@@ -221,9 +221,24 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// CountTokens is not implemented for iFlow.
|
||||
func (e *IFlowExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{Payload: nil}, fmt.Errorf("not implemented")
|
||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
enc, err := tokenizerForModel(req.Model)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := countOpenAIChatTokens(enc, body)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes OAuth tokens and updates the stored API key.
|
||||
|
||||
@@ -219,7 +219,29 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented")
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
modelForCounting := req.Model
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
modelForCounting = modelOverride
|
||||
}
|
||||
|
||||
enc, err := tokenizerForModel(modelForCounting)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := countOpenAIChatTokens(enc, translated)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil
|
||||
}
|
||||
|
||||
// Refresh is a no-op for API-key based compatibility providers.
|
||||
|
||||
@@ -207,7 +207,28 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented")
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
modelName = req.Model
|
||||
}
|
||||
|
||||
enc, err := tokenizerForModel(modelName)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := countOpenAIChatTokens(enc, body)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
|
||||
234
internal/runtime/executor/token_helpers.go
Normal file
234
internal/runtime/executor/token_helpers.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
switch {
|
||||
case sanitized == "":
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4)
|
||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
case strings.HasPrefix(sanitized, "o1"):
|
||||
return tokenizer.ForModel(tokenizer.O1)
|
||||
case strings.HasPrefix(sanitized, "o3"):
|
||||
return tokenizer.ForModel(tokenizer.O3)
|
||||
case strings.HasPrefix(sanitized, "o4"):
|
||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
||||
default:
|
||||
return tokenizer.Get(tokenizer.O200kBase)
|
||||
}
|
||||
}
|
||||
|
||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
segments := make([]string, 0, 32)
|
||||
|
||||
collectOpenAIMessages(root.Get("messages"), &segments)
|
||||
collectOpenAITools(root.Get("tools"), &segments)
|
||||
collectOpenAIFunctions(root.Get("functions"), &segments)
|
||||
collectOpenAIToolChoice(root.Get("tool_choice"), &segments)
|
||||
collectOpenAIResponseFormat(root.Get("response_format"), &segments)
|
||||
addIfNotEmpty(&segments, root.Get("input").String())
|
||||
addIfNotEmpty(&segments, root.Get("prompt").String())
|
||||
|
||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
if joined == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
func buildOpenAIUsageJSON(count int64) []byte {
|
||||
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
|
||||
}
|
||||
|
||||
func collectOpenAIMessages(messages gjson.Result, segments *[]string) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
addIfNotEmpty(segments, message.Get("name").String())
|
||||
collectOpenAIContent(message.Get("content"), segments)
|
||||
collectOpenAIToolCalls(message.Get("tool_calls"), segments)
|
||||
collectOpenAIFunctionCall(message.Get("function_call"), segments)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func collectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text", "input_text", "output_text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image_url":
|
||||
addIfNotEmpty(segments, part.Get("image_url.url").String())
|
||||
case "input_audio", "output_audio", "audio":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
collectOpenAIContent(part.Get("content"), segments)
|
||||
default:
|
||||
if part.IsArray() {
|
||||
collectOpenAIContent(part, segments)
|
||||
return true
|
||||
}
|
||||
if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
return true
|
||||
}
|
||||
addIfNotEmpty(segments, part.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, content.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
|
||||
if !calls.Exists() || !calls.IsArray() {
|
||||
return
|
||||
}
|
||||
calls.ForEach(func(_, call gjson.Result) bool {
|
||||
addIfNotEmpty(segments, call.Get("id").String())
|
||||
addIfNotEmpty(segments, call.Get("type").String())
|
||||
function := call.Get("function")
|
||||
if function.Exists() {
|
||||
addIfNotEmpty(segments, function.Get("name").String())
|
||||
addIfNotEmpty(segments, function.Get("description").String())
|
||||
addIfNotEmpty(segments, function.Get("arguments").String())
|
||||
if params := function.Get("parameters"); params.Exists() {
|
||||
addIfNotEmpty(segments, params.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func collectOpenAIFunctionCall(call gjson.Result, segments *[]string) {
|
||||
if !call.Exists() {
|
||||
return
|
||||
}
|
||||
addIfNotEmpty(segments, call.Get("name").String())
|
||||
addIfNotEmpty(segments, call.Get("arguments").String())
|
||||
}
|
||||
|
||||
func collectOpenAITools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() {
|
||||
return
|
||||
}
|
||||
if tools.IsArray() {
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
appendToolPayload(tool, segments)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
appendToolPayload(tools, segments)
|
||||
}
|
||||
|
||||
func collectOpenAIFunctions(functions gjson.Result, segments *[]string) {
|
||||
if !functions.Exists() || !functions.IsArray() {
|
||||
return
|
||||
}
|
||||
functions.ForEach(func(_, function gjson.Result) bool {
|
||||
addIfNotEmpty(segments, function.Get("name").String())
|
||||
addIfNotEmpty(segments, function.Get("description").String())
|
||||
if params := function.Get("parameters"); params.Exists() {
|
||||
addIfNotEmpty(segments, params.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func collectOpenAIToolChoice(choice gjson.Result, segments *[]string) {
|
||||
if !choice.Exists() {
|
||||
return
|
||||
}
|
||||
if choice.Type == gjson.String {
|
||||
addIfNotEmpty(segments, choice.String())
|
||||
return
|
||||
}
|
||||
addIfNotEmpty(segments, choice.Raw)
|
||||
}
|
||||
|
||||
func collectOpenAIResponseFormat(format gjson.Result, segments *[]string) {
|
||||
if !format.Exists() {
|
||||
return
|
||||
}
|
||||
addIfNotEmpty(segments, format.Get("type").String())
|
||||
addIfNotEmpty(segments, format.Get("name").String())
|
||||
if schema := format.Get("json_schema"); schema.Exists() {
|
||||
addIfNotEmpty(segments, schema.Raw)
|
||||
}
|
||||
if schema := format.Get("schema"); schema.Exists() {
|
||||
addIfNotEmpty(segments, schema.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func appendToolPayload(tool gjson.Result, segments *[]string) {
|
||||
if !tool.Exists() {
|
||||
return
|
||||
}
|
||||
addIfNotEmpty(segments, tool.Get("type").String())
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if function := tool.Get("function"); function.Exists() {
|
||||
addIfNotEmpty(segments, function.Get("name").String())
|
||||
addIfNotEmpty(segments, function.Get("description").String())
|
||||
if params := function.Get("parameters"); params.Exists() {
|
||||
addIfNotEmpty(segments, params.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addIfNotEmpty(segments *[]string, value string) {
|
||||
if segments == nil {
|
||||
return
|
||||
}
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
*segments = append(*segments, trimmed)
|
||||
}
|
||||
}
|
||||
@@ -354,3 +354,7 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
|
||||
}
|
||||
return rev
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||
}
|
||||
|
||||
@@ -12,8 +12,9 @@ func init() {
|
||||
Codex,
|
||||
ConvertClaudeRequestToCodex,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertCodexResponseToClaude,
|
||||
NonStream: ConvertCodexResponseToClaudeNonStream,
|
||||
Stream: ConvertCodexResponseToClaude,
|
||||
NonStream: ConvertCodexResponseToClaudeNonStream,
|
||||
TokenCount: ClaudeTokenCount,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package geminiCLI
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -54,3 +55,7 @@ func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName str
|
||||
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
|
||||
return strJSON
|
||||
}
|
||||
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
}
|
||||
|
||||
@@ -12,8 +12,9 @@ func init() {
|
||||
Codex,
|
||||
ConvertGeminiCLIRequestToCodex,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertCodexResponseToGeminiCLI,
|
||||
NonStream: ConvertCodexResponseToGeminiCLINonStream,
|
||||
Stream: ConvertCodexResponseToGeminiCLI,
|
||||
NonStream: ConvertCodexResponseToGeminiCLINonStream,
|
||||
TokenCount: GeminiCLITokenCount,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -330,3 +331,7 @@ func mustMarshalJSON(v interface{}) string {
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
}
|
||||
|
||||
@@ -12,8 +12,9 @@ func init() {
|
||||
Codex,
|
||||
ConvertGeminiRequestToCodex,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertCodexResponseToGemini,
|
||||
NonStream: ConvertCodexResponseToGeminiNonStream,
|
||||
Stream: ConvertCodexResponseToGemini,
|
||||
NonStream: ConvertCodexResponseToGeminiNonStream,
|
||||
TokenCount: GeminiTokenCount,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
}
|
||||
|
||||
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
||||
// e.g. "modalities": ["image", "text"] -> ["Image", "Text"]
|
||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
||||
var responseMods []string
|
||||
for _, m := range mods.Array() {
|
||||
switch strings.ToLower(m.String()) {
|
||||
case "text":
|
||||
responseMods = append(responseMods, "Text")
|
||||
responseMods = append(responseMods, "TEXT")
|
||||
case "image":
|
||||
responseMods = append(responseMods, "Image")
|
||||
responseMods = append(responseMods, "IMAGE")
|
||||
}
|
||||
}
|
||||
if len(responseMods) > 0 {
|
||||
@@ -250,8 +250,34 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
if t.Get("type").String() == "function" {
|
||||
fn := t.Get("function")
|
||||
if fn.Exists() && fn.IsObject() {
|
||||
parametersJsonSchema, _ := util.RenameKey(fn.Raw, "parameters", "parametersJsonSchema")
|
||||
out, _ = sjson.SetRawBytes(out, fdPath+".-1", []byte(parametersJsonSchema))
|
||||
fnRaw := fn.Raw
|
||||
if fn.Get("parameters").Exists() {
|
||||
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
|
||||
if errRename != nil {
|
||||
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
|
||||
} else {
|
||||
fnRaw = renamed
|
||||
}
|
||||
} else {
|
||||
var errSet error
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
tmp, errSet := sjson.SetRawBytes(out, fdPath+".-1", []byte(fnRaw))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
out = tmp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
|
||||
// Map OpenAI modalities -> Gemini generationConfig.responseModalities
|
||||
// e.g. "modalities": ["image", "text"] -> ["Image", "Text"]
|
||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
||||
var responseMods []string
|
||||
for _, m := range mods.Array() {
|
||||
switch strings.ToLower(m.String()) {
|
||||
case "text":
|
||||
responseMods = append(responseMods, "Text")
|
||||
responseMods = append(responseMods, "TEXT")
|
||||
case "image":
|
||||
responseMods = append(responseMods, "Image")
|
||||
responseMods = append(responseMods, "IMAGE")
|
||||
}
|
||||
}
|
||||
if len(responseMods) > 0 {
|
||||
|
||||
@@ -12,8 +12,9 @@ func init() {
|
||||
OpenAI,
|
||||
ConvertClaudeRequestToOpenAI,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertOpenAIResponseToClaude,
|
||||
NonStream: ConvertOpenAIResponseToClaudeNonStream,
|
||||
Stream: ConvertOpenAIResponseToClaude,
|
||||
NonStream: ConvertOpenAIResponseToClaudeNonStream,
|
||||
TokenCount: ClaudeTokenCount,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -79,7 +78,9 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
if system.IsArray() {
|
||||
systemResults := system.Array()
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", systemResults[i].Raw)
|
||||
if contentItem, ok := convertClaudeContentPart(systemResults[i]); ok {
|
||||
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", contentItem)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -94,29 +95,16 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
// Handle content
|
||||
if contentResult.Exists() && contentResult.IsArray() {
|
||||
var textParts []string
|
||||
var contentItems []string
|
||||
var toolCalls []interface{}
|
||||
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "text":
|
||||
textParts = append(textParts, part.Get("text").String())
|
||||
|
||||
case "image":
|
||||
// Convert Anthropic image format to OpenAI format
|
||||
if source := part.Get("source"); source.Exists() {
|
||||
sourceType := source.Get("type").String()
|
||||
if sourceType == "base64" {
|
||||
mediaType := source.Get("media_type").String()
|
||||
data := source.Get("data").String()
|
||||
imageURL := "data:" + mediaType + ";base64," + data
|
||||
|
||||
// For now, add as text since OpenAI image handling is complex
|
||||
// In a real implementation, you'd need to handle this properly
|
||||
textParts = append(textParts, "[Image: "+imageURL+"]")
|
||||
}
|
||||
case "text", "image":
|
||||
if contentItem, ok := convertClaudeContentPart(part); ok {
|
||||
contentItems = append(contentItems, contentItem)
|
||||
}
|
||||
|
||||
case "tool_use":
|
||||
@@ -149,13 +137,17 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
})
|
||||
|
||||
// Create main message if there's text content or tool calls
|
||||
if len(textParts) > 0 || len(toolCalls) > 0 {
|
||||
if len(contentItems) > 0 || len(toolCalls) > 0 {
|
||||
msgJSON := `{"role":"","content":""}`
|
||||
msgJSON, _ = sjson.Set(msgJSON, "role", role)
|
||||
|
||||
// Set content
|
||||
if len(textParts) > 0 {
|
||||
msgJSON, _ = sjson.Set(msgJSON, "content", strings.Join(textParts, ""))
|
||||
if len(contentItems) > 0 {
|
||||
contentArrayJSON := "[]"
|
||||
for _, contentItem := range contentItems {
|
||||
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
|
||||
}
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
|
||||
} else {
|
||||
msgJSON, _ = sjson.Set(msgJSON, "content", "")
|
||||
}
|
||||
@@ -166,7 +158,20 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "tool_calls", string(toolCallsJSON))
|
||||
}
|
||||
|
||||
if gjson.Get(msgJSON, "content").String() != "" || len(toolCalls) != 0 {
|
||||
contentValue := gjson.Get(msgJSON, "content")
|
||||
hasContent := false
|
||||
switch {
|
||||
case !contentValue.Exists():
|
||||
hasContent = false
|
||||
case contentValue.Type == gjson.String:
|
||||
hasContent = contentValue.String() != ""
|
||||
case contentValue.IsArray():
|
||||
hasContent = len(contentValue.Array()) > 0
|
||||
default:
|
||||
hasContent = contentValue.Raw != "" && contentValue.Raw != "null"
|
||||
}
|
||||
|
||||
if hasContent || len(toolCalls) != 0 {
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
|
||||
}
|
||||
}
|
||||
@@ -237,3 +242,53 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
func convertClaudeContentPart(part gjson.Result) (string, bool) {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "text":
|
||||
if !part.Get("text").Exists() {
|
||||
return "", false
|
||||
}
|
||||
textContent := `{"type":"text","text":""}`
|
||||
textContent, _ = sjson.Set(textContent, "text", part.Get("text").String())
|
||||
return textContent, true
|
||||
|
||||
case "image":
|
||||
var imageURL string
|
||||
|
||||
if source := part.Get("source"); source.Exists() {
|
||||
sourceType := source.Get("type").String()
|
||||
switch sourceType {
|
||||
case "base64":
|
||||
mediaType := source.Get("media_type").String()
|
||||
if mediaType == "" {
|
||||
mediaType = "application/octet-stream"
|
||||
}
|
||||
data := source.Get("data").String()
|
||||
if data != "" {
|
||||
imageURL = "data:" + mediaType + ";base64," + data
|
||||
}
|
||||
case "url":
|
||||
imageURL = source.Get("url").String()
|
||||
}
|
||||
}
|
||||
|
||||
if imageURL == "" {
|
||||
imageURL = part.Get("url").String()
|
||||
}
|
||||
|
||||
if imageURL == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
imageContent := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imageContent, _ = sjson.Set(imageContent, "image_url.url", imageURL)
|
||||
|
||||
return imageContent, true
|
||||
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -630,3 +631,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
return string(responseJSON)
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||
}
|
||||
|
||||
@@ -12,8 +12,9 @@ func init() {
|
||||
OpenAI,
|
||||
ConvertGeminiCLIRequestToOpenAI,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertOpenAIResponseToGeminiCLI,
|
||||
NonStream: ConvertOpenAIResponseToGeminiCLINonStream,
|
||||
Stream: ConvertOpenAIResponseToGeminiCLI,
|
||||
NonStream: ConvertOpenAIResponseToGeminiCLINonStream,
|
||||
TokenCount: GeminiCLITokenCount,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ package geminiCLI
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -51,3 +52,7 @@ func ConvertOpenAIResponseToGeminiCLINonStream(ctx context.Context, modelName st
|
||||
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
|
||||
return strJSON
|
||||
}
|
||||
|
||||
func GeminiCLITokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
}
|
||||
|
||||
@@ -12,8 +12,9 @@ func init() {
|
||||
OpenAI,
|
||||
ConvertGeminiRequestToOpenAI,
|
||||
interfaces.TranslateResponse{
|
||||
Stream: ConvertOpenAIResponseToGemini,
|
||||
NonStream: ConvertOpenAIResponseToGeminiNonStream,
|
||||
Stream: ConvertOpenAIResponseToGemini,
|
||||
NonStream: ConvertOpenAIResponseToGeminiNonStream,
|
||||
TokenCount: GeminiTokenCount,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -609,3 +610,7 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
|
||||
@@ -17,5 +18,14 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
return bytes.Clone(inputRawJSON)
|
||||
// Update the "model" field in the JSON payload with the provided modelName
|
||||
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
|
||||
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
|
||||
if err != nil {
|
||||
// If there's an error, return the original JSON or handle the error appropriately.
|
||||
// For now, we'll return the original, but in a real scenario, logging or a more robust error
|
||||
// handling mechanism would be needed.
|
||||
return bytes.Clone(inputRawJSON)
|
||||
}
|
||||
return updatedJSON
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -188,3 +189,56 @@ func MaskSensitiveHeaderValue(key, value string) string {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string.
|
||||
func MaskSensitiveQuery(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(raw, "&")
|
||||
changed := false
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
keyPart := part
|
||||
valuePart := ""
|
||||
if idx := strings.Index(part, "="); idx >= 0 {
|
||||
keyPart = part[:idx]
|
||||
valuePart = part[idx+1:]
|
||||
}
|
||||
decodedKey, err := url.QueryUnescape(keyPart)
|
||||
if err != nil {
|
||||
decodedKey = keyPart
|
||||
}
|
||||
if !shouldMaskQueryParam(decodedKey) {
|
||||
continue
|
||||
}
|
||||
decodedValue, err := url.QueryUnescape(valuePart)
|
||||
if err != nil {
|
||||
decodedValue = valuePart
|
||||
}
|
||||
masked := HideAPIKey(strings.TrimSpace(decodedValue))
|
||||
parts[i] = keyPart + "=" + url.QueryEscape(masked)
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
return raw
|
||||
}
|
||||
return strings.Join(parts, "&")
|
||||
}
|
||||
|
||||
func shouldMaskQueryParam(key string) bool {
|
||||
key = strings.ToLower(strings.TrimSpace(key))
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
key = strings.TrimSuffix(key, "[]")
|
||||
if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(key, "token") || strings.Contains(key, "secret") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -423,6 +423,19 @@ func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// computeClaudeModelsHash returns a stable hash for Claude model aliases.
|
||||
func computeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(models)
|
||||
if err != nil || len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// SetClients sets the file-based clients.
|
||||
// SetClients removed
|
||||
// SetAPIKeyClients removed
|
||||
@@ -760,13 +773,17 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
id, token := idGen.next("claude:apikey", key, ck.BaseURL)
|
||||
base := strings.TrimSpace(ck.BaseURL)
|
||||
id, token := idGen.next("claude:apikey", key, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
if hash := computeClaudeModelsHash(ck.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
proxyURL := strings.TrimSpace(ck.ProxyURL)
|
||||
a := &coreauth.Auth{
|
||||
@@ -1204,6 +1221,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
|
||||
}
|
||||
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
|
||||
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
|
||||
}
|
||||
|
||||
// Quota-exceeded behavior
|
||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||
|
||||
233
internal/wsrelay/http.go
Normal file
233
internal/wsrelay/http.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package wsrelay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// HTTPRequest represents a proxied HTTP request delivered to websocket clients.
|
||||
type HTTPRequest struct {
|
||||
Method string
|
||||
URL string
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
}
|
||||
|
||||
// HTTPResponse captures the response relayed back from websocket clients.
|
||||
type HTTPResponse struct {
|
||||
Status int
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming response event from clients.
|
||||
type StreamEvent struct {
|
||||
Type string
|
||||
Payload []byte
|
||||
Status int
|
||||
Headers http.Header
|
||||
Err error
|
||||
}
|
||||
|
||||
// NonStream executes a non-streaming HTTP request using the websocket provider.
|
||||
func (m *Manager) NonStream(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("wsrelay: request is nil")
|
||||
}
|
||||
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
||||
respCh, err := m.Send(ctx, provider, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var (
|
||||
streamMode bool
|
||||
streamResp *HTTPResponse
|
||||
streamBody bytes.Buffer
|
||||
)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case msg, ok := <-respCh:
|
||||
if !ok {
|
||||
if streamMode {
|
||||
if streamResp == nil {
|
||||
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||
} else if streamResp.Headers == nil {
|
||||
streamResp.Headers = make(http.Header)
|
||||
}
|
||||
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
|
||||
return streamResp, nil
|
||||
}
|
||||
return nil, errors.New("wsrelay: connection closed during response")
|
||||
}
|
||||
switch msg.Type {
|
||||
case MessageTypeHTTPResp:
|
||||
resp := decodeResponse(msg.Payload)
|
||||
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 {
|
||||
resp.Body = append(resp.Body[:0], streamBody.Bytes()...)
|
||||
}
|
||||
return resp, nil
|
||||
case MessageTypeError:
|
||||
return nil, decodeError(msg.Payload)
|
||||
case MessageTypeStreamStart, MessageTypeStreamChunk:
|
||||
if msg.Type == MessageTypeStreamStart {
|
||||
streamMode = true
|
||||
streamResp = decodeResponse(msg.Payload)
|
||||
if streamResp.Headers == nil {
|
||||
streamResp.Headers = make(http.Header)
|
||||
}
|
||||
streamBody.Reset()
|
||||
continue
|
||||
}
|
||||
if !streamMode {
|
||||
streamMode = true
|
||||
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||
}
|
||||
chunk := decodeChunk(msg.Payload)
|
||||
if len(chunk) > 0 {
|
||||
streamBody.Write(chunk)
|
||||
}
|
||||
case MessageTypeStreamEnd:
|
||||
if !streamMode {
|
||||
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil
|
||||
}
|
||||
if streamResp == nil {
|
||||
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||
} else if streamResp.Headers == nil {
|
||||
streamResp.Headers = make(http.Header)
|
||||
}
|
||||
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
|
||||
return streamResp, nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream executes a streaming HTTP request and returns channel with stream events.
|
||||
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("wsrelay: request is nil")
|
||||
}
|
||||
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
||||
respCh, err := m.Send(ctx, provider, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan StreamEvent)
|
||||
go func() {
|
||||
defer close(out)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
out <- StreamEvent{Err: ctx.Err()}
|
||||
return
|
||||
case msg, ok := <-respCh:
|
||||
if !ok {
|
||||
out <- StreamEvent{Err: errors.New("wsrelay: stream closed")}
|
||||
return
|
||||
}
|
||||
switch msg.Type {
|
||||
case MessageTypeStreamStart:
|
||||
resp := decodeResponse(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}
|
||||
case MessageTypeStreamChunk:
|
||||
chunk := decodeChunk(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}
|
||||
case MessageTypeStreamEnd:
|
||||
out <- StreamEvent{Type: MessageTypeStreamEnd}
|
||||
return
|
||||
case MessageTypeError:
|
||||
out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)}
|
||||
return
|
||||
case MessageTypeHTTPResp:
|
||||
resp := decodeResponse(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body}
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func encodeRequest(req *HTTPRequest) map[string]any {
|
||||
headers := make(map[string]any, len(req.Headers))
|
||||
for key, values := range req.Headers {
|
||||
copyValues := make([]string, len(values))
|
||||
copy(copyValues, values)
|
||||
headers[key] = copyValues
|
||||
}
|
||||
return map[string]any{
|
||||
"method": req.Method,
|
||||
"url": req.URL,
|
||||
"headers": headers,
|
||||
"body": string(req.Body),
|
||||
"sent_at": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
}
|
||||
}
|
||||
|
||||
func decodeResponse(payload map[string]any) *HTTPResponse {
|
||||
if payload == nil {
|
||||
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)}
|
||||
}
|
||||
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||
if status, ok := payload["status"].(float64); ok {
|
||||
resp.Status = int(status)
|
||||
}
|
||||
if headers, ok := payload["headers"].(map[string]any); ok {
|
||||
for key, raw := range headers {
|
||||
switch v := raw.(type) {
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
resp.Headers.Add(key, str)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, str := range v {
|
||||
resp.Headers.Add(key, str)
|
||||
}
|
||||
case string:
|
||||
resp.Headers.Set(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if body, ok := payload["body"].(string); ok {
|
||||
resp.Body = []byte(body)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func decodeChunk(payload map[string]any) []byte {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
if data, ok := payload["data"].(string); ok {
|
||||
return []byte(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeError(payload map[string]any) error {
|
||||
if payload == nil {
|
||||
return errors.New("wsrelay: unknown error")
|
||||
}
|
||||
message, _ := payload["error"].(string)
|
||||
status := 0
|
||||
if v, ok := payload["status"].(float64); ok {
|
||||
status = int(v)
|
||||
}
|
||||
if message == "" {
|
||||
message = "wsrelay: upstream error"
|
||||
}
|
||||
return fmt.Errorf("%s (status=%d)", message, status)
|
||||
}
|
||||
205
internal/wsrelay/manager.go
Normal file
205
internal/wsrelay/manager.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package wsrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Manager exposes a websocket endpoint that proxies Gemini requests to
|
||||
// connected clients.
|
||||
type Manager struct {
|
||||
path string
|
||||
upgrader websocket.Upgrader
|
||||
sessions map[string]*session
|
||||
sessMutex sync.RWMutex
|
||||
|
||||
providerFactory func(*http.Request) (string, error)
|
||||
onConnected func(string)
|
||||
onDisconnected func(string, error)
|
||||
|
||||
logDebugf func(string, ...any)
|
||||
logInfof func(string, ...any)
|
||||
logWarnf func(string, ...any)
|
||||
}
|
||||
|
||||
// Options configures a Manager instance.
|
||||
type Options struct {
|
||||
Path string
|
||||
ProviderFactory func(*http.Request) (string, error)
|
||||
OnConnected func(string)
|
||||
OnDisconnected func(string, error)
|
||||
LogDebugf func(string, ...any)
|
||||
LogInfof func(string, ...any)
|
||||
LogWarnf func(string, ...any)
|
||||
}
|
||||
|
||||
// NewManager builds a websocket relay manager with the supplied options.
|
||||
func NewManager(opts Options) *Manager {
|
||||
path := strings.TrimSpace(opts.Path)
|
||||
if path == "" {
|
||||
path = "/v1/ws"
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
mgr := &Manager{
|
||||
path: path,
|
||||
sessions: make(map[string]*session),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
providerFactory: opts.ProviderFactory,
|
||||
onConnected: opts.OnConnected,
|
||||
onDisconnected: opts.OnDisconnected,
|
||||
logDebugf: opts.LogDebugf,
|
||||
logInfof: opts.LogInfof,
|
||||
logWarnf: opts.LogWarnf,
|
||||
}
|
||||
if mgr.logDebugf == nil {
|
||||
mgr.logDebugf = func(string, ...any) {}
|
||||
}
|
||||
if mgr.logInfof == nil {
|
||||
mgr.logInfof = func(string, ...any) {}
|
||||
}
|
||||
if mgr.logWarnf == nil {
|
||||
mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) }
|
||||
}
|
||||
return mgr
|
||||
}
|
||||
|
||||
// Path returns the HTTP path the manager expects for websocket upgrades.
|
||||
func (m *Manager) Path() string {
|
||||
if m == nil {
|
||||
return "/v1/ws"
|
||||
}
|
||||
return m.path
|
||||
}
|
||||
|
||||
// Handler exposes an http.Handler that upgrades connections to websocket sessions.
|
||||
func (m *Manager) Handler() http.Handler {
|
||||
return http.HandlerFunc(m.handleWebsocket)
|
||||
}
|
||||
|
||||
// Stop gracefully closes all active websocket sessions.
|
||||
func (m *Manager) Stop(_ context.Context) error {
|
||||
m.sessMutex.Lock()
|
||||
sessions := make([]*session, 0, len(m.sessions))
|
||||
for _, sess := range m.sessions {
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
m.sessions = make(map[string]*session)
|
||||
m.sessMutex.Unlock()
|
||||
|
||||
for _, sess := range sessions {
|
||||
if sess != nil {
|
||||
sess.cleanup(errors.New("wsrelay: manager stopped"))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleWebsocket upgrades the connection and wires the session into the pool.
|
||||
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
expectedPath := m.Path()
|
||||
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(r.Method, http.MethodGet) {
|
||||
w.Header().Set("Allow", http.MethodGet)
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
conn, err := m.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
m.logWarnf("wsrelay: upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
s := newSession(conn, m, randomProviderName())
|
||||
if m.providerFactory != nil {
|
||||
name, err := m.providerFactory(r)
|
||||
if err != nil {
|
||||
s.cleanup(err)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(name) != "" {
|
||||
s.provider = strings.ToLower(name)
|
||||
}
|
||||
}
|
||||
if s.provider == "" {
|
||||
s.provider = strings.ToLower(s.id)
|
||||
}
|
||||
m.sessMutex.Lock()
|
||||
var replaced *session
|
||||
if existing, ok := m.sessions[s.provider]; ok {
|
||||
replaced = existing
|
||||
}
|
||||
m.sessions[s.provider] = s
|
||||
m.sessMutex.Unlock()
|
||||
|
||||
if replaced != nil {
|
||||
replaced.cleanup(errors.New("replaced by new connection"))
|
||||
}
|
||||
if m.onConnected != nil {
|
||||
m.onConnected(s.provider)
|
||||
}
|
||||
|
||||
go s.run(context.Background())
|
||||
}
|
||||
|
||||
// Send forwards the message to the specific provider connection and returns a channel
|
||||
// yielding response messages.
|
||||
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
|
||||
s := m.session(provider)
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
|
||||
}
|
||||
return s.request(ctx, msg)
|
||||
}
|
||||
|
||||
func (m *Manager) session(provider string) *session {
|
||||
key := strings.ToLower(strings.TrimSpace(provider))
|
||||
m.sessMutex.RLock()
|
||||
s := m.sessions[key]
|
||||
m.sessMutex.RUnlock()
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *Manager) handleSessionClosed(s *session, cause error) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(s.provider))
|
||||
m.sessMutex.Lock()
|
||||
if cur, ok := m.sessions[key]; ok && cur == s {
|
||||
delete(m.sessions, key)
|
||||
}
|
||||
m.sessMutex.Unlock()
|
||||
if m.onDisconnected != nil {
|
||||
m.onDisconnected(s.provider, cause)
|
||||
}
|
||||
}
|
||||
|
||||
func randomProviderName() string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
|
||||
}
|
||||
for i := range buf {
|
||||
buf[i] = alphabet[int(buf[i])%len(alphabet)]
|
||||
}
|
||||
return "aistudio-" + string(buf)
|
||||
}
|
||||
27
internal/wsrelay/message.go
Normal file
27
internal/wsrelay/message.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package wsrelay
|
||||
|
||||
// Message represents the JSON payload exchanged with websocket clients.
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Payload map[string]any `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// MessageTypeHTTPReq identifies an HTTP-style request envelope.
|
||||
MessageTypeHTTPReq = "http_request"
|
||||
// MessageTypeHTTPResp identifies a non-streaming HTTP response envelope.
|
||||
MessageTypeHTTPResp = "http_response"
|
||||
// MessageTypeStreamStart marks the beginning of a streaming response.
|
||||
MessageTypeStreamStart = "stream_start"
|
||||
// MessageTypeStreamChunk carries a streaming response chunk.
|
||||
MessageTypeStreamChunk = "stream_chunk"
|
||||
// MessageTypeStreamEnd marks the completion of a streaming response.
|
||||
MessageTypeStreamEnd = "stream_end"
|
||||
// MessageTypeError carries an error response.
|
||||
MessageTypeError = "error"
|
||||
// MessageTypePing represents ping messages from clients.
|
||||
MessageTypePing = "ping"
|
||||
// MessageTypePong represents pong responses back to clients.
|
||||
MessageTypePong = "pong"
|
||||
)
|
||||
188
internal/wsrelay/session.go
Normal file
188
internal/wsrelay/session.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package wsrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
readTimeout = 60 * time.Second
|
||||
writeTimeout = 10 * time.Second
|
||||
maxInboundMessageLen = 64 << 20 // 64 MiB
|
||||
heartbeatInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
var errClosed = errors.New("websocket session closed")
|
||||
|
||||
type pendingRequest struct {
|
||||
ch chan Message
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (pr *pendingRequest) close() {
|
||||
if pr == nil {
|
||||
return
|
||||
}
|
||||
pr.closeOnce.Do(func() {
|
||||
close(pr.ch)
|
||||
})
|
||||
}
|
||||
|
||||
type session struct {
|
||||
conn *websocket.Conn
|
||||
manager *Manager
|
||||
provider string
|
||||
id string
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
writeMutex sync.Mutex
|
||||
pending sync.Map // map[string]*pendingRequest
|
||||
}
|
||||
|
||||
func newSession(conn *websocket.Conn, mgr *Manager, id string) *session {
|
||||
s := &session{
|
||||
conn: conn,
|
||||
manager: mgr,
|
||||
provider: "",
|
||||
id: id,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
conn.SetReadLimit(maxInboundMessageLen)
|
||||
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
return nil
|
||||
})
|
||||
s.startHeartbeat()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) startHeartbeat() {
|
||||
if s == nil || s.conn == nil {
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(heartbeatInterval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.writeMutex.Lock()
|
||||
err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout))
|
||||
s.writeMutex.Unlock()
|
||||
if err != nil {
|
||||
s.cleanup(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *session) run(ctx context.Context) {
|
||||
defer s.cleanup(errClosed)
|
||||
for {
|
||||
var msg Message
|
||||
if err := s.conn.ReadJSON(&msg); err != nil {
|
||||
s.cleanup(err)
|
||||
return
|
||||
}
|
||||
s.dispatch(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) dispatch(msg Message) {
|
||||
if msg.Type == MessageTypePing {
|
||||
_ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong})
|
||||
return
|
||||
}
|
||||
if value, ok := s.pending.Load(msg.ID); ok {
|
||||
req := value.(*pendingRequest)
|
||||
select {
|
||||
case req.ch <- msg:
|
||||
default:
|
||||
}
|
||||
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
||||
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||
actual.(*pendingRequest).close()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
||||
s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) send(ctx context.Context, msg Message) error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return errClosed
|
||||
default:
|
||||
}
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
if err := s.conn.WriteJSON(msg); err != nil {
|
||||
return fmt.Errorf("write json: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) {
|
||||
if msg.ID == "" {
|
||||
return nil, fmt.Errorf("wsrelay: message id is required")
|
||||
}
|
||||
if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded {
|
||||
return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID)
|
||||
}
|
||||
value, _ := s.pending.Load(msg.ID)
|
||||
req := value.(*pendingRequest)
|
||||
if err := s.send(ctx, msg); err != nil {
|
||||
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||
req := actual.(*pendingRequest)
|
||||
req.close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||
actual.(*pendingRequest).close()
|
||||
}
|
||||
case <-s.closed:
|
||||
}
|
||||
}()
|
||||
return req.ch, nil
|
||||
}
|
||||
|
||||
func (s *session) cleanup(cause error) {
|
||||
s.closeOnce.Do(func() {
|
||||
close(s.closed)
|
||||
s.pending.Range(func(key, value any) bool {
|
||||
req := value.(*pendingRequest)
|
||||
msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}}
|
||||
select {
|
||||
case req.ch <- msg:
|
||||
default:
|
||||
}
|
||||
req.close()
|
||||
return true
|
||||
})
|
||||
s.pending = sync.Map{}
|
||||
_ = s.conn.Close()
|
||||
if s.manager != nil {
|
||||
s.manager.handleSessionClosed(s, cause)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -6,6 +6,7 @@ package handlers
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
|
||||
|
||||
// Cfg holds the current application configuration.
|
||||
Cfg *config.SDKConfig
|
||||
|
||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||
OpenAICompatProviders []string
|
||||
}
|
||||
|
||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
|
||||
//
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||
return &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
OpenAICompatProviders: openAICompatProviders,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 {
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 {
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 {
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
errChan <- errMsg
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
}
|
||||
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
|
||||
|
||||
// First, normalize the model name to handle suffixes like "-thinking-128"
|
||||
// This needs to happen before determining the provider for non-dynamic models.
|
||||
normalizedModel, metadata = normalizeModelMetadata(modelName)
|
||||
|
||||
if isDynamic {
|
||||
providers = []string{providerName}
|
||||
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
|
||||
// so we use it as the final normalizedModel.
|
||||
normalizedModel = extractedModelName
|
||||
} else {
|
||||
// For non-dynamic models, use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
}
|
||||
|
||||
if len(providers) == 0 {
|
||||
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
}
|
||||
|
||||
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
|
||||
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
|
||||
// So, normalizedModel is already correctly set at this point.
|
||||
|
||||
return providers, normalizedModel, metadata, nil
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
||||
var providerPart, modelPart string
|
||||
for _, sep := range []string{"://"} {
|
||||
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
||||
providerPart = parts[0]
|
||||
modelPart = parts[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if providerPart == "" {
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
// Check if the provider is a configured openai-compatibility provider
|
||||
for _, pName := range h.OpenAICompatProviders {
|
||||
if pName == providerPart {
|
||||
return providerPart, modelPart, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -153,6 +153,17 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
||||
m.executors[executor.Identifier()] = executor
|
||||
}
|
||||
|
||||
// UnregisterExecutor removes the executor associated with the provider key.
|
||||
func (m *Manager) UnregisterExecutor(provider string) {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
delete(m.executors, provider)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// Register inserts a new auth entry into the manager.
|
||||
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if auth == nil {
|
||||
|
||||
@@ -156,7 +156,8 @@ func (a *Auth) AccountInfo() (string, string) {
|
||||
if v, ok := a.Metadata["email"].(string); ok {
|
||||
return "oauth", v
|
||||
}
|
||||
} else if a.Attributes != nil {
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
if v := a.Attributes["api_key"]; v != "" {
|
||||
return "api_key", v
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -82,6 +83,9 @@ type Service struct {
|
||||
|
||||
// shutdownOnce ensures shutdown is called only once.
|
||||
shutdownOnce sync.Once
|
||||
|
||||
// wsGateway manages websocket Gemini providers.
|
||||
wsGateway *wsrelay.Manager
|
||||
}
|
||||
|
||||
// RegisterUsagePlugin registers a usage plugin on the global usage manager.
|
||||
@@ -172,6 +176,69 @@ func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdat
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ensureWebsocketGateway() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if s.wsGateway != nil {
|
||||
return
|
||||
}
|
||||
opts := wsrelay.Options{
|
||||
Path: "/v1/ws",
|
||||
OnConnected: s.wsOnConnected,
|
||||
OnDisconnected: s.wsOnDisconnected,
|
||||
LogDebugf: log.Debugf,
|
||||
LogInfof: log.Infof,
|
||||
LogWarnf: log.Warnf,
|
||||
}
|
||||
s.wsGateway = wsrelay.NewManager(opts)
|
||||
}
|
||||
|
||||
func (s *Service) wsOnConnected(channelID string) {
|
||||
if s == nil || channelID == "" {
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(channelID), "aistudio-") {
|
||||
return
|
||||
}
|
||||
if s.coreManager != nil {
|
||||
if existing, ok := s.coreManager.GetByID(channelID); ok && existing != nil {
|
||||
if !existing.Disabled && existing.Status == coreauth.StatusActive {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
auth := &coreauth.Auth{
|
||||
ID: channelID, // keep channel identifier as ID
|
||||
Provider: "aistudio", // logical provider for switch routing
|
||||
Label: channelID, // display original channel id
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{"email": channelID}, // inject email inline
|
||||
}
|
||||
log.Infof("websocket provider connected: %s", channelID)
|
||||
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
|
||||
}
|
||||
|
||||
func (s *Service) wsOnDisconnected(channelID string, reason error) {
|
||||
if s == nil || channelID == "" {
|
||||
return
|
||||
}
|
||||
if reason != nil {
|
||||
if strings.Contains(reason.Error(), "replaced by new connection") {
|
||||
log.Infof("websocket provider replaced: %s", channelID)
|
||||
return
|
||||
}
|
||||
log.Warnf("websocket provider disconnected: %s (%v)", channelID, reason)
|
||||
} else {
|
||||
log.Infof("websocket provider disconnected: %s", channelID)
|
||||
}
|
||||
ctx := context.Background()
|
||||
s.applyCoreAuthRemoval(ctx, channelID)
|
||||
}
|
||||
|
||||
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
|
||||
if s == nil || auth == nil || auth.ID == "" {
|
||||
return
|
||||
@@ -252,6 +319,11 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
||||
case "gemini-cli":
|
||||
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
|
||||
case "aistudio":
|
||||
if s.wsGateway != nil {
|
||||
s.coreManager.RegisterExecutor(executor.NewAIStudioExecutor(s.cfg, a.ID, s.wsGateway))
|
||||
}
|
||||
return
|
||||
case "claude":
|
||||
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
||||
case "codex":
|
||||
@@ -342,6 +414,27 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.authManager = newDefaultAuthManager()
|
||||
}
|
||||
|
||||
s.ensureWebsocketGateway()
|
||||
if s.server != nil && s.wsGateway != nil {
|
||||
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
|
||||
s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) {
|
||||
if oldEnabled == newEnabled {
|
||||
return
|
||||
}
|
||||
if !oldEnabled && newEnabled {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if errStop := s.wsGateway.Stop(ctx); errStop != nil {
|
||||
log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop)
|
||||
return
|
||||
}
|
||||
log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication")
|
||||
return
|
||||
}
|
||||
log.Debugf("ws-auth disabled; existing websocket sessions remain connected")
|
||||
})
|
||||
}
|
||||
|
||||
if s.hooks.OnBeforeStart != nil {
|
||||
s.hooks.OnBeforeStart(s.cfg)
|
||||
}
|
||||
@@ -379,7 +472,6 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.cfg = newCfg
|
||||
s.cfgMu.Unlock()
|
||||
s.rebindExecutors()
|
||||
|
||||
}
|
||||
|
||||
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)
|
||||
@@ -449,6 +541,14 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
||||
shutdownErr = err
|
||||
}
|
||||
}
|
||||
if s.wsGateway != nil {
|
||||
if err := s.wsGateway.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to stop websocket gateway: %v", err)
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.authQueueStop != nil {
|
||||
s.authQueueStop()
|
||||
s.authQueueStop = nil
|
||||
@@ -514,8 +614,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
models = registry.GetGeminiModels()
|
||||
case "gemini-cli":
|
||||
models = registry.GetGeminiCLIModels()
|
||||
case "aistudio":
|
||||
models = registry.GetAIStudioModels()
|
||||
case "claude":
|
||||
models = registry.GetClaudeModels()
|
||||
if entry := s.resolveConfigClaudeKey(a); entry != nil && len(entry.Models) > 0 {
|
||||
models = buildClaudeConfigModels(entry)
|
||||
}
|
||||
case "codex":
|
||||
models = registry.GetOpenAIModels()
|
||||
case "qwen":
|
||||
@@ -611,3 +716,80 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
|
||||
if auth == nil || s.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range s.cfg.ClaudeKey {
|
||||
entry := &s.cfg.ClaudeKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if attrBase == "" || cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range s.cfg.ClaudeKey {
|
||||
entry := &s.cfg.ClaudeKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
out := make([]*ModelInfo, 0, len(entry.Models))
|
||||
seen := make(map[string]struct{}, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = name
|
||||
}
|
||||
if alias == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(alias)
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
display := name
|
||||
if display == "" {
|
||||
display = alias
|
||||
}
|
||||
out = append(out, &ModelInfo{
|
||||
ID: alias,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "claude",
|
||||
Type: "claude",
|
||||
DisplayName: display,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user