mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-10 16:30:51 +08:00
Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0040d78496 | ||
|
|
896de027cc | ||
|
|
fc329ebf37 | ||
|
|
eaab1d6824 | ||
|
|
0cfe310df6 | ||
|
|
918b6955e4 | ||
|
|
5a3eb08739 | ||
|
|
0dff329162 | ||
|
|
49c1740b47 | ||
|
|
3fbee51e9f | ||
|
|
63643c44a1 | ||
|
|
3b34521ad9 | ||
|
|
7197fb350b | ||
|
|
6e349bfcc7 | ||
|
|
234056072d | ||
|
|
7e9d0db6aa | ||
|
|
2f1874ede5 | ||
|
|
78ef04fcf1 | ||
|
|
b7e4f00c5f | ||
|
|
f7d0019df7 | ||
|
|
52364af5bf | ||
|
|
f410dd0440 | ||
|
|
eb5582c17c | ||
|
|
1c6cb2bec3 | ||
|
|
80b5e79e75 | ||
|
|
394497fb2f | ||
|
|
fc7b6ef086 | ||
|
|
1187aa8222 | ||
|
|
dc9b4dd017 | ||
|
|
68cb81a258 | ||
|
|
c874f19f2a | ||
|
|
f5f26f0cbe | ||
|
|
4b00312fef | ||
|
|
c5fd3db01e | ||
|
|
f870a9d2a7 | ||
|
|
b4e034be1c | ||
|
|
a5a25dec57 | ||
|
|
c71905e5e8 | ||
|
|
bc78d668ac | ||
|
|
5bd0896ad7 | ||
|
|
09ecfbcaed | ||
|
|
f0bd14b64f | ||
|
|
f7d82fda3f | ||
|
|
706590c62a | ||
|
|
25c6b479c7 | ||
|
|
7cf9ff0345 | ||
|
|
209d74062a | ||
|
|
d86b13c9cb | ||
|
|
075e3ab69e | ||
|
|
c1c9483752 | ||
|
|
6c65fdf54b | ||
|
|
4874253d1e | ||
|
|
b72250349f | ||
|
|
116573311f | ||
|
|
4af712544d | ||
|
|
3f9c9591bd | ||
|
|
1548c567ab | ||
|
|
5b23fc570c | ||
|
|
04e1c7a05a | ||
|
|
9181e72204 | ||
|
|
4939865f6d | ||
|
|
3da7f7482e | ||
|
|
9072b029b2 | ||
|
|
c296cfb8c0 | ||
|
|
2707377fcb | ||
|
|
259f586ff7 | ||
|
|
d885b81f23 | ||
|
|
fe6bffd080 | ||
|
|
a275db3fdb | ||
|
|
233be6272a | ||
|
|
47cb52385e | ||
|
|
a406ca2d5a |
@@ -27,8 +27,8 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=cliproxyapi">this link</a> and enter the "cliproxyapi" promo code during recharge to get 10% off.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
|
||||
<td>Thanks to Cubence for sponsoring this project! Cubence is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. Cubence provides special discounts for our software users: register using <a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">this link</a> and enter the "CLIPROXYAPI" promo code during recharge to get 10% off.</td>
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via <a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -142,6 +142,10 @@ A lightweight web admin panel for CLIProxyAPI with health checks, resource monit
|
||||
|
||||
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
|
||||
|
||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||
|
||||
霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
16
README_CN.md
16
README_CN.md
@@ -27,8 +27,8 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元
|
||||
<td>感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.packyapi.com/register?aff=cliproxyapi">此链接</a>注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
|
||||
<td>感谢 Cubence 对本项目的赞助!Cubence 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。Cubence 为本软件用户提供了特别优惠:使用<a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">此链接</a>注册,并在充值时输入 "CLIPROXYAPI" 优惠码即可享受九折优惠。</td>
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">此链接</a>注册的用户,可享受首充8折,企业客户最高可享 7.5 折!</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -137,6 +137,14 @@ Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI
|
||||
|
||||
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
|
||||
|
||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||
|
||||
Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方库。主要功能包括:自动创建快捷方式、静默运行、密码管理、通道切换(Main / Plus)以及自动下载与更新。
|
||||
|
||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||
|
||||
霖君是一款用于管理AI编程助手的跨平台桌面应用,支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具,本地代理实现多账户配额跟踪和一键配置。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
@@ -148,10 +156,6 @@ Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI
|
||||
|
||||
基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
|
||||
|
||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||
|
||||
Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方库。主要功能包括:自动创建快捷方式、静默运行、密码管理、通道切换(Main / Plus)以及自动下载与更新。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
||||
|
||||
|
||||
BIN
assets/aicodemirror.png
Normal file
BIN
assets/aicodemirror.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 51 KiB |
@@ -63,6 +63,7 @@ func main() {
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kimiLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
@@ -78,6 +79,7 @@ func main() {
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
@@ -443,7 +445,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Register built-in access providers before constructing services.
|
||||
configaccess.Register()
|
||||
configaccess.Register(&cfg.SDKConfig)
|
||||
|
||||
// Handle different command modes based on the provided flags.
|
||||
|
||||
@@ -468,6 +470,8 @@ func main() {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else {
|
||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||
if isCloudDeploy && !configFileExists {
|
||||
|
||||
@@ -40,6 +40,11 @@ api-keys:
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
|
||||
pprof:
|
||||
enable: false
|
||||
addr: "127.0.0.1:8316"
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
|
||||
@@ -216,25 +221,10 @@ nonstream-keepalive-interval: 0
|
||||
|
||||
# Global OAuth model name aliases (per channel)
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
oauth-model-alias:
|
||||
antigravity:
|
||||
- name: "rev19-uic3-1p"
|
||||
alias: "gemini-2.5-computer-use-preview-10-2025"
|
||||
- name: "gemini-3-pro-image"
|
||||
alias: "gemini-3-pro-image-preview"
|
||||
- name: "gemini-3-pro-high"
|
||||
alias: "gemini-3-pro-preview"
|
||||
- name: "gemini-3-flash"
|
||||
alias: "gemini-3-flash-preview"
|
||||
- name: "claude-sonnet-4-5"
|
||||
alias: "gemini-claude-sonnet-4-5"
|
||||
- name: "claude-sonnet-4-5-thinking"
|
||||
alias: "gemini-claude-sonnet-4-5-thinking"
|
||||
- name: "claude-opus-4-5-thinking"
|
||||
alias: "gemini-claude-opus-4-5-thinking"
|
||||
# oauth-model-alias:
|
||||
# gemini-cli:
|
||||
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||
# alias: "g2.5p" # client-visible alias
|
||||
@@ -245,6 +235,9 @@ oauth-model-alias:
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# antigravity:
|
||||
# - name: "gemini-3-pro-high"
|
||||
# alias: "gemini-3-pro-preview"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
@@ -257,6 +250,9 @@ oauth-model-alias:
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
# kimi:
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "k2.5"
|
||||
|
||||
# OAuth provider excluded models
|
||||
# oauth-excluded-models:
|
||||
@@ -279,6 +275,8 @@ oauth-model-alias:
|
||||
# - "vision-model"
|
||||
# iflow:
|
||||
# - "tstars2.0"
|
||||
# kimi:
|
||||
# - "kimi-k2-thinking"
|
||||
|
||||
# Optional payload configuration
|
||||
# payload:
|
||||
|
||||
@@ -7,80 +7,71 @@ The `github.com/router-for-me/CLIProxyAPI/v6/sdk/access` package centralizes inb
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
|
||||
|
||||
## Provider Registry
|
||||
|
||||
Providers are registered globally and then attached to a `Manager` as a snapshot:
|
||||
|
||||
- `RegisterProvider(type, provider)` installs a pre-initialized provider instance.
|
||||
- Registration order is preserved the first time each `type` is seen.
|
||||
- `RegisteredProviders()` returns the providers in that order.
|
||||
|
||||
## Manager Lifecycle
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
* `NewManager` constructs an empty manager.
|
||||
* `SetProviders` replaces the provider slice using a defensive copy.
|
||||
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
|
||||
* `BuildProviders` translates `config.Config` access declarations into runnable providers. When the config omits explicit providers but defines inline API keys, the helper auto-installs the built-in `config-api-key` provider.
|
||||
|
||||
If the manager itself is `nil` or no providers are configured, the call returns `nil, nil`, allowing callers to treat access control as disabled.
|
||||
|
||||
## Authenticating Requests
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result describes the provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Supplied credentials were present but rejected.
|
||||
default:
|
||||
// Transport-level failure was returned by a provider.
|
||||
// Internal/transport failure was returned by a provider.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that surface `ErrNotHandled`, and tracks whether any provider reported `ErrNoCredentials` or `ErrInvalidCredential` for downstream error reporting.
|
||||
|
||||
If the manager itself is `nil` or no providers are registered, the call returns `nil, nil`, allowing callers to treat access control as disabled without branching on errors.
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that return `AuthErrorCodeNotHandled`, and aggregates `AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` for a final result.
|
||||
|
||||
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
|
||||
|
||||
## Configuration Layout
|
||||
## Built-in `config-api-key` Provider
|
||||
|
||||
The manager expects access providers under the `auth.providers` key inside `config.yaml`:
|
||||
The proxy includes one built-in access provider:
|
||||
|
||||
- `config-api-key`: Validates API keys declared under top-level `api-keys`.
|
||||
- Credential sources: `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, `?key=`, `?auth_token=`
|
||||
- Metadata: `Result.Metadata["source"]` is set to the matched source label.
|
||||
|
||||
In the CLI server and `sdk/cliproxy`, this provider is registered automatically based on the loaded configuration.
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
Fields map directly to `config.AccessProvider`: `name` labels the provider, `type` selects the registered factory, `sdk` can name an external module, `api-keys` seeds inline credentials, and `config` passes provider-specific options.
|
||||
## Loading Providers from External Go Modules
|
||||
|
||||
### Loading providers from external SDK modules
|
||||
|
||||
To consume a provider shipped in another Go module, point the `sdk` field at the module path and import it for its registration side effect:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
To consume a provider shipped in another Go module, import it for its registration side effect:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before `BuildProviders` is called.
|
||||
|
||||
## Built-in Providers
|
||||
|
||||
The SDK ships with one provider out of the box:
|
||||
|
||||
- `config-api-key`: Validates API keys declared inline or under top-level `api-keys`. It accepts the key from `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, or the `?key=` query string and reports `ErrInvalidCredential` when no match is found.
|
||||
|
||||
Additional providers can be delivered by third-party packages. When a provider package is imported, it registers itself with `sdkaccess.RegisterProvider`.
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before you call `RegisteredProviders()` (or before `cliproxy.NewBuilder().Build()`).
|
||||
|
||||
### Metadata and auditing
|
||||
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, or `query-key`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, `query-key`, `query-auth-token`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
|
||||
## Writing Custom Providers
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To expose it to configuration, call `RegisterProvider` inside `init`. Provider factories receive the specific `AccessProvider` block plus the full root configuration for contextual needs.
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To make it available to the access manager, call `RegisterProvider` inside `init` with an initialized provider instance.
|
||||
|
||||
## Error Semantics
|
||||
|
||||
- `ErrNoCredentials`: no credentials were present or recognized by any provider.
|
||||
- `ErrInvalidCredential`: at least one provider processed the credentials but rejected them.
|
||||
- `ErrNotHandled`: instructs the manager to fall through to the next provider without affecting aggregate error reporting.
|
||||
- `NewNoCredentialsError()` (`AuthErrorCodeNoCredentials`): no credentials were present or recognized. (HTTP 401)
|
||||
- `NewInvalidCredentialError()` (`AuthErrorCodeInvalidCredential`): credentials were present but rejected. (HTTP 401)
|
||||
- `NewNotHandledError()` (`AuthErrorCodeNotHandled`): fall through to the next provider.
|
||||
- `NewInternalAuthError(message, cause)` (`AuthErrorCodeInternal`): transport/system failure. (HTTP 500)
|
||||
|
||||
Return custom errors to surface transport failures; they propagate immediately to the caller instead of being masked.
|
||||
Errors propagate immediately to the caller unless they are classified as `not_handled` / `no_credentials` / `invalid_credential` and can be aggregated by the manager.
|
||||
|
||||
## Integration with cliproxy Service
|
||||
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a preconfigured manager allows you to extend or override the default providers:
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a manager lets you reuse the same instance in your host process:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
The service reuses the manager for every inbound request, ensuring consistent authentication across embedded deployments and the canonical CLI binary.
|
||||
Register any custom providers (typically via blank imports) before calling `Build()` so they are present in the global registry snapshot.
|
||||
|
||||
### Hot reloading providers
|
||||
### Hot reloading
|
||||
|
||||
When configuration changes, rebuild providers and swap them into the manager:
|
||||
When configuration changes, refresh any config-backed providers and then reset the manager's provider chain:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
This mirrors the behaviour in `cliproxy.Service.refreshAccessProviders` and `api.Server.applyAccessConfig`, enabling runtime updates without restarting the process.
|
||||
This mirrors the behaviour in `internal/access.ApplyAccessProviders`, enabling runtime updates without restarting the process.
|
||||
|
||||
@@ -7,80 +7,71 @@
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
|
||||
|
||||
## Provider Registry
|
||||
|
||||
访问提供者是全局注册,然后以快照形式挂到 `Manager` 上:
|
||||
|
||||
- `RegisterProvider(type, provider)` 注册一个已经初始化好的 provider 实例。
|
||||
- 每个 `type` 第一次出现时会记录其注册顺序。
|
||||
- `RegisteredProviders()` 会按该顺序返回 provider 列表。
|
||||
|
||||
## 管理器生命周期
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
- `NewManager` 创建空管理器。
|
||||
- `SetProviders` 替换提供者切片并做防御性拷贝。
|
||||
- `Providers` 返回适合并发读取的快照。
|
||||
- `BuildProviders` 将 `config.Config` 中的访问配置转换成可运行的提供者。当配置没有显式声明但包含顶层 `api-keys` 时,会自动挂载内建的 `config-api-key` 提供者。
|
||||
|
||||
如果管理器本身为 `nil` 或未配置任何 provider,调用会返回 `nil, nil`,可视为关闭访问控制。
|
||||
|
||||
## 认证请求
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result carries provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Credentials were present but rejected.
|
||||
default:
|
||||
// Provider surfaced a transport-level failure.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` 按配置顺序遍历提供者。遇到成功立即返回,`ErrNotHandled` 会继续尝试下一个;若发现 `ErrNoCredentials` 或 `ErrInvalidCredential`,会在遍历结束后汇总给调用方。
|
||||
|
||||
若管理器本身为 `nil` 或尚未注册提供者,调用会返回 `nil, nil`,让调用方无需针对错误做额外分支即可关闭访问控制。
|
||||
`Manager.Authenticate` 会按顺序遍历 provider:遇到成功立即返回,`AuthErrorCodeNotHandled` 会继续尝试下一个;`AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` 会在遍历结束后汇总给调用方。
|
||||
|
||||
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
|
||||
|
||||
## 配置结构
|
||||
## 内建 `config-api-key` Provider
|
||||
|
||||
在 `config.yaml` 的 `auth.providers` 下定义访问提供者:
|
||||
代理内置一个访问提供者:
|
||||
|
||||
- `config-api-key`:校验 `config.yaml` 顶层的 `api-keys`。
|
||||
- 凭证来源:`Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key`、`?key=`、`?auth_token=`
|
||||
- 元数据:`Result.Metadata["source"]` 会写入匹配到的来源标识
|
||||
|
||||
在 CLI 服务端与 `sdk/cliproxy` 中,该 provider 会根据加载到的配置自动注册。
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
条目映射到 `config.AccessProvider`:`name` 指定实例名,`type` 选择注册的工厂,`sdk` 可引用第三方模块,`api-keys` 提供内联凭证,`config` 用于传递特定选项。
|
||||
## 引入外部 Go 模块提供者
|
||||
|
||||
### 引入外部 SDK 提供者
|
||||
|
||||
若要消费其它 Go 模块输出的访问提供者,可在配置里填写 `sdk` 字段并在代码中引入该包,利用其 `init` 注册过程:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
若要消费其它 Go 模块输出的访问提供者,直接用空白标识符导入以触发其 `init` 注册即可:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
通过空白标识符导入即可确保 `init` 调用,先于 `BuildProviders` 完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
## 内建提供者
|
||||
|
||||
当前 SDK 默认内置:
|
||||
|
||||
- `config-api-key`:校验配置中的 API Key。它从 `Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key` 以及查询参数 `?key=` 提取凭证,不匹配时抛出 `ErrInvalidCredential`。
|
||||
|
||||
导入第三方包即可通过 `sdkaccess.RegisterProvider` 注册更多类型。
|
||||
空白导入可确保 `init` 先执行,从而在你调用 `RegisteredProviders()`(或 `cliproxy.NewBuilder().Build()`)之前完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
### 元数据与审计
|
||||
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key` 或 `query-key`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key`、`query-key`、`query-auth-token`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
|
||||
## 编写自定义提供者
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中调用 `RegisterProvider` 暴露给配置层,工厂函数既能读取当前条目,也能访问完整根配置。
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中用已初始化实例调用 `RegisterProvider` 注册到全局 registry。
|
||||
|
||||
## 错误语义
|
||||
|
||||
- `ErrNoCredentials`:任何提供者都未识别到凭证。
|
||||
- `ErrInvalidCredential`:至少一个提供者处理了凭证但判定无效。
|
||||
- `ErrNotHandled`:告诉管理器跳到下一个提供者,不影响最终错误统计。
|
||||
- `NewNoCredentialsError()`(`AuthErrorCodeNoCredentials`):未提供或未识别到凭证。(HTTP 401)
|
||||
- `NewInvalidCredentialError()`(`AuthErrorCodeInvalidCredential`):凭证存在但校验失败。(HTTP 401)
|
||||
- `NewNotHandledError()`(`AuthErrorCodeNotHandled`):告诉管理器跳到下一个 provider。
|
||||
- `NewInternalAuthError(message, cause)`(`AuthErrorCodeInternal`):网络/系统错误。(HTTP 500)
|
||||
|
||||
自定义错误(例如网络异常)会马上冒泡返回。
|
||||
除可汇总的 `not_handled` / `no_credentials` / `invalid_credential` 外,其它错误会立即冒泡返回。
|
||||
|
||||
## 与 cliproxy 集成
|
||||
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果需要扩展内置行为,可传入自定义管理器:
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果希望在宿主进程里复用同一个 `Manager` 实例,可传入自定义管理器:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
服务会复用该管理器处理每一个入站请求,实现与 CLI 二进制一致的访问控制体验。
|
||||
请在调用 `Build()` 之前完成自定义 provider 的注册(通常通过空白导入触发 `init`),以确保它们被包含在全局 registry 的快照中。
|
||||
|
||||
### 动态热更新提供者
|
||||
|
||||
当配置发生变化时,可以重新构建提供者并替换当前列表:
|
||||
当配置发生变化时,刷新依赖配置的 provider,然后重置 manager 的 provider 链:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
这一流程与 `cliproxy.Service.refreshAccessProviders` 和 `api.Server.applyAccessConfig` 保持一致,避免为更新访问策略而重启进程。
|
||||
这一流程与 `internal/access.ApplyAccessProviders` 保持一致,避免为更新访问策略而重启进程。
|
||||
|
||||
2
go.mod
2
go.mod
@@ -22,6 +22,7 @@ require (
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.18.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -69,7 +70,6 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
|
||||
@@ -4,19 +4,28 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
var registerOnce sync.Once
|
||||
|
||||
// Register ensures the config-access provider is available to the access manager.
|
||||
func Register() {
|
||||
registerOnce.Do(func() {
|
||||
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
|
||||
})
|
||||
func Register(cfg *sdkconfig.SDKConfig) {
|
||||
if cfg == nil {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
keys := normalizeKeys(cfg.APIKeys)
|
||||
if len(keys) == 0 {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
sdkaccess.RegisterProvider(
|
||||
sdkaccess.AccessProviderTypeConfigAPIKey,
|
||||
newProvider(sdkaccess.DefaultAccessProviderName, keys),
|
||||
)
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
@@ -24,34 +33,31 @@ type provider struct {
|
||||
keys map[string]struct{}
|
||||
}
|
||||
|
||||
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = sdkconfig.DefaultAccessProviderName
|
||||
func newProvider(name string, keys []string) *provider {
|
||||
providerName := strings.TrimSpace(name)
|
||||
if providerName == "" {
|
||||
providerName = sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
keys := make(map[string]struct{}, len(cfg.APIKeys))
|
||||
for _, key := range cfg.APIKeys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
keys[key] = struct{}{}
|
||||
keySet := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
return &provider{name: name, keys: keys}, nil
|
||||
return &provider{name: providerName, keys: keySet}
|
||||
}
|
||||
|
||||
func (p *provider) Identifier() string {
|
||||
if p == nil || p.name == "" {
|
||||
return sdkconfig.DefaultAccessProviderName
|
||||
return sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
if p == nil {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if len(p.keys) == 0 {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||
@@ -63,7 +69,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||
}
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNoCredentialsError()
|
||||
}
|
||||
|
||||
apiKey := extractBearerToken(authHeader)
|
||||
@@ -94,7 +100,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
}
|
||||
}
|
||||
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
|
||||
func extractBearerToken(header string) string {
|
||||
@@ -110,3 +116,26 @@ func extractBearerToken(header string) string {
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmedKey := strings.TrimSpace(key)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[trimmedKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmedKey] = struct{}{}
|
||||
normalized = append(normalized, trimmedKey)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -17,26 +17,26 @@ import (
|
||||
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||
// removed compared to the previous configuration.
|
||||
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
||||
_ = oldCfg
|
||||
if newCfg == nil {
|
||||
return nil, nil, nil, nil, nil
|
||||
}
|
||||
|
||||
result = sdkaccess.RegisteredProviders()
|
||||
|
||||
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
||||
for _, provider := range existing {
|
||||
if provider == nil {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[provider.Identifier()] = provider
|
||||
existingMap[providerID] = provider
|
||||
}
|
||||
|
||||
oldCfgMap := accessProviderMap(oldCfg)
|
||||
newEntries := collectProviderEntries(newCfg)
|
||||
|
||||
result = make([]sdkaccess.Provider, 0, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(result))
|
||||
|
||||
isInlineProvider := func(id string) bool {
|
||||
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
|
||||
return strings.EqualFold(id, sdkaccess.DefaultAccessProviderName)
|
||||
}
|
||||
appendChange := func(list *[]string, id string) {
|
||||
if isInlineProvider(id) {
|
||||
@@ -45,85 +45,28 @@ func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Prov
|
||||
*list = append(*list, id)
|
||||
}
|
||||
|
||||
for _, providerCfg := range newEntries {
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
for _, provider := range result {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
finalIDs[providerID] = struct{}{}
|
||||
|
||||
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
isAliased := oldCfgProvider == providerCfg
|
||||
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
existingProvider, exists := existingMap[providerID]
|
||||
if !exists {
|
||||
appendChange(&added, providerID)
|
||||
continue
|
||||
}
|
||||
|
||||
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, ok := oldCfgMap[key]; ok {
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
|
||||
key := providerIdentifier(inline)
|
||||
if key != "" {
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
if providerConfigEqual(oldCfgProvider, inline) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
goto inlineDone
|
||||
}
|
||||
}
|
||||
}
|
||||
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else if _, hadOld := oldCfgMap[key]; hadOld {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
inlineDone:
|
||||
}
|
||||
|
||||
removedSet := make(map[string]struct{})
|
||||
for id := range existingMap {
|
||||
if _, ok := finalIDs[id]; !ok {
|
||||
if isInlineProvider(id) {
|
||||
continue
|
||||
}
|
||||
removedSet[id] = struct{}{}
|
||||
if !providerInstanceEqual(existingProvider, provider) {
|
||||
appendChange(&updated, providerID)
|
||||
}
|
||||
}
|
||||
|
||||
removed = make([]string, 0, len(removedSet))
|
||||
for id := range removedSet {
|
||||
removed = append(removed, id)
|
||||
for providerID := range existingMap {
|
||||
if _, exists := finalIDs[providerID]; exists {
|
||||
continue
|
||||
}
|
||||
appendChange(&removed, providerID)
|
||||
}
|
||||
|
||||
sort.Strings(added)
|
||||
@@ -142,6 +85,7 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
}
|
||||
|
||||
existing := manager.Providers()
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||
@@ -160,111 +104,24 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
|
||||
result := make(map[string]*sdkConfig.AccessProvider)
|
||||
if cfg == nil {
|
||||
return result
|
||||
}
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
result[key] = providerCfg
|
||||
}
|
||||
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
|
||||
if key := providerIdentifier(provider); key != "" {
|
||||
result[key] = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
|
||||
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
if key := providerIdentifier(providerCfg); key != "" {
|
||||
entries = append(entries, providerCfg)
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
|
||||
entries = append(entries, inline)
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
|
||||
func identifierFromProvider(provider sdkaccess.Provider) string {
|
||||
if provider == nil {
|
||||
return ""
|
||||
}
|
||||
if name := strings.TrimSpace(provider.Name); name != "" {
|
||||
return name
|
||||
}
|
||||
typ := strings.TrimSpace(provider.Type)
|
||||
if typ == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
|
||||
return sdkConfig.DefaultAccessProviderName
|
||||
}
|
||||
return typ
|
||||
return strings.TrimSpace(provider.Identifier())
|
||||
}
|
||||
|
||||
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
|
||||
func providerInstanceEqual(a, b sdkaccess.Provider) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == nil && b == nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
||||
if reflect.TypeOf(a) != reflect.TypeOf(b) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
||||
return false
|
||||
valueA := reflect.ValueOf(a)
|
||||
valueB := reflect.ValueOf(b)
|
||||
if valueA.Kind() == reflect.Pointer && valueB.Kind() == reflect.Pointer {
|
||||
return valueA.Pointer() == valueB.Pointer()
|
||||
}
|
||||
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) != len(b.Config) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSetEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
seen := make(map[string]int, len(a))
|
||||
for _, val := range a {
|
||||
seen[val]++
|
||||
}
|
||||
for _, val := range b {
|
||||
count := seen[val]
|
||||
if count == 0 {
|
||||
return false
|
||||
}
|
||||
if count == 1 {
|
||||
delete(seen, val)
|
||||
} else {
|
||||
seen[val] = count - 1
|
||||
}
|
||||
}
|
||||
return len(seen) == 0
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
@@ -1608,6 +1609,82 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
state := fmt.Sprintf("kmi-%d", time.Now().UnixNano())
|
||||
// Initialize Kimi auth service
|
||||
kimiAuth := kimi.NewKimiAuth(h.cfg)
|
||||
|
||||
// Generate authorization URL
|
||||
deviceFlow, errStartDeviceFlow := kimiAuth.StartDeviceFlow(ctx)
|
||||
if errStartDeviceFlow != nil {
|
||||
log.Errorf("Failed to generate authorization URL: %v", errStartDeviceFlow)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
if authURL == "" {
|
||||
authURL = deviceFlow.VerificationURI
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "kimi")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
authBundle, errWaitForAuthorization := kimiAuth.WaitForAuthorization(ctx, deviceFlow)
|
||||
if errWaitForAuthorization != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errWaitForAuthorization)
|
||||
return
|
||||
}
|
||||
|
||||
// Create token storage
|
||||
tokenStorage := kimiAuth.CreateTokenStorage(authBundle)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kimi",
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"refresh_token": authBundle.TokenData.RefreshToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
if authBundle.TokenData.ExpiresAt > 0 {
|
||||
expired := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
metadata["expired"] = expired
|
||||
}
|
||||
if strings.TrimSpace(authBundle.DeviceID) != "" {
|
||||
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kimi",
|
||||
FileName: fileName,
|
||||
Label: "Kimi User",
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Kimi services through this CLI")
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("kimi")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -109,14 +109,13 @@ func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.c
|
||||
func (h *Handler) PutAPIKeys(c *gin.Context) {
|
||||
h.putStringList(c, func(v []string) {
|
||||
h.cfg.APIKeys = append([]string(nil), v...)
|
||||
h.cfg.Access.Providers = nil
|
||||
}, nil)
|
||||
}
|
||||
func (h *Handler) PatchAPIKeys(c *gin.Context) {
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() {})
|
||||
}
|
||||
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() {})
|
||||
}
|
||||
|
||||
// gemini-api-key: []GeminiKey
|
||||
|
||||
@@ -66,7 +66,7 @@ func (rw *ResponseRewriter) Flush() {
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
|
||||
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewriteModelInResponse_TopLevel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseCreated(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_NoModelField(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: ""}
|
||||
|
||||
input := []byte(`{"model":"gpt-5.3-codex"}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification when originalModel is empty, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MultipleEvents(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
if string(result) == string(chunk) {
|
||||
t.Error("expected response.model to be rewritten in SSE stream")
|
||||
}
|
||||
if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) {
|
||||
t.Errorf("expected rewritten model in output, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "claude-opus-4.5"}
|
||||
|
||||
chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -623,6 +623,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
@@ -654,14 +655,17 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
// Synchronously ensure management.html is available with a detached context.
|
||||
// Control panel bootstrap should not be canceled by client disconnects.
|
||||
if !managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.File(filePath)
|
||||
@@ -951,10 +955,6 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
@@ -1033,14 +1033,10 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
||||
default:
|
||||
statusCode := err.HTTPStatusCode()
|
||||
if statusCode >= http.StatusInternalServerError {
|
||||
log.Errorf("authentication middleware error: %v", err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
||||
}
|
||||
}
|
||||
|
||||
396
internal/auth/kimi/kimi.go
Normal file
396
internal/auth/kimi/kimi.go
Normal file
@@ -0,0 +1,396 @@
|
||||
// Package kimi provides authentication and token management for Kimi (Moonshot AI) API.
|
||||
// It handles the RFC 8628 OAuth2 Device Authorization Grant flow for secure authentication.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// kimiClientID is Kimi Code's OAuth client ID.
|
||||
kimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098"
|
||||
// kimiOAuthHost is the OAuth server endpoint.
|
||||
kimiOAuthHost = "https://auth.kimi.com"
|
||||
// kimiDeviceCodeURL is the endpoint for requesting device codes.
|
||||
kimiDeviceCodeURL = kimiOAuthHost + "/api/oauth/device_authorization"
|
||||
// kimiTokenURL is the endpoint for exchanging device codes for tokens.
|
||||
kimiTokenURL = kimiOAuthHost + "/api/oauth/token"
|
||||
// KimiAPIBaseURL is the base URL for Kimi API requests.
|
||||
KimiAPIBaseURL = "https://api.kimi.com/coding"
|
||||
// defaultPollInterval is the default interval for polling token endpoint.
|
||||
defaultPollInterval = 5 * time.Second
|
||||
// maxPollDuration is the maximum time to wait for user authorization.
|
||||
maxPollDuration = 15 * time.Minute
|
||||
// refreshThresholdSeconds is when to refresh token before expiry (5 minutes).
|
||||
refreshThresholdSeconds = 300
|
||||
)
|
||||
|
||||
// KimiAuth handles Kimi authentication flow.
|
||||
type KimiAuth struct {
|
||||
deviceClient *DeviceFlowClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiAuth creates a new KimiAuth service instance.
|
||||
func NewKimiAuth(cfg *config.Config) *KimiAuth {
|
||||
return &KimiAuth{
|
||||
deviceClient: NewDeviceFlowClient(cfg),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// StartDeviceFlow initiates the device flow authentication.
|
||||
func (k *KimiAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
return k.deviceClient.RequestDeviceCode(ctx)
|
||||
}
|
||||
|
||||
// WaitForAuthorization polls for user authorization and returns the auth bundle.
|
||||
func (k *KimiAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiAuthBundle, error) {
|
||||
tokenData, err := k.deviceClient.PollForToken(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KimiAuthBundle{
|
||||
TokenData: tokenData,
|
||||
DeviceID: k.deviceClient.deviceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new KimiTokenStorage from auth bundle.
|
||||
func (k *KimiAuth) CreateTokenStorage(bundle *KimiAuthBundle) *KimiTokenStorage {
|
||||
expired := ""
|
||||
if bundle.TokenData.ExpiresAt > 0 {
|
||||
expired = time.Unix(bundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
return &KimiTokenStorage{
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
DeviceID: strings.TrimSpace(bundle.DeviceID),
|
||||
Expired: expired,
|
||||
Type: "kimi",
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceFlowClient handles the OAuth2 device flow for Kimi.
|
||||
type DeviceFlowClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
deviceID string
|
||||
}
|
||||
|
||||
// NewDeviceFlowClient creates a new device flow client.
|
||||
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
return NewDeviceFlowClientWithDeviceID(cfg, "")
|
||||
}
|
||||
|
||||
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
||||
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
resolvedDeviceID := strings.TrimSpace(deviceID)
|
||||
if resolvedDeviceID == "" {
|
||||
resolvedDeviceID = getOrCreateDeviceID()
|
||||
}
|
||||
return &DeviceFlowClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
deviceID: resolvedDeviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateDeviceID returns an in-memory device ID for the current authentication flow.
|
||||
func getOrCreateDeviceID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// getDeviceModel returns a device model string.
|
||||
func getDeviceModel() string {
|
||||
osName := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
|
||||
switch osName {
|
||||
case "darwin":
|
||||
return fmt.Sprintf("macOS %s", arch)
|
||||
case "windows":
|
||||
return fmt.Sprintf("Windows %s", arch)
|
||||
case "linux":
|
||||
return fmt.Sprintf("Linux %s", arch)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s", osName, arch)
|
||||
}
|
||||
}
|
||||
|
||||
// getHostname returns the machine hostname.
|
||||
func getHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// commonHeaders returns headers required for Kimi API requests.
|
||||
func (c *DeviceFlowClient) commonHeaders() map[string]string {
|
||||
return map[string]string{
|
||||
"X-Msh-Platform": "cli-proxy-api",
|
||||
"X-Msh-Version": "1.0.0",
|
||||
"X-Msh-Device-Name": getHostname(),
|
||||
"X-Msh-Device-Model": getDeviceModel(),
|
||||
"X-Msh-Device-Id": c.deviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestDeviceCode initiates the device flow by requesting a device code from Kimi.
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create device code request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: device code request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi device code: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read device code response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: device code request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var deviceCode DeviceCodeResponse
|
||||
if err = json.Unmarshal(bodyBytes, &deviceCode); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse device code response: %w", err)
|
||||
}
|
||||
|
||||
return &deviceCode, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
|
||||
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiTokenData, error) {
|
||||
if deviceCode == nil {
|
||||
return nil, fmt.Errorf("kimi: device code is nil")
|
||||
}
|
||||
|
||||
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||
if interval < defaultPollInterval {
|
||||
interval = defaultPollInterval
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(maxPollDuration)
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||
if codeDeadline.Before(deadline) {
|
||||
deadline = codeDeadline
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("kimi: context cancelled: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("kimi: device code expired")
|
||||
}
|
||||
|
||||
token, pollErr, shouldContinue := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
|
||||
if token != nil {
|
||||
return token, nil
|
||||
}
|
||||
if !shouldContinue {
|
||||
return nil, pollErr
|
||||
}
|
||||
// Continue polling
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeDeviceCode attempts to exchange the device code for an access token.
|
||||
// Returns (token, error, shouldContinue).
|
||||
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*KimiTokenData, error, bool) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create token request: %w", err), false
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: token request failed: %w", err), false
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read token response: %w", err), false
|
||||
}
|
||||
|
||||
// Parse response - Kimi returns 200 for both success and pending states
|
||||
var oauthResp struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse token response: %w", err), false
|
||||
}
|
||||
|
||||
if oauthResp.Error != "" {
|
||||
switch oauthResp.Error {
|
||||
case "authorization_pending":
|
||||
return nil, nil, true // Continue polling
|
||||
case "slow_down":
|
||||
return nil, nil, true // Continue polling (with increased interval handled by caller)
|
||||
case "expired_token":
|
||||
return nil, fmt.Errorf("kimi: device code expired"), false
|
||||
case "access_denied":
|
||||
return nil, fmt.Errorf("kimi: access denied by user"), false
|
||||
default:
|
||||
return nil, fmt.Errorf("kimi: OAuth error: %s - %s", oauthResp.Error, oauthResp.ErrorDescription), false
|
||||
}
|
||||
}
|
||||
|
||||
if oauthResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in response"), false
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if oauthResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(oauthResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: oauthResp.AccessToken,
|
||||
RefreshToken: oauthResp.RefreshToken,
|
||||
TokenType: oauthResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: oauthResp.Scope,
|
||||
}, nil, false
|
||||
}
|
||||
|
||||
// RefreshToken exchanges a refresh token for a new access token.
|
||||
func (c *DeviceFlowClient) RefreshToken(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", refreshToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi refresh token: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return nil, fmt.Errorf("kimi: refresh token rejected (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in refresh response")
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if tokenResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(tokenResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
}, nil
|
||||
}
|
||||
116
internal/auth/kimi/token.go
Normal file
116
internal/auth/kimi/token.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Package kimi provides authentication and token management functionality
|
||||
// for Kimi (Moonshot AI) services. It handles OAuth2 device flow token storage,
|
||||
// serialization, and retrieval for maintaining authenticated sessions with the Kimi API.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// KimiTokenStorage stores OAuth2 token information for Kimi API authentication.
|
||||
type KimiTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// DeviceID is the OAuth device flow identifier used for Kimi requests.
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
// Expired is the RFC3339 timestamp when the access token expires.
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
type KimiTokenData struct {
|
||||
// AccessToken is the OAuth2 access token.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ExpiresAt is the Unix timestamp when the token expires.
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// KimiAuthBundle bundles authentication data for storage.
|
||||
type KimiAuthBundle struct {
|
||||
// TokenData contains the OAuth token information.
|
||||
TokenData *KimiTokenData
|
||||
// DeviceID is the device identifier used during OAuth device flow.
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents Kimi's device code response.
|
||||
type DeviceCodeResponse struct {
|
||||
// DeviceCode is the device verification code.
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code the user must enter at the verification URI.
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is the URL where the user should enter the code.
|
||||
VerificationURI string `json:"verification_uri,omitempty"`
|
||||
// VerificationURIComplete is the URL with the code pre-filled.
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
// ExpiresIn is the number of seconds until the device code expires.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
// Interval is the minimum number of seconds to wait between polling requests.
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kimi token storage to a JSON file.
|
||||
func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kimi"
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if the token has expired.
|
||||
func (ts *KimiTokenStorage) IsExpired() bool {
|
||||
if ts.Expired == "" {
|
||||
return false // No expiry set, assume valid
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, ts.Expired)
|
||||
if err != nil {
|
||||
return true // Has expiry string but can't parse
|
||||
}
|
||||
// Consider expired if within refresh threshold
|
||||
return time.Now().Add(time.Duration(refreshThresholdSeconds) * time.Second).After(t)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the token should be refreshed.
|
||||
func (ts *KimiTokenStorage) NeedsRefresh() bool {
|
||||
if ts.RefreshToken == "" {
|
||||
return false // Can't refresh without refresh token
|
||||
}
|
||||
return ts.IsExpired()
|
||||
}
|
||||
@@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewIFlowAuthenticator(),
|
||||
sdkAuth.NewAntigravityAuthenticator(),
|
||||
sdkAuth.NewKimiAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
44
internal/cmd/kimi_login.go
Normal file
44
internal/cmd/kimi_login.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoKimiLogin triggers the OAuth device flow for Kimi (Moonshot AI) and saves tokens.
|
||||
// It initiates the device flow authentication, displays the verification URL for the user,
|
||||
// and waits for authorization before saving the tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy and auth directory settings
|
||||
// - options: Login options including browser behavior settings
|
||||
func DoKimiLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "kimi", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("Kimi authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("Kimi authentication successful!")
|
||||
}
|
||||
@@ -18,7 +18,10 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
const (
|
||||
DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
DefaultPprofAddr = "127.0.0.1:8316"
|
||||
)
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
@@ -41,6 +44,9 @@ type Config struct {
|
||||
// Debug enables or disables debug-level logging and other debug features.
|
||||
Debug bool `yaml:"debug" json:"debug"`
|
||||
|
||||
// Pprof config controls the optional pprof HTTP debug server.
|
||||
Pprof PprofConfig `yaml:"pprof" json:"pprof"`
|
||||
|
||||
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
|
||||
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
|
||||
|
||||
@@ -121,6 +127,14 @@ type TLSConfig struct {
|
||||
Key string `yaml:"key" json:"key"`
|
||||
}
|
||||
|
||||
// PprofConfig holds pprof HTTP server settings.
|
||||
type PprofConfig struct {
|
||||
// Enable toggles the pprof HTTP debug server.
|
||||
Enable bool `yaml:"enable" json:"enable"`
|
||||
// Addr is the host:port address for the pprof HTTP server.
|
||||
Addr string `yaml:"addr" json:"addr"`
|
||||
}
|
||||
|
||||
// RemoteManagement holds management API configuration under 'remote-management'.
|
||||
type RemoteManagement struct {
|
||||
// AllowRemote toggles remote (non-localhost) access to management API.
|
||||
@@ -479,14 +493,15 @@ func LoadConfig(configFile string) (*Config, error) {
|
||||
// If optional is true and the file is missing, it returns an empty Config.
|
||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Perform oauth-model-alias migration before loading config.
|
||||
// This migrates oauth-model-mappings to oauth-model-alias if needed.
|
||||
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// Log warning but don't fail - config loading should still work
|
||||
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
} else if migrated {
|
||||
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
}
|
||||
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// // Log warning but don't fail - config loading should still work
|
||||
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
// } else if migrated {
|
||||
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
// }
|
||||
|
||||
// Read the entire configuration file into memory.
|
||||
data, err := os.ReadFile(configFile)
|
||||
@@ -514,6 +529,8 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.Pprof.Enable = false
|
||||
cfg.Pprof.Addr = DefaultPprofAddr
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||
@@ -524,18 +541,21 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
var legacy legacyConfigData
|
||||
if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||
if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
}
|
||||
// NOTE: Startup legacy key migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// var legacy legacyConfigData
|
||||
// if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||
// if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// }
|
||||
|
||||
// Hash remote management key if plaintext is detected (nested)
|
||||
// We consider a value to be already hashed if it looks like a bcrypt hash ($2a$, $2b$, or $2y$ prefix).
|
||||
@@ -556,6 +576,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
cfg.Pprof.Addr = strings.TrimSpace(cfg.Pprof.Addr)
|
||||
if cfg.Pprof.Addr == "" {
|
||||
cfg.Pprof.Addr = DefaultPprofAddr
|
||||
}
|
||||
|
||||
if cfg.LogsMaxTotalSizeMB < 0 {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
}
|
||||
@@ -564,9 +589,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
@@ -591,17 +613,20 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Validate raw payload rules and drop invalid entries.
|
||||
cfg.SanitizePayloadRules()
|
||||
|
||||
if cfg.legacyMigrationPending {
|
||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
if !optional && configFile != "" {
|
||||
if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||
}
|
||||
fmt.Println("Legacy configuration normalized and persisted.")
|
||||
} else {
|
||||
fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||
}
|
||||
}
|
||||
// NOTE: Legacy migration persistence is intentionally disabled together with
|
||||
// startup legacy migration to keep startup read-only for config.yaml.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if cfg.legacyMigrationPending {
|
||||
// fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
// if !optional && configFile != "" {
|
||||
// if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||
// return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||
// }
|
||||
// fmt.Println("Legacy configuration normalized and persisted.")
|
||||
// } else {
|
||||
// fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||
// }
|
||||
// }
|
||||
|
||||
// Return the populated configuration struct.
|
||||
return &cfg, nil
|
||||
@@ -797,18 +822,6 @@ func normalizeModelPrefix(prefix string) string {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if len(cfg.APIKeys) == 0 {
|
||||
if provider := cfg.ConfigAPIKeyProvider(); provider != nil && len(provider.APIKeys) > 0 {
|
||||
cfg.APIKeys = append([]string(nil), provider.APIKeys...)
|
||||
}
|
||||
}
|
||||
cfg.Access.Providers = nil
|
||||
}
|
||||
|
||||
// looksLikeBcrypt returns true if the provided string appears to be a bcrypt hash.
|
||||
func looksLikeBcrypt(s string) bool {
|
||||
return len(s) > 4 && (s[:4] == "$2a$" || s[:4] == "$2b$" || s[:4] == "$2y$")
|
||||
@@ -896,7 +909,7 @@ func hashSecret(secret string) (string, error) {
|
||||
// SaveConfigPreserveComments writes the config back to YAML while preserving existing comments
|
||||
// and key ordering by loading the original file into a yaml.Node tree and updating values in-place.
|
||||
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
persistCfg := sanitizeConfigForPersist(cfg)
|
||||
persistCfg := cfg
|
||||
// Load original YAML as a node tree to preserve comments and ordering.
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
@@ -964,16 +977,6 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// SaveConfigPreserveCommentsUpdateNestedScalar updates a nested scalar key path like ["a","b"]
|
||||
// while preserving comments and positions.
|
||||
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
|
||||
@@ -1070,8 +1073,13 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
||||
|
||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||
// key order and comments of existing keys in dst. New keys are only added if their
|
||||
// value is non-zero to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
// value is non-zero and not a known default to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node, path ...[]string) {
|
||||
var currentPath []string
|
||||
if len(path) > 0 {
|
||||
currentPath = path[0]
|
||||
}
|
||||
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
@@ -1085,16 +1093,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
sk := src.Content[i]
|
||||
sv := src.Content[i+1]
|
||||
idx := findMapKeyIndex(dst, sk.Value)
|
||||
childPath := appendPath(currentPath, sk.Value)
|
||||
if idx >= 0 {
|
||||
// Merge into existing value node (always update, even to zero values)
|
||||
dv := dst.Content[idx+1]
|
||||
mergeNodePreserve(dv, sv)
|
||||
mergeNodePreserve(dv, sv, childPath)
|
||||
} else {
|
||||
// New key: only add if value is non-zero to avoid polluting config with defaults
|
||||
if isZeroValueNode(sv) {
|
||||
// New key: only add if value is non-zero and not a known default
|
||||
candidate := deepCopyNode(sv)
|
||||
pruneKnownDefaultsInNewNode(childPath, candidate)
|
||||
if isKnownDefaultValue(childPath, candidate) {
|
||||
continue
|
||||
}
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), candidate)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1102,7 +1113,12 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
// mergeNodePreserve merges src into dst for scalars, mappings and sequences while
|
||||
// reusing destination nodes to keep comments and anchors. For sequences, it updates
|
||||
// in-place by index.
|
||||
func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
func mergeNodePreserve(dst, src *yaml.Node, path ...[]string) {
|
||||
var currentPath []string
|
||||
if len(path) > 0 {
|
||||
currentPath = path[0]
|
||||
}
|
||||
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
@@ -1111,7 +1127,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
if dst.Kind != yaml.MappingNode {
|
||||
copyNodeShallow(dst, src)
|
||||
}
|
||||
mergeMappingPreserve(dst, src)
|
||||
mergeMappingPreserve(dst, src, currentPath)
|
||||
case yaml.SequenceNode:
|
||||
// Preserve explicit null style if dst was null and src is empty sequence
|
||||
if dst.Kind == yaml.ScalarNode && dst.Tag == "!!null" && len(src.Content) == 0 {
|
||||
@@ -1134,7 +1150,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
dst.Content[i] = deepCopyNode(src.Content[i])
|
||||
continue
|
||||
}
|
||||
mergeNodePreserve(dst.Content[i], src.Content[i])
|
||||
mergeNodePreserve(dst.Content[i], src.Content[i], currentPath)
|
||||
if dst.Content[i] != nil && src.Content[i] != nil &&
|
||||
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
|
||||
pruneMissingMapKeys(dst.Content[i], src.Content[i])
|
||||
@@ -1176,6 +1192,94 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// appendPath appends a key to the path, returning a new slice to avoid modifying the original.
|
||||
func appendPath(path []string, key string) []string {
|
||||
if len(path) == 0 {
|
||||
return []string{key}
|
||||
}
|
||||
newPath := make([]string, len(path)+1)
|
||||
copy(newPath, path)
|
||||
newPath[len(path)] = key
|
||||
return newPath
|
||||
}
|
||||
|
||||
// isKnownDefaultValue returns true if the given node at the specified path
|
||||
// represents a known default value that should not be written to the config file.
|
||||
// This prevents non-zero defaults from polluting the config.
|
||||
func isKnownDefaultValue(path []string, node *yaml.Node) bool {
|
||||
// First check if it's a zero value
|
||||
if isZeroValueNode(node) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match known non-zero defaults by exact dotted path.
|
||||
if len(path) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fullPath := strings.Join(path, ".")
|
||||
|
||||
// Check string defaults
|
||||
if node.Kind == yaml.ScalarNode && node.Tag == "!!str" {
|
||||
switch fullPath {
|
||||
case "pprof.addr":
|
||||
return node.Value == DefaultPprofAddr
|
||||
case "remote-management.panel-github-repository":
|
||||
return node.Value == DefaultPanelGitHubRepository
|
||||
case "routing.strategy":
|
||||
return node.Value == "round-robin"
|
||||
}
|
||||
}
|
||||
|
||||
// Check integer defaults
|
||||
if node.Kind == yaml.ScalarNode && node.Tag == "!!int" {
|
||||
switch fullPath {
|
||||
case "error-logs-max-files":
|
||||
return node.Value == "10"
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// pruneKnownDefaultsInNewNode removes default-valued descendants from a new node
|
||||
// before it is appended into the destination YAML tree.
|
||||
func pruneKnownDefaultsInNewNode(path []string, node *yaml.Node) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch node.Kind {
|
||||
case yaml.MappingNode:
|
||||
filtered := make([]*yaml.Node, 0, len(node.Content))
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
keyNode := node.Content[i]
|
||||
valueNode := node.Content[i+1]
|
||||
if keyNode == nil || valueNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
childPath := appendPath(path, keyNode.Value)
|
||||
if isKnownDefaultValue(childPath, valueNode) {
|
||||
continue
|
||||
}
|
||||
|
||||
pruneKnownDefaultsInNewNode(childPath, valueNode)
|
||||
if (valueNode.Kind == yaml.MappingNode || valueNode.Kind == yaml.SequenceNode) &&
|
||||
len(valueNode.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered = append(filtered, keyNode, valueNode)
|
||||
}
|
||||
node.Content = filtered
|
||||
case yaml.SequenceNode:
|
||||
for _, child := range node.Content {
|
||||
pruneKnownDefaultsInNewNode(path, child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||
// that should not be written as a new key to preserve config cleanliness.
|
||||
// For mappings and sequences, recursively checks if all children are zero values.
|
||||
|
||||
@@ -17,6 +17,7 @@ var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
@@ -30,6 +31,7 @@ func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
||||
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -131,6 +131,9 @@ func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-6-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
||||
|
||||
@@ -20,9 +20,6 @@ type SDKConfig struct {
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
|
||||
@@ -42,65 +39,3 @@ type StreamingConfig struct {
|
||||
// <= 0 disables bootstrap retries. Default is 0.
|
||||
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -131,7 +131,10 @@ func ResolveLogDirectory(cfg *config.Config) string {
|
||||
return logDir
|
||||
}
|
||||
if !isDirWritable(logDir) {
|
||||
authDir := strings.TrimSpace(cfg.AuthDir)
|
||||
authDir, err := util.ResolveAuthDir(cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to resolve auth-dir %q for log directory: %v", cfg.AuthDir, err)
|
||||
}
|
||||
if authDir != "" {
|
||||
logDir = filepath.Join(authDir, "logs")
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,6 +29,7 @@ const (
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
managementSyncMinInterval = 30 * time.Second
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
@@ -37,11 +39,10 @@ const ManagementFileName = managementAssetName
|
||||
var (
|
||||
lastUpdateCheckMu sync.Mutex
|
||||
lastUpdateCheckTime time.Time
|
||||
|
||||
currentConfigPtr atomic.Pointer[config.Config]
|
||||
disableControlPanel atomic.Bool
|
||||
schedulerOnce sync.Once
|
||||
schedulerConfigPath atomic.Value
|
||||
sfGroup singleflight.Group
|
||||
)
|
||||
|
||||
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
|
||||
@@ -50,16 +51,7 @@ func SetCurrentConfig(cfg *config.Config) {
|
||||
currentConfigPtr.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
prevDisabled := disableControlPanel.Load()
|
||||
currentConfigPtr.Store(cfg)
|
||||
disableControlPanel.Store(cfg.RemoteManagement.DisableControlPanel)
|
||||
|
||||
if prevDisabled && !cfg.RemoteManagement.DisableControlPanel {
|
||||
lastUpdateCheckMu.Lock()
|
||||
lastUpdateCheckTime = time.Time{}
|
||||
lastUpdateCheckMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
|
||||
@@ -92,7 +84,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
log.Debug("management asset auto-updater skipped: config not yet available")
|
||||
return
|
||||
}
|
||||
if disableControlPanel.Load() {
|
||||
if cfg.RemoteManagement.DisableControlPanel {
|
||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||
return
|
||||
}
|
||||
@@ -181,103 +173,106 @@ func FilePath(configFilePath string) string {
|
||||
}
|
||||
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
// It coalesces concurrent sync attempts and returns whether the asset exists after the sync attempt.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) bool {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if disableControlPanel.Load() {
|
||||
log.Debug("management asset sync skipped: control panel disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
staticDir = strings.TrimSpace(staticDir)
|
||||
if staticDir == "" {
|
||||
log.Debug("management asset sync skipped: empty static directory")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
timeSinceLastCheck := now.Sub(lastUpdateCheckTime)
|
||||
if timeSinceLastCheck < updateCheckInterval {
|
||||
_, _, _ = sfGroup.Do(localPath, func() (interface{}, error) {
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
timeSinceLastAttempt := now.Sub(lastUpdateCheckTime)
|
||||
if !lastUpdateCheckTime.IsZero() && timeSinceLastAttempt < managementSyncMinInterval {
|
||||
lastUpdateCheckMu.Unlock()
|
||||
log.Debugf(
|
||||
"management asset sync skipped by throttle: last attempt %v ago (interval %v)",
|
||||
timeSinceLastAttempt.Round(time.Second),
|
||||
managementSyncMinInterval,
|
||||
)
|
||||
return nil, nil
|
||||
}
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
log.Debugf("management asset update check skipped: last check was %v ago (interval: %v)", timeSinceLastCheck.Round(time.Second), updateCheckInterval)
|
||||
return
|
||||
}
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.WithError(err).Debug("failed to read local management asset hash")
|
||||
}
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
|
||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||
log.Debug("management asset is already up to date")
|
||||
return
|
||||
}
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.WithError(err).Debug("failed to read local management asset hash")
|
||||
}
|
||||
return
|
||||
localHash = ""
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
}
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to update management asset on disk")
|
||||
return
|
||||
}
|
||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||
log.Debug("management asset is already up to date")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to update management asset on disk")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
_, err := os.Stat(localPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
|
||||
@@ -15,7 +15,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4.5 Haiku",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
// Thinking: not supported for Haiku models
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
@@ -28,6 +28,18 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Object: "model",
|
||||
Created: 1770318000, // 2026-02-05
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Opus",
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 1000000,
|
||||
MaxCompletionTokens: 128000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Object: "model",
|
||||
@@ -716,6 +728,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex",
|
||||
Object: "model",
|
||||
Created: 1770307200,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex",
|
||||
Description: "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -803,6 +829,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -839,8 +866,50 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 128000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
}
|
||||
|
||||
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions
|
||||
func GetKimiModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "kimi-k2",
|
||||
Object: "model",
|
||||
Created: 1752192000, // 2025-07-11
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2",
|
||||
Description: "Kimi K2 - Moonshot AI's flagship coding model",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2-thinking",
|
||||
Object: "model",
|
||||
Created: 1762387200, // 2025-11-06
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2 Thinking",
|
||||
Description: "Kimi K2 Thinking - Extended reasoning model",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2.5",
|
||||
Object: "model",
|
||||
Created: 1769472000, // 2026-01-26
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2.5",
|
||||
Description: "Kimi K2.5 - Latest Moonshot AI coding model with improved capabilities",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -156,14 +156,14 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||
if len(wsResp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
|
||||
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
|
||||
}
|
||||
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
||||
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -199,7 +199,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -225,7 +225,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
var body bytes.Buffer
|
||||
if len(firstEvent.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(firstEvent.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
|
||||
body.Write(firstEvent.Payload)
|
||||
}
|
||||
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -244,7 +244,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
body.Write(event.Payload)
|
||||
}
|
||||
if event.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -274,12 +274,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
case wsrelay.MessageTypeStreamChunk:
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
filtered := FilterSSEUsageMetadata(event.Payload)
|
||||
if detail, ok := parseGeminiStreamUsage(filtered); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
@@ -293,9 +293,9 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
@@ -350,7 +350,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -364,7 +364,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||
if len(resp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
|
||||
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
|
||||
}
|
||||
if resp.Status < 200 || resp.Status >= 300 {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
||||
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
if totalTokens <= 0 {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||
}
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
@@ -393,12 +393,13 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, translatedPayload{}, err
|
||||
|
||||
@@ -133,12 +133,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -230,7 +231,7 @@ attemptLoop:
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
@@ -274,12 +275,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -433,7 +435,7 @@ attemptLoop:
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
@@ -665,12 +667,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -800,12 +803,12 @@ attemptLoop:
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
}
|
||||
@@ -872,7 +875,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
// Prepare payload once (doesn't depend on baseURL)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -1280,51 +1283,40 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
||||
payload = []byte(strJSON)
|
||||
} else {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
// Clean tool schemas for Gemini to remove unsupported JSON Schema keywords
|
||||
// without adding empty-schema placeholders.
|
||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||
payload = []byte(strJSON)
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high")
|
||||
payloadStr := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
systemInstructionPartsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
if useAntigravitySchema {
|
||||
payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr)
|
||||
} else {
|
||||
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||
}
|
||||
|
||||
if useAntigravitySchema {
|
||||
systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
|
||||
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(partResult.Raw))
|
||||
payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
} else {
|
||||
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
|
||||
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), strings.NewReader(payloadStr))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
@@ -1346,11 +1338,15 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
var payloadLog []byte
|
||||
if e.cfg != nil && e.cfg.RequestLog {
|
||||
payloadLog = []byte(payloadStr)
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Body: payloadLog,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
|
||||
@@ -100,12 +100,13 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -216,7 +217,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
opts.OriginalRequest,
|
||||
bodyForTranslation,
|
||||
data,
|
||||
¶m,
|
||||
@@ -240,12 +241,13 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -381,7 +383,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
opts.OriginalRequest,
|
||||
bodyForTranslation,
|
||||
bytes.Clone(line),
|
||||
¶m,
|
||||
@@ -411,7 +413,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
if !strings.HasPrefix(baseModel, "claude-3-5-haiku") {
|
||||
|
||||
@@ -27,6 +27,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.98.0"
|
||||
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
|
||||
// CodexExecutor is a stateless executor for Codex (OpenAI Responses API entrypoint).
|
||||
@@ -88,12 +93,13 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -176,7 +182,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, line, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -197,12 +203,13 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai-response")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -265,7 +272,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -286,12 +293,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -378,7 +386,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -397,7 +405,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -634,10 +642,10 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "0.21.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "codex_cli_rs/0.50.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
@@ -119,12 +119,13 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -223,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -272,12 +273,13 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -399,14 +401,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -428,12 +430,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -485,7 +487,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||
// Gemini CLI endpoint when iterating fallback variants.
|
||||
for range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -116,12 +116,13 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -203,7 +204,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -222,12 +223,13 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -318,12 +320,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseGeminiStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(payload), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -344,7 +346,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -318,12 +318,13 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -417,7 +418,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -432,12 +433,13 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -521,7 +523,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -536,12 +538,13 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -632,12 +635,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -660,12 +663,13 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -756,12 +760,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -781,7 +785,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -865,7 +869,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -4,12 +4,16 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
@@ -87,12 +91,13 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
@@ -163,7 +168,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -189,12 +194,13 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
@@ -274,7 +280,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -296,7 +302,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
enc, err := tokenizerForModel(baseModel)
|
||||
if err != nil {
|
||||
@@ -451,6 +457,20 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
r.Header.Set("User-Agent", iflowUserAgent)
|
||||
|
||||
// Generate session-id
|
||||
sessionID := "session-" + generateUUID()
|
||||
r.Header.Set("session-id", sessionID)
|
||||
|
||||
// Generate timestamp and signature
|
||||
timestamp := time.Now().UnixMilli()
|
||||
r.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp))
|
||||
|
||||
signature := createIFlowSignature(iflowUserAgent, sessionID, timestamp, apiKey)
|
||||
if signature != "" {
|
||||
r.Header.Set("x-iflow-signature", signature)
|
||||
}
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
@@ -458,6 +478,23 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// createIFlowSignature generates HMAC-SHA256 signature for iFlow API requests.
|
||||
// The signature payload format is: userAgent:sessionId:timestamp
|
||||
func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return ""
|
||||
}
|
||||
payload := fmt.Sprintf("%s:%s:%d", userAgent, sessionID, timestamp)
|
||||
h := hmac.New(sha256.New, []byte(apiKey))
|
||||
h.Write([]byte(payload))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// generateUUID generates a random UUID v4 string.
|
||||
func generateUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func iflowCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
618
internal/runtime/executor/kimi_executor.go
Normal file
618
internal/runtime/executor/kimi_executor.go
Normal file
@@ -0,0 +1,618 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// KimiExecutor is a stateless executor for Kimi API using OpenAI-compatible chat completions.
|
||||
type KimiExecutor struct {
|
||||
ClaudeExecutor
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiExecutor creates a new Kimi executor.
|
||||
func NewKimiExecutor(cfg *config.Config) *KimiExecutor { return &KimiExecutor{cfg: cfg} }
|
||||
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *KimiExecutor) Identifier() string { return "kimi" }
|
||||
|
||||
// PrepareRequest injects Kimi credentials into the outgoing HTTP request.
|
||||
func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token := kimiCreds(auth)
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Kimi credentials into the request and executes it.
|
||||
func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("kimi executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.Execute(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := bytes.Clone(originalPayloadSource)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
// Strip kimi- prefix for upstream API
|
||||
upstreamModel := stripKimiPrefix(baseModel)
|
||||
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||
}
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.ExecuteStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := bytes.Clone(originalPayloadSource)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
// Strip kimi- prefix for upstream API
|
||||
upstreamModel := stripKimiPrefix(baseModel)
|
||||
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||
}
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 1_048_576) // 1MB
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for Kimi requests.
|
||||
func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.CountTokens(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
out := body
|
||||
pending := make([]string, 0)
|
||||
patched := 0
|
||||
patchedReasoning := 0
|
||||
ambiguous := 0
|
||||
latestReasoning := ""
|
||||
hasLatestReasoning := false
|
||||
|
||||
removePending := func(id string) {
|
||||
for idx := range pending {
|
||||
if pending[idx] != id {
|
||||
continue
|
||||
}
|
||||
pending = append(pending[:idx], pending[idx+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
msgs := messages.Array()
|
||||
for msgIdx := range msgs {
|
||||
msg := msgs[msgIdx]
|
||||
role := strings.TrimSpace(msg.Get("role").String())
|
||||
switch role {
|
||||
case "assistant":
|
||||
reasoning := msg.Get("reasoning_content")
|
||||
if reasoning.Exists() {
|
||||
reasoningText := reasoning.String()
|
||||
if strings.TrimSpace(reasoningText) != "" {
|
||||
latestReasoning = reasoningText
|
||||
hasLatestReasoning = true
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if !toolCalls.Exists() || !toolCalls.IsArray() || len(toolCalls.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !reasoning.Exists() || strings.TrimSpace(reasoning.String()) == "" {
|
||||
reasoningText := fallbackAssistantReasoning(msg, hasLatestReasoning, latestReasoning)
|
||||
path := fmt.Sprintf("messages.%d.reasoning_content", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, reasoningText)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to set assistant reasoning_content: %w", err)
|
||||
}
|
||||
out = next
|
||||
patchedReasoning++
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls.Array() {
|
||||
id := strings.TrimSpace(tc.Get("id").String())
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
pending = append(pending, id)
|
||||
}
|
||||
case "tool":
|
||||
toolCallID := strings.TrimSpace(msg.Get("tool_call_id").String())
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(msg.Get("call_id").String())
|
||||
if toolCallID != "" {
|
||||
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to set tool_call_id from call_id: %w", err)
|
||||
}
|
||||
out = next
|
||||
patched++
|
||||
}
|
||||
}
|
||||
if toolCallID == "" {
|
||||
if len(pending) == 1 {
|
||||
toolCallID = pending[0]
|
||||
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to infer tool_call_id: %w", err)
|
||||
}
|
||||
out = next
|
||||
patched++
|
||||
} else if len(pending) > 1 {
|
||||
ambiguous++
|
||||
}
|
||||
}
|
||||
if toolCallID != "" {
|
||||
removePending(toolCallID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if patched > 0 || patchedReasoning > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"patched_tool_messages": patched,
|
||||
"patched_reasoning_messages": patchedReasoning,
|
||||
}).Debug("kimi executor: normalized tool message fields")
|
||||
}
|
||||
if ambiguous > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"ambiguous_tool_messages": ambiguous,
|
||||
"pending_tool_calls": len(pending),
|
||||
}).Warn("kimi executor: tool messages missing tool_call_id with ambiguous candidates")
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
|
||||
if hasLatest && strings.TrimSpace(latest) != "" {
|
||||
return latest
|
||||
}
|
||||
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
if text := strings.TrimSpace(content.String()); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
if content.IsArray() {
|
||||
parts := make([]string, 0, len(content.Array()))
|
||||
for _, item := range content.Array() {
|
||||
text := strings.TrimSpace(item.Get("text").String())
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
return "[reasoning unavailable]"
|
||||
}
|
||||
|
||||
// Refresh refreshes the Kimi token using the refresh token.
|
||||
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
log.Debugf("kimi executor: refresh called")
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("kimi executor: auth is nil")
|
||||
}
|
||||
// Expect refresh_token in metadata for OAuth-based accounts
|
||||
var refreshToken string
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
refreshToken = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
// Nothing to refresh
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
client := kimiauth.NewDeviceFlowClientWithDeviceID(e.cfg, resolveKimiDeviceID(auth))
|
||||
td, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["access_token"] = td.AccessToken
|
||||
if td.RefreshToken != "" {
|
||||
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||
}
|
||||
if td.ExpiresAt > 0 {
|
||||
exp := time.Unix(td.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
auth.Metadata["expired"] = exp
|
||||
}
|
||||
auth.Metadata["type"] = "kimi"
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
auth.Metadata["last_refresh"] = now
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// applyKimiHeaders sets required headers for Kimi API requests.
|
||||
// Headers match kimi-cli client for compatibility.
|
||||
func applyKimiHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
// Match kimi-cli headers exactly
|
||||
r.Header.Set("User-Agent", "KimiCLI/1.10.6")
|
||||
r.Header.Set("X-Msh-Platform", "kimi_cli")
|
||||
r.Header.Set("X-Msh-Version", "1.10.6")
|
||||
r.Header.Set("X-Msh-Device-Name", getKimiHostname())
|
||||
r.Header.Set("X-Msh-Device-Model", getKimiDeviceModel())
|
||||
r.Header.Set("X-Msh-Device-Id", getKimiDeviceID())
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
}
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
func resolveKimiDeviceIDFromAuth(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
deviceIDRaw, ok := auth.Metadata["device_id"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
deviceID, ok := deviceIDRaw.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(deviceID)
|
||||
}
|
||||
|
||||
func resolveKimiDeviceIDFromStorage(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
storage, ok := auth.Storage.(*kimiauth.KimiTokenStorage)
|
||||
if !ok || storage == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(storage.DeviceID)
|
||||
}
|
||||
|
||||
func resolveKimiDeviceID(auth *cliproxyauth.Auth) string {
|
||||
deviceID := resolveKimiDeviceIDFromAuth(auth)
|
||||
if deviceID != "" {
|
||||
return deviceID
|
||||
}
|
||||
return resolveKimiDeviceIDFromStorage(auth)
|
||||
}
|
||||
|
||||
func applyKimiHeadersWithAuth(r *http.Request, token string, stream bool, auth *cliproxyauth.Auth) {
|
||||
applyKimiHeaders(r, token, stream)
|
||||
|
||||
if deviceID := resolveKimiDeviceID(auth); deviceID != "" {
|
||||
r.Header.Set("X-Msh-Device-Id", deviceID)
|
||||
}
|
||||
}
|
||||
|
||||
// getKimiHostname returns the machine hostname.
|
||||
func getKimiHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// getKimiDeviceModel returns a device model string matching kimi-cli format.
|
||||
func getKimiDeviceModel() string {
|
||||
return fmt.Sprintf("%s %s", runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
// getKimiDeviceID returns a stable device ID, matching kimi-cli storage location.
|
||||
func getKimiDeviceID() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "cli-proxy-api-device"
|
||||
}
|
||||
// Check kimi-cli's device_id location first (platform-specific)
|
||||
var kimiShareDir string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
kimiShareDir = filepath.Join(homeDir, "Library", "Application Support", "kimi")
|
||||
case "windows":
|
||||
appData := os.Getenv("APPDATA")
|
||||
if appData == "" {
|
||||
appData = filepath.Join(homeDir, "AppData", "Roaming")
|
||||
}
|
||||
kimiShareDir = filepath.Join(appData, "kimi")
|
||||
default: // linux and other unix-like
|
||||
kimiShareDir = filepath.Join(homeDir, ".local", "share", "kimi")
|
||||
}
|
||||
deviceIDPath := filepath.Join(kimiShareDir, "device_id")
|
||||
if data, err := os.ReadFile(deviceIDPath); err == nil {
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
return "cli-proxy-api-device"
|
||||
}
|
||||
|
||||
// kimiCreds extracts the access token from auth.
|
||||
func kimiCreds(a *cliproxyauth.Auth) (token string) {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
// Check metadata first (OAuth flow stores tokens here)
|
||||
if a.Metadata != nil {
|
||||
if v, ok := a.Metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
// Fallback to attributes (API key style)
|
||||
if a.Attributes != nil {
|
||||
if v := a.Attributes["access_token"]; v != "" {
|
||||
return v
|
||||
}
|
||||
if v := a.Attributes["api_key"]; v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// stripKimiPrefix removes the "kimi-" prefix from model names for the upstream API.
|
||||
func stripKimiPrefix(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if strings.HasPrefix(strings.ToLower(model), "kimi-") {
|
||||
return model[5:]
|
||||
}
|
||||
return model
|
||||
}
|
||||
205
internal/runtime/executor/kimi_executor_test.go
Normal file
205
internal/runtime/executor/kimi_executor_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_UsesCallIDFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"list_directory:1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"tool","call_id":"list_directory:1","content":"[]"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "list_directory:1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "list_directory:1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InferSinglePendingID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_123","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||
{"role":"tool","content":"file-content"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "call_123" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_AmbiguousMissingIDIsNotInferred(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[
|
||||
{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}},
|
||||
{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}
|
||||
]},
|
||||
{"role":"tool","content":"result-without-id"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out, "messages.1.tool_call_id").Exists() {
|
||||
t.Fatalf("messages.1.tool_call_id should be absent for ambiguous case, got %q", gjson.GetBytes(out, "messages.1.tool_call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesExistingToolCallID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"call_1","call_id":"different-id","content":"result"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "call_1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InheritsPreviousReasoningForAssistantToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"plan","reasoning_content":"previous reasoning"},
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.reasoning_content").String()
|
||||
if got != "previous reasoning" {
|
||||
t.Fatalf("messages.1.reasoning_content = %q, want %q", got, "previous reasoning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InsertsFallbackReasoningWhenMissing(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
reasoning := gjson.GetBytes(out, "messages.0.reasoning_content")
|
||||
if !reasoning.Exists() {
|
||||
t.Fatalf("messages.0.reasoning_content should exist")
|
||||
}
|
||||
if reasoning.String() != "[reasoning unavailable]" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", reasoning.String(), "[reasoning unavailable]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_UsesContentAsReasoningFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":[{"type":"text","text":"first line"},{"type":"text","text":"second line"}],"tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "first line\nsecond line" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "first line\nsecond line")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_ReplacesEmptyReasoningContent(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"assistant summary","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":""}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "assistant summary" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "assistant summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesExistingAssistantReasoning(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"keep me"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "keep me" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "keep me")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"r1"},
|
||||
{"role":"tool","call_id":"call_1","content":"[]"},
|
||||
{"role":"assistant","tool_calls":[{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||
{"role":"tool","call_id":"call_2","content":"file"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.3.tool_call_id").String(); got != "call_2" {
|
||||
t.Fatalf("messages.3.tool_call_id = %q, want %q", got, "call_2")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.2.reasoning_content").String(); got != "r1" {
|
||||
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
|
||||
}
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
|
||||
writeHeaders(builder, info.Headers)
|
||||
builder.WriteString("\nBody:\n")
|
||||
if len(info.Body) > 0 {
|
||||
builder.WriteString(string(bytes.Clone(info.Body)))
|
||||
builder.WriteString(string(info.Body))
|
||||
} else {
|
||||
builder.WriteString("<empty>")
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(bytes.Clone(chunk))
|
||||
data := bytes.TrimSpace(chunk)
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -88,12 +88,13 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
endpoint = "/responses/compact"
|
||||
}
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
if opts.Alt == "responses/compact" {
|
||||
@@ -170,7 +171,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, body, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -189,12 +190,13 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
@@ -283,7 +285,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
|
||||
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
|
||||
// Pass through translator; it yields one or more chunks for the target schema.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -304,7 +306,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
modelForCounting := baseModel
|
||||
|
||||
|
||||
@@ -81,12 +81,13 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -150,7 +151,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -171,12 +172,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -253,12 +255,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
}
|
||||
@@ -276,7 +278,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
|
||||
@@ -7,5 +7,6 @@ import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,9 @@ import (
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// gcInterval defines minimum time between garbage collection runs.
|
||||
const gcInterval = 5 * time.Minute
|
||||
|
||||
// GitTokenStore persists token records and auth metadata using git as the backing storage.
|
||||
type GitTokenStore struct {
|
||||
mu sync.Mutex
|
||||
@@ -31,6 +34,7 @@ type GitTokenStore struct {
|
||||
remote string
|
||||
username string
|
||||
password string
|
||||
lastGC time.Time
|
||||
}
|
||||
|
||||
// NewGitTokenStore creates a token store that saves credentials to disk through the
|
||||
@@ -613,6 +617,7 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
|
||||
} else if errRewrite := s.rewriteHeadAsSingleCommit(repo, headRef.Name(), commitHash, message, signature); errRewrite != nil {
|
||||
return errRewrite
|
||||
}
|
||||
s.maybeRunGC(repo)
|
||||
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
|
||||
if errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||
return nil
|
||||
@@ -652,6 +657,23 @@ func (s *GitTokenStore) rewriteHeadAsSingleCommit(repo *git.Repository, branch p
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GitTokenStore) maybeRunGC(repo *git.Repository) {
|
||||
now := time.Now()
|
||||
if now.Sub(s.lastGC) < gcInterval {
|
||||
return
|
||||
}
|
||||
s.lastGC = now
|
||||
|
||||
pruneOpts := git.PruneOptions{
|
||||
OnlyObjectsOlderThan: now,
|
||||
Handler: repo.DeleteObject,
|
||||
}
|
||||
if err := repo.Prune(pruneOpts); err != nil && !errors.Is(err, git.ErrLooseObjectsNotSupported) {
|
||||
return
|
||||
}
|
||||
_ = repo.RepackObjects(&git.RepackConfig{})
|
||||
}
|
||||
|
||||
// PersistConfig commits and pushes configuration changes to git.
|
||||
func (s *GitTokenStore) PersistConfig(_ context.Context) error {
|
||||
if err := s.EnsureRepository(); err != nil {
|
||||
|
||||
@@ -18,6 +18,7 @@ var providerAppliers = map[string]ProviderApplier{
|
||||
"codex": nil,
|
||||
"iflow": nil,
|
||||
"antigravity": nil,
|
||||
"kimi": nil,
|
||||
}
|
||||
|
||||
// GetProviderApplier returns the ProviderApplier for the given provider name.
|
||||
@@ -326,6 +327,9 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
|
||||
return config
|
||||
}
|
||||
return extractOpenAIConfig(body)
|
||||
case "kimi":
|
||||
// Kimi uses OpenAI-compatible reasoning_effort format
|
||||
return extractOpenAIConfig(body)
|
||||
default:
|
||||
return ThinkingConfig{}
|
||||
}
|
||||
@@ -388,7 +392,12 @@ func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
|
||||
}
|
||||
|
||||
// Check thinkingLevel first (Gemini 3 format takes precedence)
|
||||
if level := gjson.GetBytes(body, prefix+".thinkingLevel"); level.Exists() {
|
||||
level := gjson.GetBytes(body, prefix+".thinkingLevel")
|
||||
if !level.Exists() {
|
||||
// Google official Gemini Python SDK sends snake_case field names
|
||||
level = gjson.GetBytes(body, prefix+".thinking_level")
|
||||
}
|
||||
if level.Exists() {
|
||||
value := level.String()
|
||||
switch value {
|
||||
case "none":
|
||||
@@ -401,7 +410,12 @@ func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
|
||||
}
|
||||
|
||||
// Check thinkingBudget (Gemini 2.5 format)
|
||||
if budget := gjson.GetBytes(body, prefix+".thinkingBudget"); budget.Exists() {
|
||||
budget := gjson.GetBytes(body, prefix+".thinkingBudget")
|
||||
if !budget.Exists() {
|
||||
// Google official Gemini Python SDK sends snake_case field names
|
||||
budget = gjson.GetBytes(body, prefix+".thinking_budget")
|
||||
}
|
||||
if budget.Exists() {
|
||||
value := int(budget.Int())
|
||||
switch value {
|
||||
case 0:
|
||||
|
||||
@@ -94,8 +94,10 @@ func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig, m
|
||||
}
|
||||
|
||||
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -114,28 +116,30 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo, isClaude bool) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
|
||||
// Apply Claude-specific constraints
|
||||
// Apply Claude-specific constraints first to get the final budget value
|
||||
if isClaude && modelInfo != nil {
|
||||
budget, result = a.normalizeClaudeBudget(budget, result, modelInfo)
|
||||
// Check if budget was removed entirely
|
||||
@@ -144,6 +148,37 @@ func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
|
||||
@@ -118,8 +118,10 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
|
||||
// ValidateConfig sets config.Level to the lowest level when ModeNone + Budget > 0.
|
||||
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -138,29 +140,58 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
// ModeNone semantics:
|
||||
// - ModeNone + Budget=0: completely disable thinking
|
||||
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
|
||||
// When ZeroAllowed=false, ValidateConfig clamps Budget to Min while preserving ModeNone.
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
|
||||
@@ -79,8 +79,10 @@ func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) (
|
||||
}
|
||||
|
||||
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -99,25 +101,58 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
|
||||
126
internal/thinking/provider/kimi/apply.go
Normal file
126
internal/thinking/provider/kimi/apply.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
||||
//
|
||||
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
|
||||
// (low/medium/high). The provider strips any existing thinking config and applies
|
||||
// the unified ThinkingConfig in OpenAI format.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// Applier implements thinking.ProviderApplier for Kimi models.
|
||||
//
|
||||
// Kimi-specific behavior:
|
||||
// - Output format: reasoning_effort (string: low/medium/high)
|
||||
// - Uses OpenAI-compatible format
|
||||
// - Supports budget-to-level conversion
|
||||
type Applier struct{}
|
||||
|
||||
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||
|
||||
// NewApplier creates a new Kimi thinking applier.
|
||||
func NewApplier() *Applier {
|
||||
return &Applier{}
|
||||
}
|
||||
|
||||
func init() {
|
||||
thinking.RegisterProvider("kimi", NewApplier())
|
||||
}
|
||||
|
||||
// Apply applies thinking configuration to Kimi request body.
|
||||
//
|
||||
// Expected output format:
|
||||
//
|
||||
// {
|
||||
// "reasoning_effort": "high"
|
||||
// }
|
||||
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||
if thinking.IsUserDefinedModel(modelInfo) {
|
||||
return applyCompatibleKimi(body, config)
|
||||
}
|
||||
if modelInfo.Thinking == nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var effort string
|
||||
switch config.Mode {
|
||||
case thinking.ModeLevel:
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
// Kimi uses "none" to disable thinking
|
||||
effort = string(thinking.LevelNone)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level using threshold mapping
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
effort = level
|
||||
case thinking.ModeAuto:
|
||||
// Auto mode maps to "auto" effort
|
||||
effort = string(thinking.LevelAuto)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if effort == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
||||
func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var effort string
|
||||
switch config.Mode {
|
||||
case thinking.ModeLevel:
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
effort = string(thinking.LevelNone)
|
||||
if config.Level != "" {
|
||||
effort = string(config.Level)
|
||||
}
|
||||
case thinking.ModeAuto:
|
||||
effort = string(thinking.LevelAuto)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
effort = level
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -6,7 +6,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
@@ -37,7 +36,7 @@ import (
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
enableThoughtTranslate := true
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
@@ -115,7 +114,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
|
||||
if len(arrayClientSignatures) == 2 {
|
||||
if modelName == arrayClientSignatures[0] {
|
||||
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
|
||||
clientSignature = arrayClientSignatures[1]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -34,7 +33,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -28,7 +27,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToAntigravity(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
|
||||
return ConvertGeminiRequestToAntigravity(modelName, rawJSON, stream)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -30,7 +28,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertGeminiCLIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
// Extract the inner request object and promote it to the top level
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -46,7 +45,7 @@ var (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
@@ -116,7 +115,11 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
level := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
switch level {
|
||||
case "":
|
||||
@@ -132,23 +135,29 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -44,7 +43,7 @@ var (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -32,7 +31,7 @@ var (
|
||||
// - max_output_tokens -> max_tokens
|
||||
// - stream passthrough via parameter
|
||||
func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -35,7 +34,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in internal client format
|
||||
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
template := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
|
||||
@@ -113,10 +113,10 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
||||
stopReason := rootResult.Get("response.stop_reason").String()
|
||||
if stopReason != "" {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||
} else if p {
|
||||
if p {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
||||
} else if stopReason == "max_tokens" || stopReason == "stop" {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -30,7 +28,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiCLIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -37,7 +36,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base template
|
||||
out := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
@@ -243,19 +242,30 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
|
||||
// Convert Gemini thinkingConfig to Codex reasoning.effort.
|
||||
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
|
||||
effortSet := false
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -29,7 +27,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in OpenAI Responses API format
|
||||
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Start with empty JSON object
|
||||
out := `{"instructions":""}`
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -9,7 +8,13 @@ import (
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
inputResult := gjson.GetBytes(rawJSON, "input")
|
||||
if inputResult.Type == gjson.String {
|
||||
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "store", false)
|
||||
|
||||
@@ -35,7 +35,7 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
@@ -116,6 +116,19 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
part, _ = sjson.Set(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
|
||||
case "image":
|
||||
source := contentResult.Get("source")
|
||||
if source.Get("type").String() == "base64" {
|
||||
mimeType := source.Get("media_type").String()
|
||||
data := source.Get("data").String()
|
||||
if mimeType != "" && data != "" {
|
||||
part := `{"inlineData":{"mime_type":"","data":""}}`
|
||||
part, _ = sjson.Set(part, "inlineData.mime_type", mimeType)
|
||||
part, _ = sjson.Set(part, "inlineData.data", data)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -33,7 +32,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -28,7 +27,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -77,14 +78,20 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
finishReason := ""
|
||||
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() {
|
||||
finishReason = stopReasonResult.String()
|
||||
}
|
||||
if finishReason == "" {
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
finishReason = finishReasonResult.String()
|
||||
}
|
||||
}
|
||||
finishReason = strings.ToLower(finishReason)
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
@@ -97,6 +104,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("gemini-cli openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
@@ -187,6 +202,12 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
|
||||
// Only pass through specific finish reasons
|
||||
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
}
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/gemini"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToGeminiCLI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
|
||||
return ConvertGeminiRequestToGeminiCLI(modelName, rawJSON, stream)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request in Gemini CLI format.
|
||||
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -19,7 +18,7 @@ import (
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||
func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -19,7 +18,7 @@ import (
|
||||
//
|
||||
// It keeps the payload otherwise unchanged.
|
||||
func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Fast path: if no contents field, only attach safety settings
|
||||
contents := gjson.GetBytes(rawJSON, "contents")
|
||||
if !contents.Exists() {
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -28,7 +27,7 @@ const geminiFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"contents":[]}`)
|
||||
|
||||
|
||||
@@ -129,11 +129,16 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
candidateIndex := int(candidate.Get("index").Int())
|
||||
template, _ = sjson.Set(template, "choices.0.index", candidateIndex)
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
finishReason := ""
|
||||
if stopReasonResult := gjson.GetBytes(rawJSON, "stop_reason"); stopReasonResult.Exists() {
|
||||
finishReason = stopReasonResult.String()
|
||||
}
|
||||
if finishReason == "" {
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
finishReason = finishReasonResult.String()
|
||||
}
|
||||
}
|
||||
finishReason = strings.ToLower(finishReason)
|
||||
|
||||
partsResult := candidate.Get("content.parts")
|
||||
hasFunctionCall := false
|
||||
@@ -225,6 +230,12 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
} else if finishReason != "" {
|
||||
// Only pass through specific finish reasons
|
||||
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
}
|
||||
}
|
||||
|
||||
responseStrings = append(responseStrings, template)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -12,7 +11,7 @@ import (
|
||||
const geminiResponsesThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
// Note: modelName and stream parameters are part of the fixed method signature
|
||||
_ = modelName // Unused but required by interface
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
@@ -18,7 +17,7 @@ import (
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base OpenAI Chat Completions API template
|
||||
out := `{"model":"","messages":[]}`
|
||||
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -17,7 +15,7 @@ import (
|
||||
// It extracts the model name, generation config, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertGeminiCLIRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
if gjson.GetBytes(rawJSON, "systemInstruction").Exists() {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -21,7 +20,7 @@ import (
|
||||
// It extracts the model name, generation config, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base OpenAI Chat Completions API template
|
||||
out := `{"model":"","messages":[]}`
|
||||
|
||||
@@ -83,16 +82,27 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Map Gemini thinkingConfig to OpenAI reasoning_effort.
|
||||
// Always perform conversion to support allowCompat models that may not be in registry
|
||||
// Always perform conversion to support allowCompat models that may not be in registry.
|
||||
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -25,7 +24,7 @@ func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool)
|
||||
// If there's an error, return the original JSON or handle the error appropriately.
|
||||
// For now, we'll return the original, but in a real scenario, logging or a more robust error
|
||||
// handling mechanism would be needed.
|
||||
return bytes.Clone(inputRawJSON)
|
||||
return inputRawJSON
|
||||
}
|
||||
return updatedJSON
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -28,7 +27,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in OpenAI chat completions format
|
||||
func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base OpenAI chat completions template with default values
|
||||
out := `{"model":"","messages":[],"stream":false}`
|
||||
|
||||
@@ -68,7 +67,10 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
case "message", "":
|
||||
// Handle regular message conversion
|
||||
role := item.Get("role").String()
|
||||
message := `{"role":"","content":""}`
|
||||
if role == "developer" {
|
||||
role = "user"
|
||||
}
|
||||
message := `{"role":"","content":[]}`
|
||||
message, _ = sjson.Set(message, "role", role)
|
||||
|
||||
if content := item.Get("content"); content.Exists() && content.IsArray() {
|
||||
@@ -82,20 +84,16 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case "input_text":
|
||||
case "input_text", "output_text":
|
||||
text := contentItem.Get("text").String()
|
||||
if messageContent != "" {
|
||||
messageContent += "\n" + text
|
||||
} else {
|
||||
messageContent = text
|
||||
}
|
||||
case "output_text":
|
||||
text := contentItem.Get("text").String()
|
||||
if messageContent != "" {
|
||||
messageContent += "\n" + text
|
||||
} else {
|
||||
messageContent = text
|
||||
}
|
||||
contentPart := `{"type":"text","text":""}`
|
||||
contentPart, _ = sjson.Set(contentPart, "text", text)
|
||||
message, _ = sjson.SetRaw(message, "content.-1", contentPart)
|
||||
case "input_image":
|
||||
imageURL := contentItem.Get("image_url").String()
|
||||
contentPart := `{"type":"image_url","image_url":{"url":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
|
||||
message, _ = sjson.SetRaw(message, "content.-1", contentPart)
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -167,7 +165,8 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
// Only function tools need structural conversion because Chat Completions nests details under "function".
|
||||
toolType := tool.Get("type").String()
|
||||
if toolType != "" && toolType != "function" && tool.IsObject() {
|
||||
chatCompletionsTools = append(chatCompletionsTools, tool.Value())
|
||||
// Almost all providers lack built-in tools, so we just ignore them.
|
||||
// chatCompletionsTools = append(chatCompletionsTools, tool.Value())
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ func TestIsClaudeThinkingModel(t *testing.T) {
|
||||
// Claude thinking models - should return true
|
||||
{"claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
|
||||
{"claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
|
||||
{"claude-opus-4-6-thinking", "claude-opus-4-6-thinking", true},
|
||||
{"Claude-Sonnet-Thinking uppercase", "Claude-Sonnet-4-5-Thinking", true},
|
||||
{"claude thinking mixed case", "Claude-THINKING-Model", true},
|
||||
|
||||
|
||||
@@ -61,14 +61,20 @@ func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
|
||||
|
||||
// removeKeywords removes all occurrences of specified keywords from the JSON schema.
|
||||
func removeKeywords(jsonStr string, keywords []string) string {
|
||||
deletePaths := make([]string, 0)
|
||||
pathsByField := findPathsByFields(jsonStr, keywords)
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
for _, p := range pathsByField[key] {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
deletePaths = append(deletePaths, p)
|
||||
}
|
||||
}
|
||||
sortByDepth(deletePaths)
|
||||
for _, p := range deletePaths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
@@ -235,8 +241,9 @@ var unsupportedConstraints = []string{
|
||||
}
|
||||
|
||||
func moveConstraintsToDescription(jsonStr string) string {
|
||||
pathsByField := findPathsByFields(jsonStr, unsupportedConstraints)
|
||||
for _, key := range unsupportedConstraints {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
for _, p := range pathsByField[key] {
|
||||
val := gjson.Get(jsonStr, p)
|
||||
if !val.Exists() || val.IsObject() || val.IsArray() {
|
||||
continue
|
||||
@@ -424,14 +431,21 @@ func removeUnsupportedKeywords(jsonStr string) string {
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||
"propertyNames", // Gemini doesn't support property name validation
|
||||
)
|
||||
|
||||
deletePaths := make([]string, 0)
|
||||
pathsByField := findPathsByFields(jsonStr, keywords)
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
for _, p := range pathsByField[key] {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
deletePaths = append(deletePaths, p)
|
||||
}
|
||||
}
|
||||
sortByDepth(deletePaths)
|
||||
for _, p := range deletePaths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
// Remove x-* extension fields (e.g., x-google-enum-descriptions) that are not supported by Gemini API
|
||||
jsonStr = removeExtensionFields(jsonStr)
|
||||
return jsonStr
|
||||
@@ -581,6 +595,42 @@ func findPaths(jsonStr, field string) []string {
|
||||
return paths
|
||||
}
|
||||
|
||||
func findPathsByFields(jsonStr string, fields []string) map[string][]string {
|
||||
set := make(map[string]struct{}, len(fields))
|
||||
for _, field := range fields {
|
||||
set[field] = struct{}{}
|
||||
}
|
||||
paths := make(map[string][]string, len(set))
|
||||
walkForFields(gjson.Parse(jsonStr), "", set, paths)
|
||||
return paths
|
||||
}
|
||||
|
||||
func walkForFields(value gjson.Result, path string, fields map[string]struct{}, paths map[string][]string) {
|
||||
switch value.Type {
|
||||
case gjson.JSON:
|
||||
value.ForEach(func(key, val gjson.Result) bool {
|
||||
keyStr := key.String()
|
||||
safeKey := escapeGJSONPathKey(keyStr)
|
||||
|
||||
var childPath string
|
||||
if path == "" {
|
||||
childPath = safeKey
|
||||
} else {
|
||||
childPath = path + "." + safeKey
|
||||
}
|
||||
|
||||
if _, ok := fields[keyStr]; ok {
|
||||
paths[keyStr] = append(paths[keyStr], childPath)
|
||||
}
|
||||
|
||||
walkForFields(val, childPath, fields, paths)
|
||||
return true
|
||||
})
|
||||
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
|
||||
// Terminal types - no further traversal needed
|
||||
}
|
||||
}
|
||||
|
||||
func sortByDepth(paths []string) {
|
||||
sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) })
|
||||
}
|
||||
@@ -667,6 +717,9 @@ func orDefault(val, def string) string {
|
||||
}
|
||||
|
||||
func escapeGJSONPathKey(key string) string {
|
||||
if strings.IndexAny(key, ".*?") == -1 {
|
||||
return key
|
||||
}
|
||||
return gjsonPathKeyReplacer.Replace(key)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ package util
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -33,15 +32,15 @@ func Walk(value gjson.Result, path, field string, paths *[]string) {
|
||||
// . -> \.
|
||||
// * -> \*
|
||||
// ? -> \?
|
||||
var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
safeKey := keyReplacer.Replace(key.String())
|
||||
keyStr := key.String()
|
||||
safeKey := escapeGJSONPathKey(keyStr)
|
||||
|
||||
if path == "" {
|
||||
childPath = safeKey
|
||||
} else {
|
||||
childPath = path + "." + safeKey
|
||||
}
|
||||
if key.String() == field {
|
||||
if keyStr == field {
|
||||
*paths = append(*paths, childPath)
|
||||
}
|
||||
Walk(val, childPath, field, paths)
|
||||
@@ -87,15 +86,6 @@ func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
|
||||
return finalJson, nil
|
||||
}
|
||||
|
||||
func DeleteKey(jsonStr, keyName string) string {
|
||||
paths := make([]string, 0)
|
||||
Walk(gjson.Parse(jsonStr), "", keyName, &paths)
|
||||
for _, p := range paths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// FixJSON converts non-standard JSON that uses single quotes for strings into
|
||||
// RFC 8259-compliant JSON by converting those single-quoted strings to
|
||||
// double-quoted strings with proper escaping.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -72,6 +74,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
} else if resolvedAuthDir != "" {
|
||||
@@ -84,6 +87,11 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
sum := sha256.Sum256(data)
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
// Parse and cache auth content for future diff comparisons
|
||||
var auth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -127,6 +135,13 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
|
||||
// Parse new auth content for diff comparison
|
||||
var newAuth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &newAuth); errParse != nil {
|
||||
log.Errorf("failed to parse auth file %s: %v", filepath.Base(path), errParse)
|
||||
return
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
@@ -141,7 +156,26 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get old auth for diff comparison
|
||||
var oldAuth *coreauth.Auth
|
||||
if w.lastAuthContents != nil {
|
||||
oldAuth = w.lastAuthContents[normalized]
|
||||
}
|
||||
|
||||
// Compute and log field changes
|
||||
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
|
||||
log.Debugf("auth field changes for %s:", filepath.Base(path))
|
||||
for _, c := range changes {
|
||||
log.Debugf(" %s", c)
|
||||
}
|
||||
}
|
||||
|
||||
// Update caches
|
||||
w.lastAuthHashes[normalized] = curHash
|
||||
if w.lastAuthContents == nil {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
@@ -160,6 +194,7 @@ func (w *Watcher) removeClient(path string) {
|
||||
|
||||
cfg := w.config
|
||||
delete(w.lastAuthHashes, normalized)
|
||||
delete(w.lastAuthContents, normalized)
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
|
||||
|
||||
44
internal/watcher/diff/auth_diff.go
Normal file
44
internal/watcher/diff/auth_diff.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// auth_diff.go computes human-readable diffs for auth file field changes.
|
||||
package diff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// BuildAuthChangeDetails computes a redacted, human-readable list of auth field changes.
|
||||
// Only prefix, proxy_url, and disabled fields are tracked; sensitive data is never printed.
|
||||
func BuildAuthChangeDetails(oldAuth, newAuth *coreauth.Auth) []string {
|
||||
changes := make([]string, 0, 3)
|
||||
|
||||
// Handle nil cases by using empty Auth as default
|
||||
if oldAuth == nil {
|
||||
oldAuth = &coreauth.Auth{}
|
||||
}
|
||||
if newAuth == nil {
|
||||
return changes
|
||||
}
|
||||
|
||||
// Compare prefix
|
||||
oldPrefix := strings.TrimSpace(oldAuth.Prefix)
|
||||
newPrefix := strings.TrimSpace(newAuth.Prefix)
|
||||
if oldPrefix != newPrefix {
|
||||
changes = append(changes, fmt.Sprintf("prefix: %s -> %s", oldPrefix, newPrefix))
|
||||
}
|
||||
|
||||
// Compare proxy_url (redacted)
|
||||
oldProxy := strings.TrimSpace(oldAuth.ProxyURL)
|
||||
newProxy := strings.TrimSpace(newAuth.ProxyURL)
|
||||
if oldProxy != newProxy {
|
||||
changes = append(changes, fmt.Sprintf("proxy_url: %s -> %s", formatProxyURL(oldProxy), formatProxyURL(newProxy)))
|
||||
}
|
||||
|
||||
// Compare disabled
|
||||
if oldAuth.Disabled != newAuth.Disabled {
|
||||
changes = append(changes, fmt.Sprintf("disabled: %t -> %t", oldAuth.Disabled, newAuth.Disabled))
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
@@ -27,6 +27,12 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.Debug != newCfg.Debug {
|
||||
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
|
||||
}
|
||||
if oldCfg.Pprof.Enable != newCfg.Pprof.Enable {
|
||||
changes = append(changes, fmt.Sprintf("pprof.enable: %t -> %t", oldCfg.Pprof.Enable, newCfg.Pprof.Enable))
|
||||
}
|
||||
if strings.TrimSpace(oldCfg.Pprof.Addr) != strings.TrimSpace(newCfg.Pprof.Addr) {
|
||||
changes = append(changes, fmt.Sprintf("pprof.addr: %s -> %s", strings.TrimSpace(oldCfg.Pprof.Addr), strings.TrimSpace(newCfg.Pprof.Addr)))
|
||||
}
|
||||
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
|
||||
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ type Watcher struct {
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastAuthContents map[string]*coreauth.Auth
|
||||
lastRemoveTimes map[string]time.Time
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
|
||||
@@ -1,12 +1,90 @@
|
||||
package access
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrNoCredentials indicates no recognizable credentials were supplied.
|
||||
ErrNoCredentials = errors.New("access: no credentials provided")
|
||||
// ErrInvalidCredential signals that supplied credentials were rejected by a provider.
|
||||
ErrInvalidCredential = errors.New("access: invalid credential")
|
||||
// ErrNotHandled tells the manager to continue trying other providers.
|
||||
ErrNotHandled = errors.New("access: not handled")
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AuthErrorCode classifies authentication failures.
|
||||
type AuthErrorCode string
|
||||
|
||||
const (
|
||||
AuthErrorCodeNoCredentials AuthErrorCode = "no_credentials"
|
||||
AuthErrorCodeInvalidCredential AuthErrorCode = "invalid_credential"
|
||||
AuthErrorCodeNotHandled AuthErrorCode = "not_handled"
|
||||
AuthErrorCodeInternal AuthErrorCode = "internal_error"
|
||||
)
|
||||
|
||||
// AuthError carries authentication failure details and HTTP status.
|
||||
type AuthError struct {
|
||||
Code AuthErrorCode
|
||||
Message string
|
||||
StatusCode int
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
message := strings.TrimSpace(e.Message)
|
||||
if message == "" {
|
||||
message = "authentication error"
|
||||
}
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %v", message, e.Cause)
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func (e *AuthError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// HTTPStatusCode returns a safe fallback for missing status codes.
|
||||
func (e *AuthError) HTTPStatusCode() int {
|
||||
if e == nil || e.StatusCode <= 0 {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
return e.StatusCode
|
||||
}
|
||||
|
||||
func newAuthError(code AuthErrorCode, message string, statusCode int, cause error) *AuthError {
|
||||
return &AuthError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
StatusCode: statusCode,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
func NewNoCredentialsError() *AuthError {
|
||||
return newAuthError(AuthErrorCodeNoCredentials, "Missing API key", http.StatusUnauthorized, nil)
|
||||
}
|
||||
|
||||
func NewInvalidCredentialError() *AuthError {
|
||||
return newAuthError(AuthErrorCodeInvalidCredential, "Invalid API key", http.StatusUnauthorized, nil)
|
||||
}
|
||||
|
||||
func NewNotHandledError() *AuthError {
|
||||
return newAuthError(AuthErrorCodeNotHandled, "authentication provider did not handle request", 0, nil)
|
||||
}
|
||||
|
||||
func NewInternalAuthError(message string, cause error) *AuthError {
|
||||
normalizedMessage := strings.TrimSpace(message)
|
||||
if normalizedMessage == "" {
|
||||
normalizedMessage = "Authentication service error"
|
||||
}
|
||||
return newAuthError(AuthErrorCodeInternal, normalizedMessage, http.StatusInternalServerError, cause)
|
||||
}
|
||||
|
||||
func IsAuthErrorCode(authErr *AuthError, code AuthErrorCode) bool {
|
||||
if authErr == nil {
|
||||
return false
|
||||
}
|
||||
return authErr.Code == code
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package access
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
@@ -43,7 +42,7 @@ func (m *Manager) Providers() []Provider {
|
||||
}
|
||||
|
||||
// Authenticate evaluates providers until one succeeds.
|
||||
func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, error) {
|
||||
func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, *AuthError) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -61,29 +60,29 @@ func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, e
|
||||
if provider == nil {
|
||||
continue
|
||||
}
|
||||
res, err := provider.Authenticate(ctx, r)
|
||||
if err == nil {
|
||||
res, authErr := provider.Authenticate(ctx, r)
|
||||
if authErr == nil {
|
||||
return res, nil
|
||||
}
|
||||
if errors.Is(err, ErrNotHandled) {
|
||||
if IsAuthErrorCode(authErr, AuthErrorCodeNotHandled) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, ErrNoCredentials) {
|
||||
if IsAuthErrorCode(authErr, AuthErrorCodeNoCredentials) {
|
||||
missing = true
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, ErrInvalidCredential) {
|
||||
if IsAuthErrorCode(authErr, AuthErrorCodeInvalidCredential) {
|
||||
invalid = true
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
return nil, authErr
|
||||
}
|
||||
|
||||
if invalid {
|
||||
return nil, ErrInvalidCredential
|
||||
return nil, NewInvalidCredentialError()
|
||||
}
|
||||
if missing {
|
||||
return nil, ErrNoCredentials
|
||||
return nil, NewNoCredentialsError()
|
||||
}
|
||||
return nil, ErrNoCredentials
|
||||
return nil, NewNoCredentialsError()
|
||||
}
|
||||
|
||||
@@ -2,17 +2,15 @@ package access
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
// Provider validates credentials for incoming requests.
|
||||
type Provider interface {
|
||||
Identifier() string
|
||||
Authenticate(ctx context.Context, r *http.Request) (*Result, error)
|
||||
Authenticate(ctx context.Context, r *http.Request) (*Result, *AuthError)
|
||||
}
|
||||
|
||||
// Result conveys authentication outcome.
|
||||
@@ -22,66 +20,64 @@ type Result struct {
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
// ProviderFactory builds a provider from configuration data.
|
||||
type ProviderFactory func(cfg *config.AccessProvider, root *config.SDKConfig) (Provider, error)
|
||||
|
||||
var (
|
||||
registryMu sync.RWMutex
|
||||
registry = make(map[string]ProviderFactory)
|
||||
registry = make(map[string]Provider)
|
||||
order []string
|
||||
)
|
||||
|
||||
// RegisterProvider registers a provider factory for a given type identifier.
|
||||
func RegisterProvider(typ string, factory ProviderFactory) {
|
||||
if typ == "" || factory == nil {
|
||||
// RegisterProvider registers a pre-built provider instance for a given type identifier.
|
||||
func RegisterProvider(typ string, provider Provider) {
|
||||
normalizedType := strings.TrimSpace(typ)
|
||||
if normalizedType == "" || provider == nil {
|
||||
return
|
||||
}
|
||||
|
||||
registryMu.Lock()
|
||||
registry[typ] = factory
|
||||
if _, exists := registry[normalizedType]; !exists {
|
||||
order = append(order, normalizedType)
|
||||
}
|
||||
registry[normalizedType] = provider
|
||||
registryMu.Unlock()
|
||||
}
|
||||
|
||||
func BuildProvider(cfg *config.AccessProvider, root *config.SDKConfig) (Provider, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("access: nil provider config")
|
||||
// UnregisterProvider removes a provider by type identifier.
|
||||
func UnregisterProvider(typ string) {
|
||||
normalizedType := strings.TrimSpace(typ)
|
||||
if normalizedType == "" {
|
||||
return
|
||||
}
|
||||
registryMu.RLock()
|
||||
factory, ok := registry[cfg.Type]
|
||||
registryMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("access: provider type %q is not registered", cfg.Type)
|
||||
registryMu.Lock()
|
||||
if _, exists := registry[normalizedType]; !exists {
|
||||
registryMu.Unlock()
|
||||
return
|
||||
}
|
||||
provider, err := factory(cfg, root)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access: failed to build provider %q: %w", cfg.Name, err)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// BuildProviders constructs providers declared in configuration.
|
||||
func BuildProviders(root *config.SDKConfig) ([]Provider, error) {
|
||||
if root == nil {
|
||||
return nil, nil
|
||||
}
|
||||
providers := make([]Provider, 0, len(root.Access.Providers))
|
||||
for i := range root.Access.Providers {
|
||||
providerCfg := &root.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
delete(registry, normalizedType)
|
||||
for index := range order {
|
||||
if order[index] != normalizedType {
|
||||
continue
|
||||
}
|
||||
provider, err := BuildProvider(providerCfg, root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
order = append(order[:index], order[index+1:]...)
|
||||
break
|
||||
}
|
||||
registryMu.Unlock()
|
||||
}
|
||||
|
||||
// RegisteredProviders returns the global provider instances in registration order.
|
||||
func RegisteredProviders() []Provider {
|
||||
registryMu.RLock()
|
||||
if len(order) == 0 {
|
||||
registryMu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
providers := make([]Provider, 0, len(order))
|
||||
for _, providerType := range order {
|
||||
provider, exists := registry[providerType]
|
||||
if !exists || provider == nil {
|
||||
continue
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
if inline := config.MakeInlineAPIKeyProvider(root.APIKeys); inline != nil {
|
||||
provider, err := BuildProvider(inline, root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
}
|
||||
return providers, nil
|
||||
registryMu.RUnlock()
|
||||
return providers
|
||||
}
|
||||
|
||||
47
sdk/access/types.go
Normal file
47
sdk/access/types.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package access
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@@ -155,20 +155,6 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||
}
|
||||
|
||||
func mergeMetadata(base, overlay map[string]any) map[string]any {
|
||||
if len(base) == 0 && len(overlay) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(base)+len(overlay))
|
||||
for k, v := range base {
|
||||
out[k] = v
|
||||
}
|
||||
for k, v := range overlay {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// BaseAPIHandler contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service and manages
|
||||
// load balancing, client selection, and configuration.
|
||||
@@ -391,14 +377,18 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
Payload: payload,
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: false,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
OriginalRequest: rawJSON,
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = reqMeta
|
||||
@@ -418,7 +408,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
}
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
}
|
||||
return cloneBytes(resp.Payload), nil
|
||||
return resp.Payload, nil
|
||||
}
|
||||
|
||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
@@ -430,14 +420,18 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
Payload: payload,
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: false,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
OriginalRequest: rawJSON,
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = reqMeta
|
||||
@@ -457,7 +451,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
}
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
}
|
||||
return cloneBytes(resp.Payload), nil
|
||||
return resp.Payload, nil
|
||||
}
|
||||
|
||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||
@@ -472,14 +466,18 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
Payload: payload,
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: true,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
OriginalRequest: rawJSON,
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = reqMeta
|
||||
@@ -668,17 +666,6 @@ func cloneBytes(src []byte) []byte {
|
||||
return dst
|
||||
}
|
||||
|
||||
func cloneMetadata(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
||||
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -709,7 +696,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
||||
var previous []byte
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
previous = bytes.Clone(existingBytes)
|
||||
previous = existingBytes
|
||||
}
|
||||
}
|
||||
appendAPIResponse(c, body)
|
||||
|
||||
@@ -18,6 +18,7 @@ type ManagementTokenRequester interface {
|
||||
RequestCodexToken(*gin.Context)
|
||||
RequestAntigravityToken(*gin.Context)
|
||||
RequestQwenToken(*gin.Context)
|
||||
RequestKimiToken(*gin.Context)
|
||||
RequestIFlowToken(*gin.Context)
|
||||
RequestIFlowCookieToken(*gin.Context)
|
||||
GetAuthStatus(c *gin.Context)
|
||||
@@ -55,6 +56,10 @@ func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
|
||||
m.handler.RequestQwenToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestKimiToken(c *gin.Context) {
|
||||
m.handler.RequestKimiToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
|
||||
m.handler.RequestIFlowToken(c)
|
||||
}
|
||||
|
||||
123
sdk/auth/kimi.go
Normal file
123
sdk/auth/kimi.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// kimiRefreshLead is the duration before token expiry when refresh should occur.
|
||||
var kimiRefreshLead = 5 * time.Minute
|
||||
|
||||
// KimiAuthenticator implements the OAuth device flow login for Kimi (Moonshot AI).
|
||||
type KimiAuthenticator struct{}
|
||||
|
||||
// NewKimiAuthenticator constructs a new Kimi authenticator.
|
||||
func NewKimiAuthenticator() Authenticator {
|
||||
return &KimiAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for kimi.
|
||||
func (KimiAuthenticator) Provider() string {
|
||||
return "kimi"
|
||||
}
|
||||
|
||||
// RefreshLead returns the duration before token expiry when refresh should occur.
|
||||
// Kimi tokens expire and need to be refreshed before expiry.
|
||||
func (KimiAuthenticator) RefreshLead() *time.Duration {
|
||||
return &kimiRefreshLead
|
||||
}
|
||||
|
||||
// Login initiates the Kimi device flow authentication.
|
||||
func (a KimiAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
authSvc := kimi.NewKimiAuth(cfg)
|
||||
|
||||
// Start the device flow
|
||||
fmt.Println("Starting Kimi authentication...")
|
||||
deviceCode, err := authSvc.StartDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to start device flow: %w", err)
|
||||
}
|
||||
|
||||
// Display the verification URL
|
||||
verificationURL := deviceCode.VerificationURIComplete
|
||||
if verificationURL == "" {
|
||||
verificationURL = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
fmt.Printf("\nTo authenticate, please visit:\n%s\n\n", verificationURL)
|
||||
if deviceCode.UserCode != "" {
|
||||
fmt.Printf("User code: %s\n\n", deviceCode.UserCode)
|
||||
}
|
||||
|
||||
// Try to open the browser automatically
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(verificationURL); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
} else {
|
||||
fmt.Println("Browser opened automatically.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for authorization...")
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
|
||||
}
|
||||
|
||||
// Wait for user authorization
|
||||
authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: %w", err)
|
||||
}
|
||||
|
||||
// Create the token storage
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
// Build metadata with token information
|
||||
metadata := map[string]any{
|
||||
"type": "kimi",
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"refresh_token": authBundle.TokenData.RefreshToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if authBundle.TokenData.ExpiresAt > 0 {
|
||||
exp := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
metadata["expired"] = exp
|
||||
}
|
||||
if strings.TrimSpace(authBundle.DeviceID) != "" {
|
||||
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
|
||||
}
|
||||
|
||||
// Generate a unique filename
|
||||
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
|
||||
|
||||
fmt.Println("\nKimi authentication successful!")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: "Kimi User",
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
@@ -14,6 +14,7 @@ func init() {
|
||||
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
@@ -607,6 +607,9 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
if isRequestInvalidError(errExec) {
|
||||
return cliproxyexecutor.Response{}, errExec
|
||||
}
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
@@ -660,6 +663,9 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
if isRequestInvalidError(errExec) {
|
||||
return cliproxyexecutor.Response{}, errExec
|
||||
}
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
@@ -711,6 +717,9 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
if isRequestInvalidError(errStream) {
|
||||
return nil, errStream
|
||||
}
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
@@ -1110,6 +1119,9 @@ func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []stri
|
||||
if status := statusCodeFromError(err); status == http.StatusOK {
|
||||
return 0, false
|
||||
}
|
||||
if isRequestInvalidError(err) {
|
||||
return 0, false
|
||||
}
|
||||
wait, found := m.closestCooldownWait(providers, model, attempt)
|
||||
if !found || wait > maxWait {
|
||||
return 0, false
|
||||
@@ -1299,7 +1311,7 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
|
||||
stateUnavailable = true
|
||||
} else if state.Unavailable {
|
||||
if state.NextRetryAfter.IsZero() {
|
||||
stateUnavailable = true
|
||||
stateUnavailable = false
|
||||
} else if state.NextRetryAfter.After(now) {
|
||||
stateUnavailable = true
|
||||
if earliestRetry.IsZero() || state.NextRetryAfter.Before(earliestRetry) {
|
||||
@@ -1430,6 +1442,21 @@ func statusCodeFromResult(err *Error) int {
|
||||
return err.StatusCode()
|
||||
}
|
||||
|
||||
// isRequestInvalidError returns true if the error represents a client request
|
||||
// error that should not be retried. Specifically, it checks for 400 Bad Request
|
||||
// with "invalid_request_error" in the message, indicating the request itself is
|
||||
// malformed and switching to a different auth will not help.
|
||||
func isRequestInvalidError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
status := statusCodeFromError(err)
|
||||
if status != http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), "invalid_request_error")
|
||||
}
|
||||
|
||||
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
|
||||
if auth == nil {
|
||||
return
|
||||
|
||||
61
sdk/cliproxy/auth/conductor_availability_test.go
Normal file
61
sdk/cliproxy/auth/conductor_availability_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUpdateAggregatedAvailability_UnavailableWithoutNextRetryDoesNotBlockAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
model := "test-model"
|
||||
auth := &Auth{
|
||||
ID: "a",
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusError,
|
||||
Unavailable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updateAggregatedAvailability(auth, now)
|
||||
|
||||
if auth.Unavailable {
|
||||
t.Fatalf("auth.Unavailable = true, want false")
|
||||
}
|
||||
if !auth.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("auth.NextRetryAfter = %v, want zero", auth.NextRetryAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAggregatedAvailability_FutureNextRetryBlocksAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
model := "test-model"
|
||||
next := now.Add(5 * time.Minute)
|
||||
auth := &Auth{
|
||||
ID: "a",
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusError,
|
||||
Unavailable: true,
|
||||
NextRetryAfter: next,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updateAggregatedAvailability(auth, now)
|
||||
|
||||
if !auth.Unavailable {
|
||||
t.Fatalf("auth.Unavailable = false, want true")
|
||||
}
|
||||
if auth.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("auth.NextRetryAfter = zero, want %v", next)
|
||||
}
|
||||
if auth.NextRetryAfter.Sub(next) > time.Second || next.Sub(auth.NextRetryAfter) > time.Second {
|
||||
t.Fatalf("auth.NextRetryAfter = %v, want %v", auth.NextRetryAfter, next)
|
||||
}
|
||||
}
|
||||
@@ -221,7 +221,7 @@ func modelAliasChannel(auth *Auth) string {
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model alias (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
|
||||
func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
@@ -245,7 +245,7 @@ func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kimi":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
|
||||
@@ -70,6 +70,15 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
|
||||
input: "gemini-2.5-pro(none)",
|
||||
want: "gemini-2.5-pro-exp-03-25(none)",
|
||||
},
|
||||
{
|
||||
name: "kimi suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"kimi": {{Name: "kimi-k2.5", Alias: "k2.5"}},
|
||||
},
|
||||
channel: "kimi",
|
||||
input: "k2.5(high)",
|
||||
want: "kimi-k2.5(high)",
|
||||
},
|
||||
{
|
||||
name: "case insensitive alias lookup with suffix",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
@@ -152,11 +161,21 @@ func createAuthForChannel(channel string) *Auth {
|
||||
return &Auth{Provider: "qwen"}
|
||||
case "iflow":
|
||||
return &Auth{Provider: "iflow"}
|
||||
case "kimi":
|
||||
return &Auth{Provider: "kimi"}
|
||||
default:
|
||||
return &Auth{Provider: channel}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthModelAliasChannel_Kimi(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := OAuthModelAliasChannel("kimi", "oauth"); got != "kimi" {
|
||||
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "kimi")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
type RoundRobinSelector struct {
|
||||
mu sync.Mutex
|
||||
cursors map[string]int
|
||||
maxKeys int
|
||||
}
|
||||
|
||||
// FillFirstSelector selects the first available credential (deterministic ordering).
|
||||
@@ -119,6 +121,19 @@ func authPriority(auth *Auth) int {
|
||||
return parsed
|
||||
}
|
||||
|
||||
func canonicalModelKey(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
parsed := thinking.ParseSuffix(model)
|
||||
modelName := strings.TrimSpace(parsed.ModelName)
|
||||
if modelName == "" {
|
||||
return model
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
|
||||
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||
available = make(map[int][]*Auth)
|
||||
for i := 0; i < len(auths); i++ {
|
||||
@@ -185,11 +200,18 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := provider + ":" + model
|
||||
key := provider + ":" + canonicalModelKey(model)
|
||||
s.mu.Lock()
|
||||
if s.cursors == nil {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
limit := s.maxKeys
|
||||
if limit <= 0 {
|
||||
limit = 4096
|
||||
}
|
||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
index := s.cursors[key]
|
||||
|
||||
if index >= 2_147_483_640 {
|
||||
@@ -223,7 +245,14 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block
|
||||
}
|
||||
if model != "" {
|
||||
if len(auth.ModelStates) > 0 {
|
||||
if state, ok := auth.ModelStates[model]; ok && state != nil {
|
||||
state, ok := auth.ModelStates[model]
|
||||
if (!ok || state == nil) && model != "" {
|
||||
baseModel := canonicalModelKey(model)
|
||||
if baseModel != "" && baseModel != model {
|
||||
state, ok = auth.ModelStates[baseModel]
|
||||
}
|
||||
}
|
||||
if ok && state != nil {
|
||||
if state.Status == StatusDisabled {
|
||||
return true, blockReasonDisabled, time.Time{}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -175,3 +177,228 @@ func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectorPick_AllCooldownReturnsModelCooldownError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := "test-model"
|
||||
now := time.Now()
|
||||
next := now.Add(60 * time.Second)
|
||||
auths := []*Auth{
|
||||
{
|
||||
ID: "a",
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusActive,
|
||||
Unavailable: true,
|
||||
NextRetryAfter: next,
|
||||
Quota: QuotaState{
|
||||
Exceeded: true,
|
||||
NextRecoverAt: next,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusActive,
|
||||
Unavailable: true,
|
||||
NextRetryAfter: next,
|
||||
Quota: QuotaState{
|
||||
Exceeded: true,
|
||||
NextRecoverAt: next,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("mixed provider redacts provider field", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &FillFirstSelector{}
|
||||
_, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, auths)
|
||||
if err == nil {
|
||||
t.Fatalf("Pick() error = nil")
|
||||
}
|
||||
|
||||
var mce *modelCooldownError
|
||||
if !errors.As(err, &mce) {
|
||||
t.Fatalf("Pick() error = %T, want *modelCooldownError", err)
|
||||
}
|
||||
if mce.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("StatusCode() = %d, want %d", mce.StatusCode(), http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
headers := mce.Headers()
|
||||
if got := headers.Get("Retry-After"); got == "" {
|
||||
t.Fatalf("Headers().Get(Retry-After) = empty")
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal(Error()) error = %v", err)
|
||||
}
|
||||
rawErr, ok := payload["error"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Error() payload missing error object: %v", payload)
|
||||
}
|
||||
if got, _ := rawErr["code"].(string); got != "model_cooldown" {
|
||||
t.Fatalf("Error().error.code = %q, want %q", got, "model_cooldown")
|
||||
}
|
||||
if _, ok := rawErr["provider"]; ok {
|
||||
t.Fatalf("Error().error.provider exists for mixed provider: %v", rawErr["provider"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-mixed provider includes provider field", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &FillFirstSelector{}
|
||||
_, err := selector.Pick(context.Background(), "gemini", model, cliproxyexecutor.Options{}, auths)
|
||||
if err == nil {
|
||||
t.Fatalf("Pick() error = nil")
|
||||
}
|
||||
|
||||
var mce *modelCooldownError
|
||||
if !errors.As(err, &mce) {
|
||||
t.Fatalf("Pick() error = %T, want *modelCooldownError", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal(Error()) error = %v", err)
|
||||
}
|
||||
rawErr, ok := payload["error"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Error() payload missing error object: %v", payload)
|
||||
}
|
||||
if got, _ := rawErr["provider"].(string); got != "gemini" {
|
||||
t.Fatalf("Error().error.provider = %q, want %q", got, "gemini")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsAuthBlockedForModel_UnavailableWithoutNextRetryIsNotBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
model := "test-model"
|
||||
auth := &Auth{
|
||||
ID: "a",
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusActive,
|
||||
Unavailable: true,
|
||||
Quota: QuotaState{
|
||||
Exceeded: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
||||
if blocked {
|
||||
t.Fatalf("blocked = true, want false")
|
||||
}
|
||||
if reason != blockReasonNone {
|
||||
t.Fatalf("reason = %v, want %v", reason, blockReasonNone)
|
||||
}
|
||||
if !next.IsZero() {
|
||||
t.Fatalf("next = %v, want zero", next)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillFirstSelectorPick_ThinkingSuffixFallsBackToBaseModelState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &FillFirstSelector{}
|
||||
now := time.Now()
|
||||
|
||||
baseModel := "test-model"
|
||||
requestedModel := "test-model(high)"
|
||||
|
||||
high := &Auth{
|
||||
ID: "high",
|
||||
Attributes: map[string]string{"priority": "10"},
|
||||
ModelStates: map[string]*ModelState{
|
||||
baseModel: {
|
||||
Status: StatusActive,
|
||||
Unavailable: true,
|
||||
NextRetryAfter: now.Add(30 * time.Minute),
|
||||
Quota: QuotaState{
|
||||
Exceeded: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
low := &Auth{
|
||||
ID: "low",
|
||||
Attributes: map[string]string{"priority": "0"},
|
||||
}
|
||||
|
||||
got, err := selector.Pick(context.Background(), "mixed", requestedModel, cliproxyexecutor.Options{}, []*Auth{high, low})
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() auth = nil")
|
||||
}
|
||||
if got.ID != "low" {
|
||||
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_ThinkingSuffixSharesCursor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
auths := []*Auth{
|
||||
{ID: "b"},
|
||||
{ID: "a"},
|
||||
}
|
||||
|
||||
first, err := selector.Pick(context.Background(), "gemini", "test-model(high)", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() first error = %v", err)
|
||||
}
|
||||
second, err := selector.Pick(context.Background(), "gemini", "test-model(low)", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() second error = %v", err)
|
||||
}
|
||||
if first == nil || second == nil {
|
||||
t.Fatalf("Pick() returned nil auth")
|
||||
}
|
||||
if first.ID != "a" {
|
||||
t.Fatalf("Pick() first auth.ID = %q, want %q", first.ID, "a")
|
||||
}
|
||||
if second.ID != "b" {
|
||||
t.Fatalf("Pick() second auth.ID = %q, want %q", second.ID, "b")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{maxKeys: 2}
|
||||
auths := []*Auth{{ID: "a"}}
|
||||
|
||||
_, _ = selector.Pick(context.Background(), "gemini", "m1", cliproxyexecutor.Options{}, auths)
|
||||
_, _ = selector.Pick(context.Background(), "gemini", "m2", cliproxyexecutor.Options{}, auths)
|
||||
_, _ = selector.Pick(context.Background(), "gemini", "m3", cliproxyexecutor.Options{}, auths)
|
||||
|
||||
selector.mu.Lock()
|
||||
defer selector.mu.Unlock()
|
||||
|
||||
if selector.cursors == nil {
|
||||
t.Fatalf("selector.cursors = nil")
|
||||
}
|
||||
if len(selector.cursors) != 1 {
|
||||
t.Fatalf("len(selector.cursors) = %d, want %d", len(selector.cursors), 1)
|
||||
}
|
||||
if _, ok := selector.cursors["gemini:m3"]; !ok {
|
||||
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -186,11 +187,8 @@ func (b *Builder) Build() (*Service, error) {
|
||||
accessManager = sdkaccess.NewManager()
|
||||
}
|
||||
|
||||
providers, err := sdkaccess.BuildProviders(&b.cfg.SDKConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
configaccess.Register(&b.cfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
|
||||
coreManager := b.coreManager
|
||||
if coreManager == nil {
|
||||
|
||||
163
sdk/cliproxy/pprof_server.go
Normal file
163
sdk/cliproxy/pprof_server.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type pprofServer struct {
|
||||
mu sync.Mutex
|
||||
server *http.Server
|
||||
addr string
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func newPprofServer() *pprofServer {
|
||||
return &pprofServer{}
|
||||
}
|
||||
|
||||
func (s *Service) applyPprofConfig(cfg *config.Config) {
|
||||
if s == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
if s.pprofServer == nil {
|
||||
s.pprofServer = newPprofServer()
|
||||
}
|
||||
s.pprofServer.Apply(cfg)
|
||||
}
|
||||
|
||||
func (s *Service) shutdownPprof(ctx context.Context) error {
|
||||
if s == nil || s.pprofServer == nil {
|
||||
return nil
|
||||
}
|
||||
return s.pprofServer.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (p *pprofServer) Apply(cfg *config.Config) {
|
||||
if p == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
addr := strings.TrimSpace(cfg.Pprof.Addr)
|
||||
if addr == "" {
|
||||
addr = config.DefaultPprofAddr
|
||||
}
|
||||
enabled := cfg.Pprof.Enable
|
||||
|
||||
p.mu.Lock()
|
||||
currentServer := p.server
|
||||
currentAddr := p.addr
|
||||
p.addr = addr
|
||||
p.enabled = enabled
|
||||
if !enabled {
|
||||
p.server = nil
|
||||
p.mu.Unlock()
|
||||
if currentServer != nil {
|
||||
p.stopServer(currentServer, currentAddr, "disabled")
|
||||
}
|
||||
return
|
||||
}
|
||||
if currentServer != nil && currentAddr == addr {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.server = nil
|
||||
p.mu.Unlock()
|
||||
|
||||
if currentServer != nil {
|
||||
p.stopServer(currentServer, currentAddr, "restarted")
|
||||
}
|
||||
|
||||
p.startServer(addr)
|
||||
}
|
||||
|
||||
func (p *pprofServer) Shutdown(ctx context.Context) error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
currentServer := p.server
|
||||
currentAddr := p.addr
|
||||
p.server = nil
|
||||
p.enabled = false
|
||||
p.mu.Unlock()
|
||||
|
||||
if currentServer == nil {
|
||||
return nil
|
||||
}
|
||||
return p.stopServerWithContext(ctx, currentServer, currentAddr, "shutdown")
|
||||
}
|
||||
|
||||
func (p *pprofServer) startServer(addr string) {
|
||||
mux := newPprofMux()
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if !p.enabled || p.addr != addr || p.server != nil {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.server = server
|
||||
p.mu.Unlock()
|
||||
|
||||
log.Infof("pprof server starting on %s", addr)
|
||||
go func() {
|
||||
if errServe := server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
||||
log.Errorf("pprof server failed on %s: %v", addr, errServe)
|
||||
p.mu.Lock()
|
||||
if p.server == server {
|
||||
p.server = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *pprofServer) stopServer(server *http.Server, addr string, reason string) {
|
||||
_ = p.stopServerWithContext(context.Background(), server, addr, reason)
|
||||
}
|
||||
|
||||
func (p *pprofServer) stopServerWithContext(ctx context.Context, server *http.Server, addr string, reason string) error {
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
stopCtx := ctx
|
||||
if stopCtx == nil {
|
||||
stopCtx = context.Background()
|
||||
}
|
||||
stopCtx, cancel := context.WithTimeout(stopCtx, 5*time.Second)
|
||||
defer cancel()
|
||||
if errStop := server.Shutdown(stopCtx); errStop != nil {
|
||||
log.Errorf("pprof server stop failed on %s: %v", addr, errStop)
|
||||
return errStop
|
||||
}
|
||||
log.Infof("pprof server stopped on %s (%s)", addr, reason)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newPprofMux() *http.ServeMux {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
|
||||
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
|
||||
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
|
||||
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
|
||||
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
|
||||
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
|
||||
return mux
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user