Compare commits

...

90 Commits

Author SHA1 Message Date
Luis Pater
d24ea4ce2a Merge pull request #1664 from ciberponk/pr/responses-compaction-compat
feat: add codex responses compatibility for compaction payloads
2026-02-25 01:21:59 +08:00
Luis Pater
2c30c981ae Merge pull request #1687 from lyd123qw2008/fix/codex-refresh-token-reused-no-retry
fix(codex): stop retrying refresh_token_reused errors
2026-02-25 01:19:30 +08:00
Luis Pater
aa1da8a858 Merge pull request #1685 from lyd123qw2008/fix/auth-auto-refresh-interval
fix(auth): respect configured auto-refresh interval
2026-02-25 01:13:47 +08:00
Luis Pater
f1e9a787d7 Merge pull request #1676 from piexian/feat/qwen-quota-handling-clean
feat(qwen): add rate limiting and quota error handling
2026-02-25 01:07:55 +08:00
Luis Pater
c66cb0afd2 Merge pull request #1683 from dusty-du/codex/device-login-flow
Add additive Codex device-code login flow
2026-02-25 00:50:48 +08:00
Luis Pater
fb48eee973 Merge pull request #1680 from canxin121/fix/responses-stream-error-chunks
fix(responses): emit schema-valid SSE chunks
2026-02-25 00:49:06 +08:00
Luis Pater
bb44e5ec44 Merge pull request #1701 from router-for-me/openai
Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping"
2026-02-25 00:46:13 +08:00
hkfires
0659ffab75 Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping" 2026-02-24 19:47:53 +08:00
Luis Pater
7cb398d167 Merge pull request #1663 from rensumo/main
feat: implement credential-based round-robin for gemini-cli
2026-02-24 06:02:50 +08:00
Luis Pater
c3e12c5e58 Merge pull request #1654 from alexey-yanchenko/feature/pass-file-inputs
Pass file input from /chat/completions and /responses to codex and claude
2026-02-24 05:53:11 +08:00
Luis Pater
1825fc7503 Merge pull request #1643 from alexey-yanchenko/fix/gemini-prompt-tokens
Fix usage convertation from gemini response to openai format
2026-02-24 05:46:13 +08:00
Luis Pater
48732ba05e Merge pull request #1527 from HEUDavid/feat/auth-hook
feat(auth): add post-auth hook mechanism
2026-02-24 05:33:13 +08:00
canxin121
acf483c9e6 fix(responses): reject invalid SSE data JSON
Guard the openai-response streaming path against truncated/invalid SSE data payloads by validating data: JSON before forwarding; surface a 502 terminal error instead of letting clients crash with JSON parse errors.
2026-02-24 01:42:54 +08:00
lyd123qw2008
3b3e0d1141 test(codex): log non-retryable refresh error and cover single-attempt behavior 2026-02-23 22:41:33 +08:00
lyd123qw2008
7acd428507 fix(codex): stop retrying refresh_token_reused errors 2026-02-23 22:31:30 +08:00
lyd123qw2008
450d1227bd fix(auth): respect configured auto-refresh interval 2026-02-23 22:07:50 +08:00
test
492b9c46f0 Add additive Codex device-code login flow 2026-02-23 06:30:04 -05:00
canxin121
eb7571936c revert: translator changes (path guard)
CI blocks PRs that modify internal/translator. Revert translator edits and keep only the /v1/responses streaming error-chunk fix; file an issue for translator conformance work.
2026-02-23 13:30:43 +08:00
canxin121
5382764d8a fix(responses): include model and usage in translated streams
Ensure response.created and response.completed chunks produced by the OpenAI/Gemini/Claude translators always include required fields (response.model and response.usage) so clients validating Responses SSE do not fail schema validation.
2026-02-23 13:22:06 +08:00
canxin121
49c8ec69d0 fix(openai): emit valid responses stream error chunks
When /v1/responses streaming fails after headers are sent, we now emit a type=error chunk instead of an HTTP-style {error:{...}} payload, preventing AI SDK chunk validation errors.
2026-02-23 12:59:50 +08:00
piexian
3b421c8181 feat(qwen): add rate limiting and quota error handling
- Add 60 requests/minute rate limiting per credential using sliding window
- Detect insufficient_quota errors and set cooldown until next day (Beijing time)
- Map quota errors (HTTP 403/429) to 429 with retryAfter for conductor integration
- Cache Beijing timezone at package level to avoid repeated syscalls
- Add redactAuthID function to protect credentials in logs
- Extract wrapQwenError helper to consolidate error handling
2026-02-23 00:38:46 +08:00
Luis Pater
713388dd7b Fixed: #1675
fix(gemini): add model definitions for Gemini 3.1 Pro High and Image
2026-02-23 00:12:57 +08:00
Luis Pater
e6c7af0fa9 Merge pull request #1522 from soilSpoon/feature/canceled
feature(proxy): Adds special handling for client cancellations in proxy error handler
2026-02-22 22:02:59 +08:00
Luis Pater
d210be06c2 fix(gemini): update min Thinking value and add Gemini 3.1 Pro Preview model definition 2026-02-22 21:51:32 +08:00
fan
afc8a0f9be refactor: simplify context_management compatibility handling 2026-02-21 22:20:48 +08:00
Luis Pater
d6ec33e8e1 Merge pull request #1662 from matchch/contribute/cache-user-id
feat: add cache-user-id toggle for Claude cloaking
2026-02-21 20:51:30 +08:00
Luis Pater
081cfe806e fix(gemini): correct Created timestamps for Gemini 3.1 Pro Preview model definitions 2026-02-21 20:47:47 +08:00
hkfires
c1c62a6c04 feat(gemini): add Gemini 3.1 Pro Preview model definitions 2026-02-21 20:42:29 +08:00
ciberponk
d693d7993b feat: support responses compaction payload compatibility for codex translator 2026-02-21 12:56:10 +08:00
rensumo
5936f9895c feat: implement credential-based round-robin for gemini-cli virtual auths
Changes the RoundRobinSelector to use two-level round-robin when
gemini-cli virtual auths are detected (via gemini_virtual_parent attr):
- Level 1: cycle across credential groups (parent accounts)
- Level 2: cycle within each group's project auths

Credentials start from a random offset (rand.IntN) for fair distribution.
Non-virtual auths and single-credential scenarios fall back to flat RR.

Adds 3 test cases covering multi-credential grouping, single-parent
fallback, and mixed virtual/non-virtual fallback.
2026-02-21 12:49:48 +08:00
matchch
2fdf5d2793 feat: add cache-user-id toggle for Claude cloaking
Default to generating a fresh random user_id per request instead of
reusing cached IDs. Add cache-user-id config option to opt in to the
previous caching behavior.

- Add CacheUserID field to CloakConfig
- Extract user_id cache logic to dedicated file
- Generate fresh user_id by default, cache only when enabled
- Add tests for both paths
2026-02-21 12:31:20 +08:00
Luis Pater
7b0eb41ebc Merge pull request #1660 from Grivn/fix/claude-token-url
fix(claude): use api.anthropic.com for OAuth token exchange
2026-02-20 21:52:08 +08:00
Grivn
ef5901c81b fix(claude): use api.anthropic.com for OAuth token exchange
console.anthropic.com is now protected by a Cloudflare managed challenge
that blocks all non-browser POST requests to /v1/oauth/token, causing
`-claude-login` to fail with a 403 error.

Switch to api.anthropic.com which hosts the same OAuth token endpoint
without the Cloudflare managed challenge.

Fixes #1659

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 20:11:27 +08:00
Luis Pater
d4829c82f7 Merge pull request #1652 from thebtf/fix/claude-translator-arguments
fix(translator): handle tool call arguments in codex→claude streaming translator
2026-02-20 19:50:20 +08:00
Luis Pater
a5f4166a9b Merge pull request #1644 from possible055/main
feat: add Gemini 3.1 Pro Preview model definition
2026-02-20 19:44:59 +08:00
Alexey Yanchenko
0cbfe7f457 Pass file input from /chat/completions and /responses to codex and claude 2026-02-20 10:25:44 +07:00
Kirill Turanskiy
1cc21cc45b fix: prevent duplicate function call arguments when delta events precede done
Non-spark codex models (gpt-5.3-codex, gpt-5.2-codex) stream function call
arguments via multiple delta events followed by a done event. The done handler
unconditionally emitted the full arguments, duplicating what deltas already
streamed. This produced invalid double JSON that Claude Code couldn't parse,
causing tool calls to fail with missing parameters and infinite retry loops.

Add HasReceivedArgumentsDelta flag to track whether delta events were received.
The done handler now only emits arguments when no deltas preceded it (spark
models), while delta-based streaming continues to work for non-spark models.
2026-02-19 23:18:14 +03:00
Kirill Turanskiy
07cf616e2b fix: handle response.function_call_arguments.done in codex→claude streaming translator
Some Codex models (e.g. gpt-5.3-codex-spark) send function call arguments
in a single "done" event without preceding "delta" events. The streaming
translator only handled "delta" events, causing tool call arguments to be
lost — resulting in empty tool inputs and infinite retry loops in clients
like Claude Code.

Emit the full arguments from the "done" event as a single input_json_delta
so downstream clients receive the complete tool input.
2026-02-19 23:18:14 +03:00
Luis Pater
4445a165e9 test(handlers): add tests for passthrough headers behavior in WriteErrorResponse 2026-02-19 21:49:44 +08:00
Luis Pater
e92e2af71a Merge branch 'codex/pr-1626' into dev 2026-02-19 21:33:23 +08:00
Luis Pater
a6bdd9a652 feat: add passthrough headers configuration
- Introduced `passthrough-headers` option in configuration to control forwarding of upstream response headers.
- Updated handlers to respect the passthrough headers setting.
- Added tests to verify behavior when passthrough is enabled or disabled.
2026-02-19 21:31:29 +08:00
Luis Pater
349a6349b3 Merge pull request #1645 from tinyc0der/fix/antigravity-tool-result-json
fix(antigravity): prevent invalid JSON when tool_result has no content
2026-02-19 21:01:25 +08:00
TinyCoder
00822770ec fix(antigravity): prevent invalid JSON when tool_result has no content
sjson.SetRaw with an empty string produces malformed JSON (e.g. "result":}).
This happens when a Claude tool_result block has no content field, causing
functionResponseResult.Raw to be "". Guard against this by falling back to
sjson.Set with an empty string only when .Raw is empty.
2026-02-19 17:08:39 +07:00
apparition
1a0ceda0fc feat: add Gemini 3.1 Pro Preview model definition 2026-02-19 17:43:08 +08:00
Alexey Yanchenko
b9ae4ab803 Fix usage convertation from gemini response to openai format 2026-02-19 15:34:59 +07:00
Luis Pater
72add453d2 docs: add OmniRoute to README 2026-02-19 13:23:25 +08:00
Luis Pater
2789396435 fix: ensure connection-scoped headers are filtered in upstream requests
- Added `connectionScopedHeaders` utility to respect "Connection" header directives.
- Updated `FilterUpstreamHeaders` to remove connection-scoped headers dynamically.
- Refactored and tested upstream header filtering with additional validations.
- Adjusted upstream header handling during retries to replace headers safely.
2026-02-19 13:19:10 +08:00
Luis Pater
61da7bd981 Merge PR #1626 into codex/pr-1626 2026-02-19 04:49:14 +08:00
Luis Pater
9c040445af Merge pull request #1635 from thebtf/fix/openai-translator-tool-streaming
fix: handle tool call argument streaming in Codex→OpenAI translator
2026-02-19 04:22:12 +08:00
Luis Pater
fff866424e Merge pull request #1628 from thebtf/fix/masquerading-headers
fix: update Claude masquerading headers and configurable defaults
2026-02-19 04:19:59 +08:00
Luis Pater
2d12becfd6 Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping
fix: clamp reasoning_effort to valid OpenAI-format values
2026-02-19 04:15:19 +08:00
Luis Pater
252f7e0751 Merge pull request #1625 from thebtf/feat/tool-prefix-config
feat: add per-auth tool_prefix_disabled option
2026-02-19 04:07:22 +08:00
Luis Pater
b2b17528cb Merge branch 'pr-1624' into dev
# Conflicts:
#	internal/runtime/executor/claude_executor.go
#	internal/runtime/executor/claude_executor_test.go
2026-02-19 04:05:04 +08:00
Luis Pater
55f938164b Merge pull request #1618 from alexey-yanchenko/fix/completions-usage
Fix empty usage in /v1/completions
2026-02-19 03:57:11 +08:00
Luis Pater
76294f0c59 Merge pull request #1608 from thebtf/fix/tool-reference-proxy-prefix-mainline
fix: add proxy_ prefix handling for tool_reference content blocks
2026-02-19 03:50:34 +08:00
Luis Pater
2bcee78c6e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:19:18 +08:00
Luis Pater
7fe8246a9f Merge branch 'tui' into dev 2026-02-19 03:18:24 +08:00
Luis Pater
93fe58e31e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:18:08 +08:00
Luis Pater
e5b5dc870f chore(executor): remove unused Openai-Beta header from Codex executor 2026-02-19 02:19:48 +08:00
Kirill Turanskiy
5fa23c7f41 fix: handle tool call argument streaming in Codex→OpenAI translator
The OpenAI Chat Completions translator was silently dropping
response.function_call_arguments.delta and
response.function_call_arguments.done Codex SSE events, meaning
tool call arguments were never streamed incrementally to clients.

Add proper handling mirroring the proven Claude translator pattern:

- response.output_item.added: announce tool call (id, name, empty args)
- response.function_call_arguments.delta: stream argument chunks
- response.function_call_arguments.done: emit full args if no deltas
- response.output_item.done: defensive fallback for backward compat

State tracking via HasReceivedArgumentsDelta and HasToolCallAnnounced
ensures no duplicate argument emission and correct behavior for models
like codex-spark that skip delta events entirely.
2026-02-18 19:09:05 +03:00
Kirill Turanskiy
73dc0b10b8 fix: update Claude masquerading headers and make them configurable
Update hardcoded X-Stainless-* and User-Agent defaults to match
Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (verified via
diagnostic proxy capture 2026-02-17).

Changes:
- X-Stainless-Os/Arch: dynamic via runtime.GOOS/GOARCH
- X-Stainless-Package-Version: 0.55.1 → 0.74.0
- X-Stainless-Timeout: 60 → 600
- User-Agent: claude-cli/1.0.83 (external, cli) → claude-cli/2.1.44 (external, sdk-cli)

Add claude-header-defaults config section so values can be updated
without recompilation when Claude Code releases new versions.
2026-02-18 03:38:51 +03:00
Kirill Turanskiy
2ea95266e3 fix: clamp reasoning_effort to valid OpenAI-format values
CPA-internal thinking levels like 'xhigh' and 'minimal' are not accepted
by OpenAI-format providers (MiniMax, etc.). The OpenAI applier now maps
non-standard levels to the nearest valid reasoning_effort value before
writing to the request body:

  xhigh   → high
  minimal → low
  auto    → medium
2026-02-18 03:36:42 +03:00
Kirill Turanskiy
1f8f198c45 feat: passthrough upstream response headers to clients
CPA previously stripped ALL response headers from upstream AI provider
APIs, preventing clients from seeing rate-limit info, request IDs,
server-timing and other useful headers.

Changes:
- Add Headers field to Response and StreamResult structs
- Add FilterUpstreamHeaders helper (hop-by-hop + security denylist)
- Add WriteUpstreamHeaders helper (respects CPA-set headers)
- ExecuteWithAuthManager/ExecuteCountWithAuthManager now return headers
- ExecuteStreamWithAuthManager returns headers from initial connection
- All 11 provider executors populate Response.Headers
- All handler call sites write filtered upstream headers before response

Filtered headers (not forwarded):
- RFC 7230 hop-by-hop: Connection, Transfer-Encoding, Keep-Alive, etc.
- Security: Set-Cookie
- CPA-managed: Content-Length, Content-Encoding
2026-02-18 00:16:22 +03:00
Kirill Turanskiy
9261b0c20b feat: add per-auth tool_prefix_disabled option
Allow disabling the proxy_ tool name prefix on a per-account basis.
Users who route their own Anthropic account through CPA can set
"tool_prefix_disabled": true in their OAuth auth JSON to send tool
names unchanged to Anthropic.

Default behavior is fully preserved — prefix is applied unless
explicitly disabled.

Changes:
- Add ToolPrefixDisabled() accessor to Auth (reads metadata key
  "tool_prefix_disabled" or "tool-prefix-disabled")
- Gate all 6 prefix apply/strip points with the new flag
- Add unit tests for the accessor
2026-02-17 21:48:19 +03:00
Kirill Turanskiy
7cc725496e fix: skip proxy_ prefix for built-in tools in message history
The proxy_ prefix logic correctly skips built-in tools (those with a
non-empty "type" field) in tools[] definitions but does not skip them
in messages[].content[] tool_use blocks or tool_choice. This causes
web_search in conversation history to become proxy_web_search, which
Anthropic does not recognize.

Fix: collect built-in tool names from tools[] into a set and also
maintain a hardcoded fallback set (web_search, code_execution,
text_editor, computer) for cases where the built-in tool appears in
history but not in the current request's tools[] array. Skip prefixing
in messages and tool_choice when name matches a built-in.
2026-02-17 21:42:32 +03:00
Alexey Yanchenko
709d999f9f Add usage to /v1/completions 2026-02-17 17:21:03 +07:00
Kirill Turanskiy
24c18614f0 fix: skip built-in tools in tool_reference prefix + refactor to switch
- Collect built-in tool names (those with a "type" field like
  web_search, code_execution) and skip prefixing tool_reference
  blocks that reference them, preventing name mismatch.
- Refactor if-else if chains to switch statements in all three
  prefix functions for idiomatic Go style.
2026-02-16 19:37:11 +03:00
Kirill Turanskiy
603f06a762 fix: handle tool_reference nested inside tool_result.content[]
tool_reference blocks can appear nested inside tool_result.content[]
arrays, not just at the top level of messages[].content[]. The prefix
logic now iterates into tool_result blocks with array content to find
and prefix/strip nested tool_reference.tool_name fields.
2026-02-16 19:06:24 +03:00
Kirill Turanskiy
98f0a3e3bd fix: add proxy_ prefix handling for tool_reference content blocks (#1)
applyClaudeToolPrefix, stripClaudeToolPrefixFromResponse, and
stripClaudeToolPrefixFromStreamLine now handle "tool_reference" blocks
(field "tool_name") in addition to "tool_use" blocks (field "name").

Without this fix, tool_reference blocks in conversation history retain
their original unprefixed names while tool definitions carry the proxy_
prefix, causing Anthropic API 400 errors: "Tool reference 'X' not found
in available tools."

Co-authored-by: Kirill Turanskiy <kt@novamedia.ru>
2026-02-16 19:06:24 +03:00
lhpqaq
2c8821891c fix(tui): update with review 2026-02-16 00:24:25 +08:00
haopeng
0a2555b0f3 Update internal/tui/auth_tab.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-16 00:11:31 +08:00
lhpqaq
020df41efe chore(tui): update readme, fix usage 2026-02-16 00:04:04 +08:00
lhpqaq
f31f7f701a feat(tui): add i18n 2026-02-15 15:42:59 +08:00
lhpqaq
54ad7c1b6b feat(tui): add manager tui 2026-02-15 14:52:40 +08:00
이대희
a45c6defa7 Merge remote-tracking branch 'upstream/main' into feature/canceled 2026-02-13 15:07:32 +09:00
이대희
40bee3e8d9 Merge branch 'main' into feature/canceled 2026-02-13 13:37:55 +09:00
이대희
93147dddeb Improves error handling for canceled requests
Adds explicit handling for context.Canceled errors in the reverse proxy error handler to return 499 status code without logging, which is more appropriate for client-side cancellations during polling.

Also adds a test case to verify this behavior.
2026-02-12 10:39:45 +09:00
이대희
c0f9b15a58 Merge remote-tracking branch 'upstream/main' into feature/canceled 2026-02-12 10:33:49 +09:00
이대희
6f2fbdcbae Update internal/api/modules/amp/proxy.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-12 10:30:05 +09:00
HEUDavid
65debb874f feat/auth-hook: refactor RequstInfo to preserve original HTTP semantics 2026-02-12 07:11:17 +08:00
HEUDavid
3caadac003 feat/auth-hook: add post auth hook [CR] 2026-02-12 07:11:17 +08:00
HEUDavid
6a9e3a6b84 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
269972440a feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
cce13e6ad2 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
8a565dcad8 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
d536110404 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
48e957ddff feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
94563d622c feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
이대희
ce0c6aa82b Update internal/api/modules/amp/proxy.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-10 19:07:49 +09:00
이대희
3c85d2a4d7 feature(proxy): Adds special handling for client cancellations in proxy error handler
Silences logging for client cancellations during polling to reduce noise in logs.
Client-side cancellations are common during long-running operations and should not be treated as errors.
2026-02-10 18:02:08 +09:00
95 changed files with 7528 additions and 295 deletions

View File

@@ -161,6 +161,12 @@ Those projects are ports of CLIProxyAPI or inspired by it:
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
Never stop coding. Smart routing to FREE & low-cost AI models with automatic fallback.
OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoint with smart routing, load balancing, retries, and fallbacks. Add policies, rate limits, caching, and observability for reliable, cost-aware inference.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.

View File

@@ -160,6 +160,12 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
基于 Next.js 的实现,灵感来自 CLIProxyAPI易于安装使用自研格式转换OpenAI/Claude/Gemini/Ollama、组合系统与自动回退、多账户管理指数退避、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
代码不止,创新不停。智能路由至免费及低成本 AI 模型,并支持自动故障转移。
OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼容 OpenAI 的端点,具备智能路由、负载均衡、重试及回退机制。通过添加策略、速率限制、缓存和可观测性,确保推理过程既可靠又具备成本意识。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。

View File

@@ -8,6 +8,7 @@ import (
"errors"
"flag"
"fmt"
"io"
"io/fs"
"net/url"
"os"
@@ -25,6 +26,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -56,6 +58,7 @@ func main() {
// Command-line flags to control the application's behavior.
var login bool
var codexLogin bool
var codexDeviceLogin bool
var claudeLogin bool
var qwenLogin bool
var iflowLogin bool
@@ -68,10 +71,13 @@ func main() {
var vertexImport string
var configPath string
var password string
var tuiMode bool
var standalone bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
@@ -84,6 +90,8 @@ func main() {
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output()
@@ -461,6 +469,9 @@ func main() {
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
} else if codexDeviceLogin {
// Handle Codex device-code login
cmd.DoCodexDeviceLogin(cfg, options)
} else if claudeLogin {
// Handle Claude login
cmd.DoClaudeLogin(cfg, options)
@@ -479,8 +490,83 @@ func main() {
cmd.WaitForCloudDeploy()
return
}
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
cmd.StartService(cfg, configFilePath, password)
if tuiMode {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
origStdout := os.Stdout
origStderr := os.Stderr
origLogOutput := log.StandardLogger().Out
log.SetOutput(io.Discard)
devNull, errOpenDevNull := os.Open(os.DevNull)
if errOpenDevNull == nil {
os.Stdout = devNull
os.Stderr = devNull
}
restoreIO := func() {
os.Stdout = origStdout
os.Stderr = origStderr
log.SetOutput(origLogOutput)
if devNull != nil {
_ = devNull.Close()
}
}
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
if password == "" {
password = localMgmtPassword
}
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
client := tui.NewClient(cfg.Port, password)
ready := false
backoff := 100 * time.Millisecond
for i := 0; i < 30; i++ {
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
ready = true
break
}
time.Sleep(backoff)
if backoff < time.Second {
backoff = time.Duration(float64(backoff) * 1.5)
}
}
if !ready {
restoreIO()
cancel()
<-done
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
return
}
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
restoreIO()
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
} else {
restoreIO()
}
cancel()
<-done
} else {
// Default TUI mode: pure management client.
// The proxy server must already be running.
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
cmd.StartService(cfg, configFilePath, password)
}
}
}

View File

@@ -68,6 +68,10 @@ proxy-url: ""
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
# When true, forward filtered upstream response headers to downstream clients.
# Default is false (disabled).
passthrough-headers: false
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
request-retry: 3
@@ -155,6 +159,15 @@ nonstream-keepalive-interval: 0
# sensitive-words: # optional: words to obfuscate with zero-width characters
# - "API"
# - "proxy"
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# These are used as fallbacks when the client does not send its own headers.
# claude-header-defaults:
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
# package-version: "0.74.0"
# runtime-version: "v24.3.0"
# timeout: "600"
# OpenAI compatibility providers
# openai-compatibility:

View File

@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
return clipexec.Response{}, errors.New("count tokens not implemented")
}
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
ch := make(chan clipexec.StreamChunk, 1)
go func() {
defer close(ch)
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
}()
return ch, nil
return &clipexec.StreamResult{Chunks: ch}, nil
}
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {

View File

@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
}
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
return nil, errors.New("echo executor: ExecuteStream not implemented")
}

21
go.mod
View File

@@ -4,6 +4,10 @@ go 1.26.0
require (
github.com/andybalholm/brotli v1.0.6
github.com/atotto/clipboard v0.1.4
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.1
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
@@ -31,8 +35,16 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
@@ -40,6 +52,7 @@ require (
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-git/gcfg/v2 v2.0.2 // indirect
@@ -56,19 +69,27 @@ require (
github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sergi/go-diff v1.4.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect

45
go.sum
View File

@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -99,8 +125,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
@@ -112,6 +144,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
@@ -120,6 +158,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
@@ -159,17 +199,22 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=

View File

@@ -808,6 +808,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
changed := false
if req.Prefix != nil {
targetAuth.Prefix = *req.Prefix
changed = true
}
if req.ProxyURL != nil {
targetAuth.ProxyURL = *req.ProxyURL
changed = true
}
if req.Priority != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
}
changed = true
}
if !changed {
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
return
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return
@@ -864,11 +945,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
if store == nil {
return "", fmt.Errorf("token store unavailable")
}
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record)
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...")
@@ -1013,6 +1100,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
@@ -1271,6 +1359,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...")
@@ -1416,6 +1505,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...")
@@ -1580,6 +1670,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...")
@@ -1635,6 +1726,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...")
@@ -1711,6 +1803,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...")
@@ -2331,3 +2424,12 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
}
c.JSON(http.StatusOK, gin.H{"status": "wait"})
}
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: c.Request.URL.Query(),
Headers: c.Request.Header,
}
return coreauth.WithRequestInfo(ctx, info)
}

View File

@@ -47,6 +47,7 @@ type Handler struct {
allowRemoteOverride bool
envSecret string
logDir string
postAuthHook coreauth.PostAuthHook
}
// NewHandler creates a new management handler instance.
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
h.logDir = dir
}
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.

View File

@@ -3,6 +3,8 @@ package amp
import (
"bytes"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"net/http"
@@ -188,6 +190,10 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// Error handler for proxy failures
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
// Client-side cancellations are common during polling; suppress logging in this case
if errors.Is(err, context.Canceled) {
return
}
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusBadGateway)

View File

@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
}
}
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
// Test that context.Canceled errors return 499 without generic error response
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
// Create a canceled context to trigger the cancellation path
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
rr := httptest.NewRecorder()
// Directly invoke the ErrorHandler with context.Canceled
proxy.ErrorHandler(rr, req, context.Canceled)
// Body should be empty for canceled requests (no JSON error response)
body := rr.Body.Bytes()
if len(body) > 0 {
t.Fatalf("expected empty body for canceled context, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
// Upstream returns gzipped JSON without Content-Encoding header
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@@ -51,6 +51,7 @@ type serverOptionConfig struct {
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
}
// ServerOption customises HTTP server construction.
@@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
}
}
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
@@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword
// Setup routes
@@ -284,8 +295,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
optionState.routerConfigurator(engine, s.handlers, cfg)
}
// Register management routes when configuration or environment secrets are available.
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
// Register management routes when configuration or environment secrets are available,
// or when a local management password is provided (e.g. TUI mode).
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
s.managementRoutesEnabled.Store(hasManagementSecret)
if hasManagementSecret {
s.registerManagementRoutes()
@@ -617,6 +629,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)

View File

@@ -20,7 +20,7 @@ import (
// OAuth configuration constants for Claude/Anthropic
const (
AuthURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
TokenURL = "https://api.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
)

View File

@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Claude token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
// Encode and write the token data as JSON
if err = json.NewEncoder(f).Encode(ts); err != nil {
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
// authorization code and PKCE verifier.
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
}
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
// a caller-provided redirect URI. This supports alternate auth flows such as device
// login while preserving the existing token parsing and storage behavior.
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("PKCE codes are required for token exchange")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("redirect URI is required for token exchange")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {RedirectURI},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {pkceCodes.CodeVerifier},
}
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
if err == nil {
return tokenData, nil
}
if isNonRetryableRefreshErr(err) {
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
return nil, err
}
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
}
func isNonRetryableRefreshErr(err error) bool {
if err == nil {
return false
}
raw := strings.ToLower(err.Error())
return strings.Contains(raw, "refresh_token_reused")
}
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
// This is typically called after a successful token refresh to persist the new credentials.
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {

View File

@@ -0,0 +1,44 @@
package codex
import (
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
var calls int32
auth := &CodexAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected error for non-retryable refresh failure")
}
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt, got %d", got)
}
}

View File

@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Codex token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
// Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "gemini"
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
Scope string `json:"scope"`
Cookie string `json:"cookie"`
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serialises the token storage to disk.
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
}
defer func() { _ = f.Close() }()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("iflow token: encode token failed: %w", err)
}
return nil

View File

@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
Expired string `json:"expired,omitempty"`
// Type indicates the authentication provider type, always "kimi" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// KimiTokenData holds the raw OAuth token response from Kimi.
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
if err = encoder.Encode(ts); err != nil {
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -0,0 +1,60 @@
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
)
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
// existing codex-login OAuth callback flow intact.
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
codexLoginModeMetadataKey: codexLoginModeDevice,
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)
}
return
}
fmt.Printf("Codex device authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Codex device authentication successful!")
}

View File

@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
}
}
// StartServiceBackground starts the proxy service in a background goroutine
// and returns a cancel function for shutdown and a done channel.
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
builder := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath(configPath).
WithLocalManagementPassword(localPassword)
ctx, cancelFn := context.WithCancel(context.Background())
doneCh := make(chan struct{})
service, err := builder.Build()
if err != nil {
log.Errorf("failed to build proxy service: %v", err)
close(doneCh)
return cancelFn, doneCh
}
go func() {
defer close(doneCh)
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("proxy service exited with error: %v", err)
}
}()
return cancelFn, doneCh
}
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
// when no configuration file is available.
func WaitForCloudDeploy() {

View File

@@ -90,6 +90,10 @@ type Config struct {
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
// ClaudeHeaderDefaults configures default header values for Claude API requests.
// These are used as fallbacks when the client does not send its own headers.
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
@@ -117,6 +121,15 @@ type Config struct {
legacyMigrationPending bool `yaml:"-" json:"-"`
}
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
// when the client does not send them. Update these when Claude Code releases a new version.
type ClaudeHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
Timeout string `yaml:"timeout" json:"timeout"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
@@ -288,6 +301,10 @@ type CloakConfig struct {
// SensitiveWords is a list of words to obfuscate with zero-width characters.
// This can help bypass certain content filters.
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
// CacheUserID controls whether Claude user_id values are cached per API key.
// When false, a fresh random user_id is generated for every request.
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
}
// ClaudeKey represents the configuration for a Claude API key,

View File

@@ -20,6 +20,10 @@ type SDKConfig struct {
// APIKeys is a list of keys for authenticating clients to this proxy server.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
// Default is false (disabled).
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`

View File

@@ -1,6 +1,7 @@
package misc
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
func LogCredentialSeparator() {
log.Debug(credentialSeparator)
}
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
var data map[string]any
// Fast path: if source is already a map, just copy it to avoid mutation of original
if srcMap, ok := source.(map[string]any); ok {
data = make(map[string]any, len(srcMap)+len(metadata))
for k, v := range srcMap {
data[k] = v
}
} else {
// Slow path: marshal to JSON and back to map to respect JSON tags
temp, err := json.Marshal(source)
if err != nil {
return nil, fmt.Errorf("failed to marshal source: %w", err)
}
if err := json.Unmarshal(temp, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
}
// Merge extra metadata
if metadata != nil {
if data == nil {
data = make(map[string]any)
}
for k, v := range metadata {
data[k] = v
}
}
return data, nil
}

View File

@@ -184,6 +184,21 @@ func GetGeminiModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -294,6 +309,21 @@ func GetGeminiVertexModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
@@ -436,6 +466,21 @@ func GetGeminiCLIModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -517,6 +562,21 @@ func GetAIStudioModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -903,6 +963,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},

View File

@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the AI Studio API.
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(first wsrelay.StreamEvent) {
defer close(out)
var param any
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
}
}(firstEvent)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the AI Studio API.

View File

@@ -232,7 +232,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -436,7 +436,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
@@ -645,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
}
// ExecuteStream performs a streaming request to the Antigravity API.
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -775,7 +775,6 @@ attemptLoop:
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response) {
defer close(out)
defer func() {
@@ -820,7 +819,7 @@ attemptLoop:
reporter.ensurePublished(ctx)
}
}(httpResp)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
switch {
@@ -968,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
}
lastStatus = httpResp.StatusCode

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"runtime"
"strings"
"time"
@@ -116,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if err != nil {
return resp, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.publish(ctx, parseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
var param any
@@ -222,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
data,
&param,
)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -257,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if err != nil {
return nil, err
}
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -329,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -348,7 +348,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
// Forward the line as-is to preserve SSE format
@@ -375,7 +375,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
chunks := sdktranslator.TranslateStream(
@@ -398,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -423,7 +423,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -432,7 +432,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
if err != nil {
return cliproxyexecutor.Response{}, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -487,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
}
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
@@ -638,7 +638,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
return body, nil
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
hdrDefault := func(cfgVal, fallback string) string {
if cfgVal != "" {
return cfgVal
}
return fallback
}
var hd config.ClaudeHeaderDefaults
if cfg != nil {
hd = cfg.ClaudeHeaderDefaults
}
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
@@ -685,16 +727,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
r.Header.Set("Connection", "keep-alive")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
if stream {
@@ -702,6 +745,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
} else {
r.Header.Set("Accept", "application/json")
}
// Keep OS/Arch mapping dynamic (not configurable).
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
@@ -753,11 +798,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return body
}
// Collect built-in tool names (those with a non-empty "type" field) so we can
// skip them consistently in both tools and message history.
builtinTools := map[string]bool{}
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
builtinTools[name] = true
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have
// a "type" field and require their name to remain unchanged.
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
if n := tool.Get("name").String(); n != "" {
builtinTools[n] = true
}
return true
}
name := tool.Get("name").String()
@@ -772,7 +827,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
name := gjson.GetBytes(body, "tool_choice.name").String()
if name != "" && !strings.HasPrefix(name, prefix) {
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
}
}
@@ -784,15 +839,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return true
}
content.ForEach(func(contentIndex, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
case "tool_reference":
toolName := part.Get("tool_name").String()
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+toolName)
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
}
}
return true
})
}
}
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
return true
@@ -811,15 +889,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
return body
}
content.ForEach(func(index, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
case "tool_reference":
toolName := part.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return true
}
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if strings.HasPrefix(nestedToolName, prefix) {
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
}
}
return true
})
}
}
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
return true
})
return body
@@ -834,15 +935,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
return line
}
contentBlock := gjson.GetBytes(payload, "content_block")
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
if !contentBlock.Exists() {
return line
}
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
blockType := contentBlock.Get("type").String()
var updated []byte
var err error
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
return line
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
if err != nil {
return line
}
default:
return line
}
@@ -862,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string {
}
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
// Returns (cloakMode, strictMode, sensitiveWords).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
if auth == nil || auth.Attributes == nil {
return "auto", false, nil
return "auto", false, nil, false
}
cloakMode := auth.Attributes["cloak_mode"]
@@ -883,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
}
}
return cloakMode, strictMode, sensitiveWords
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
return cloakMode, strictMode, sensitiveWords, cacheUserID
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
@@ -916,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c
}
// injectFakeUserID generates and injects a fake user ID into the request metadata.
func injectFakeUserID(payload []byte) []byte {
// When useCache is false, a new user ID is generated for every call.
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
generateID := func() string {
if useCache {
return cachedUserID(apiKey)
}
return generateFakeUserID()
}
metadata := gjson.GetBytes(payload, "metadata")
if !metadata.Exists() {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
return payload
}
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
if existingUserID == "" || !isValidUserID(existingUserID) {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
}
return payload
}
@@ -962,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// applyCloaking applies cloaking transformations to the payload based on config and client.
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
clientUserAgent := getClientUserAgent(ctx)
// Get cloak config from ClaudeKey configuration
@@ -972,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
var cloakMode string
var strictMode bool
var sensitiveWords []string
var cacheUserID bool
if cloakCfg != nil {
cloakMode = cloakCfg.Mode
strictMode = cloakCfg.StrictMode
sensitiveWords = cloakCfg.SensitiveWords
if cloakCfg.CacheUserID != nil {
cacheUserID = *cloakCfg.CacheUserID
}
}
// Fallback to auth attributes if no config found
if cloakMode == "" {
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
cloakMode = attrMode
if !strictMode {
strictMode = attrStrict
@@ -989,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
if len(sensitiveWords) == 0 {
sensitiveWords = attrWords
}
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
cacheUserID = attrCache
}
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
_, _, _, attrCache := getCloakConfigFromAuth(auth)
cacheUserID = attrCache
}
// Determine if cloaking should be applied
@@ -1002,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
}
// Inject fake user ID
payload = injectFakeUserID(payload)
payload = injectFakeUserID(payload, apiKey, cacheUserID)
// Apply sensitive word obfuscation
if len(sensitiveWords) > 0 {

View File

@@ -2,9 +2,18 @@ package executor
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestApplyClaudeToolPrefix(t *testing.T) {
@@ -25,6 +34,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
@@ -37,6 +58,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
body := []byte(`{
"tools": [
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
body := []byte(`{
"tools": [{"name": "Read"}, {"name": "Write"}],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
}
}
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"name": "Read"}
],
"tool_choice": {"type": "tool", "name": "web_search"}
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -49,6 +161,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
}
}
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
@@ -61,3 +185,166 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
}
}
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
}
}
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "proxy_mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
}
}
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
resetUserIDCache()
var userIDs []string
var requestModels []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userID := gjson.GetBytes(body, "metadata.user_id").String()
model := gjson.GetBytes(body, "model").String()
userIDs = append(userIDs, userID)
requestModels = append(requestModels, model)
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
cacheEnabled := true
executor := NewClaudeExecutor(&config.Config{
ClaudeKey: []config.ClaudeKey{
{
APIKey: "key-123",
BaseURL: server.URL,
Cloak: &config.CloakConfig{
CacheUserID: &cacheEnabled,
},
},
},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
for _, model := range models {
t.Logf("Sending request for model: %s", model)
modelPayload, _ := sjson.SetBytes(payload, "model", model)
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: model,
Payload: modelPayload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute(%s) error: %v", model, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
if userIDs[0] != userIDs[1] {
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
}
if !isValidUserID(userIDs[0]) {
t.Fatalf("user_id %q is not valid", userIDs[0])
}
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
}
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
resetUserIDCache()
var userIDs []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
for i := 0; i < 2; i++ {
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute call %d error: %v", i, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
if userIDs[0] == userIDs[1] {
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
}
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
}
}
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
if got != "mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
}
}
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
// tool_result.content can be a string - should not be processed
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
if got != "plain string result" {
t.Fatalf("string content should remain unchanged = %q", got)
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "web_search" {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}

View File

@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
}
@@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -643,7 +642,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)

View File

@@ -363,7 +363,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
}
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
if ctx == nil {
ctx = context.Background()
@@ -436,7 +436,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
})
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
var upstreamHeaders http.Header
if respHS != nil {
upstreamHeaders = respHS.Header.Clone()
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
@@ -516,7 +518,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
terminateReason := "completed"
var terminateErr error
@@ -627,7 +628,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil
}
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
@@ -1343,7 +1344,7 @@ func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return e.httpExec.Execute(ctx, auth, req, opts)
}
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if e == nil || e.httpExec == nil || e.wsExec == nil {
return nil, fmt.Errorf("codex auto executor: executor is nil")
}

View File

@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
// ExecuteStream performs a streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out)
defer func() {
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
if len(lastBody) > 0 {
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...)

View File

@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the Gemini API.
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the Gemini API.
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
// Refresh refreshes the authentication credentials (no-op for Gemini API key).

View File

@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
// ExecuteStream performs a streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
to := sdktranslator.FromString("gemini")
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// countTokensWithServiceAccount counts tokens using service account credentials.
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.

View File

@@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request.
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter.ensurePublished(ctx)
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {

View File

@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request to Kimi.
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
from := opts.SourceFormat
if from.String() == "claude" {
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens estimates token count for Kimi requests.

View File

@@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate response back to source format when needed
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
// Ensure we record the request if no usage chunk was ever seen
reporter.ensurePublished(ctx)
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
@@ -22,9 +23,151 @@ import (
)
const (
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
qwenRateLimitWindow = time.Minute // sliding window duration
)
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
var qwenBeijingLoc = func() *time.Location {
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil || loc == nil {
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
return time.FixedZone("CST", 8*3600)
}
return loc
}()
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
var qwenQuotaCodes = map[string]struct{}{
"insufficient_quota": {},
"quota_exceeded": {},
}
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
// Qwen has a limit of 60 requests per minute per account.
var qwenRateLimiter = struct {
sync.Mutex
requests map[string][]time.Time // authID -> request timestamps
}{
requests: make(map[string][]time.Time),
}
// redactAuthID returns a redacted version of the auth ID for safe logging.
// Keeps a small prefix/suffix to allow correlation across events.
func redactAuthID(id string) string {
if id == "" {
return ""
}
if len(id) <= 8 {
return id
}
return id[:4] + "..." + id[len(id)-4:]
}
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
func checkQwenRateLimit(authID string) error {
if authID == "" {
// Empty authID should not bypass rate limiting in production
// Use debug level to avoid log spam for certain auth flows
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
return nil
}
now := time.Now()
windowStart := now.Add(-qwenRateLimitWindow)
qwenRateLimiter.Lock()
defer qwenRateLimiter.Unlock()
// Get and filter timestamps within the window
timestamps := qwenRateLimiter.requests[authID]
var validTimestamps []time.Time
for _, ts := range timestamps {
if ts.After(windowStart) {
validTimestamps = append(validTimestamps, ts)
}
}
// Always prune expired entries to prevent memory leak
// Delete empty entries, otherwise update with pruned slice
if len(validTimestamps) == 0 {
delete(qwenRateLimiter.requests, authID)
}
// Check if rate limit exceeded
if len(validTimestamps) >= qwenRateLimitPerMin {
// Calculate when the oldest request will expire
oldestInWindow := validTimestamps[0]
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
if retryAfter < time.Second {
retryAfter = time.Second
}
retryAfterSec := int(retryAfter.Seconds())
return statusErr{
code: http.StatusTooManyRequests,
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
retryAfter: &retryAfter,
}
}
// Record this request and update the map with pruned timestamps
validTimestamps = append(validTimestamps, now)
qwenRateLimiter.requests[authID] = validTimestamps
return nil
}
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
func isQwenQuotaError(body []byte) bool {
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
// Primary check: exact match on error.code or error.type (most reliable)
if _, ok := qwenQuotaCodes[code]; ok {
return true
}
if _, ok := qwenQuotaCodes[errType]; ok {
return true
}
// Fallback: check message only if code/type don't match (less reliable)
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
strings.Contains(msg, "free allocated quota exceeded") {
return true
}
return false
}
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
// Returns the appropriate status code and retryAfter duration for statusErr.
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
errCode = httpCode
// Only check quota errors for expected status codes to avoid false positives
// Qwen returns 403 for quota errors, 429 for rate limits
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
cooldown := timeUntilNextDay()
retryAfter = &cooldown
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
}
return errCode, retryAfter
}
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
// Qwen's daily quota resets at 00:00 Beijing time.
func timeUntilNextDay() time.Duration {
now := time.Now()
nowLocal := now.In(qwenBeijingLoc)
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
return tomorrow.Sub(now)
}
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
// If access token is unavailable, it falls back to legacy via ClientAdapter.
type QwenExecutor struct {
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return resp, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
applyQwenHeaders(httpReq, token, false)
var authID, authLabel, authType, authValue string
var authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
@@ -150,14 +305,25 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return nil, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
applyQwenHeaders(httpReq, token, true)
var authID, authLabel, authType, authValue string
var authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
@@ -228,15 +393,16 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -268,7 +434,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {

View File

@@ -0,0 +1,89 @@
package executor
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
)
type userIDCacheEntry struct {
value string
expire time.Time
}
var (
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu sync.RWMutex
userIDCacheCleanupOnce sync.Once
)
const (
userIDTTL = time.Hour
userIDCacheCleanupPeriod = 15 * time.Minute
)
func startUserIDCacheCleanup() {
go func() {
ticker := time.NewTicker(userIDCacheCleanupPeriod)
defer ticker.Stop()
for range ticker.C {
purgeExpiredUserIDs()
}
}()
}
func purgeExpiredUserIDs() {
now := time.Now()
userIDCacheMu.Lock()
for key, entry := range userIDCache {
if !entry.expire.After(now) {
delete(userIDCache, key)
}
}
userIDCacheMu.Unlock()
}
func userIDCacheKey(apiKey string) string {
sum := sha256.Sum256([]byte(apiKey))
return hex.EncodeToString(sum[:])
}
func cachedUserID(apiKey string) string {
if apiKey == "" {
return generateFakeUserID()
}
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
key := userIDCacheKey(apiKey)
now := time.Now()
userIDCacheMu.RLock()
entry, ok := userIDCache[key]
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
userIDCacheMu.RUnlock()
if valid {
userIDCacheMu.Lock()
entry = userIDCache[key]
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
entry.expire = now.Add(userIDTTL)
userIDCache[key] = entry
userIDCacheMu.Unlock()
return entry.value
}
userIDCacheMu.Unlock()
}
newID := generateFakeUserID()
userIDCacheMu.Lock()
entry, ok = userIDCache[key]
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
entry.value = newID
}
entry.expire = now.Add(userIDTTL)
userIDCache[key] = entry
userIDCacheMu.Unlock()
return entry.value
}

View File

@@ -0,0 +1,86 @@
package executor
import (
"testing"
"time"
)
func resetUserIDCache() {
userIDCacheMu.Lock()
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu.Unlock()
}
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-1")
if first == "" {
t.Fatal("expected generated user_id to be non-empty")
}
if first != second {
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
}
}
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
resetUserIDCache()
expiredID := cachedUserID("api-key-expired")
cacheKey := userIDCacheKey("api-key-expired")
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
value: expiredID,
expire: time.Now().Add(-time.Minute),
}
userIDCacheMu.Unlock()
newID := cachedUserID("api-key-expired")
if newID == expiredID {
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
}
if newID == "" {
t.Fatal("expected regenerated user_id to be non-empty")
}
}
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-2")
if first == second {
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
}
}
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
resetUserIDCache()
key := "api-key-renew"
id := cachedUserID(key)
cacheKey := userIDCacheKey(key)
soon := time.Now()
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
value: id,
expire: soon.Add(2 * time.Second),
}
userIDCacheMu.Unlock()
if refreshed := cachedUserID(key); refreshed != id {
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
}
userIDCacheMu.RLock()
entry := userIDCache[cacheKey]
userIDCacheMu.RUnlock()
if entry.expire.Sub(soon) < 30*time.Minute {
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
}
}

View File

@@ -231,8 +231,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if functionResponseResult.IsObject() {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
} else {
} else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
} else {
// Content field is missing entirely — .Raw is empty which
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
}
partJSON := `{}`

View File

@@ -661,6 +661,85 @@ func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
// Bug repro: tool_result with no content field produces invalid JSON
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "MyTool-123-456",
"name": "MyTool",
"input": {"key": "value"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "MyTool-123-456"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Errorf("Result is not valid JSON:\n%s", outputStr)
}
// Verify the functionResponse has a valid result value
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
if !fr.Exists() {
t.Error("functionResponse.response.result should exist")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
// Bug repro: tool_result with null content produces invalid JSON
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "MyTool-123-456",
"name": "MyTool",
"input": {"key": "value"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "MyTool-123-456",
"content": null
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Errorf("Result is not valid JSON:\n%s", outputStr)
}
}
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
// When tools + thinking but no system instruction, should create one with hint
inputJSON := []byte(`{

View File

@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
}
}
}
return true
})

View File

@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var textAggregate strings.Builder
var partsJSON []string
hasImage := false
hasFile := false
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
ptype := part.Get("type").String()
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
hasImage = true
}
}
case "input_file":
fileData := part.Get("file_data").String()
if fileData != "" {
mediaType := "application/octet-stream"
data := fileData
if strings.HasPrefix(fileData, "data:") {
trimmed := strings.TrimPrefix(fileData, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mediaType = mediaAndData[0]
}
data = mediaAndData[1]
}
}
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart)
if role == "" {
role = "user"
}
hasFile = true
}
}
return true
})
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
if len(partsJSON) == 1 && !hasImage {
if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content")
textPart := gjson.Parse(partsJSON[0])

View File

@@ -22,8 +22,9 @@ var (
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
HasToolCall bool
BlockIndex int
HasReceivedArgumentsDelta bool
}
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
@@ -137,6 +138,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
@@ -171,12 +173,29 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output += fmt.Sprintf("data: %s\n\n", template)
}
} else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events.
// Emit the full arguments as a single input_json_delta so the
// downstream Claude client receives the complete tool input.
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", args)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
}
}
}
return []string{output}

View File

@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
case "file":
// Files are not specified in examples; skip for now
if role == "user" {
fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String()
if fileData != "" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData)
if filename != "" {
part, _ = sjson.Set(part, "filename", filename)
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
}
}
}
}

View File

@@ -20,10 +20,12 @@ var (
// ConvertCliToOpenAIParams holds parameters for response conversion.
type ConvertCliToOpenAIParams struct {
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
HasReceivedArgumentsDelta bool
HasToolCallAnnounced bool
}
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertCliToOpenAIParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
HasReceivedArgumentsDelta: false,
HasToolCallAnnounced: false,
}
}
@@ -118,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
} else if dataType == "response.output_item.done" {
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
} else if dataType == "response.output_item.added" {
itemResult := rootResult.Get("item")
if itemResult.Exists() {
if itemResult.Get("type").String() != "function_call" {
return []string{}
}
// set the index
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened
name := itemResult.Get("name").String()
// Build reverse map on demand from original request tools
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
// Increment index for this new function call item.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.delta" {
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
deltaValue := rootResult.Get("delta").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.done" {
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
// Arguments were already streamed via delta events; nothing to emit.
return []string{}
}
// Fallback: no delta events were received, emit the full arguments as a single chunk.
fullArgs := rootResult.Get("arguments").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.output_item.done" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
// Tool call was already announced via output_item.added; skip emission.
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
return []string{}
}
// Fallback path: model skipped output_item.added, so emit complete tool call now.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else {
return []string{}
}

View File

@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
// Delete the user field as it is not supported by the Codex upstream.
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
return rawJSON
}
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
// for Codex upstream compatibility.
//
// Codex /responses currently rejects context_management with:
// {"detail":"Unsupported parameter: context_management"}.
//
// Compatibility strategy:
// 1) Remove context_management before forwarding to Codex upstream.
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
return rawJSON
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
return rawJSON
}
// convertSystemRoleToDeveloper traverses the input array and converts any message items
// with role "system" to role "developer". This is necessary because Codex API does not
// accept "system" role in the input array.

View File

@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
}
}
func TestContextManagementCompactionCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"context_management": [
{
"type": "compaction",
"compact_threshold": 12000
}
],
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "context_management").Exists() {
t.Fatalf("context_management should be removed for Codex compatibility")
}
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"truncation": "disabled",
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}

View File

@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())

542
internal/tui/app.go Normal file
View File

@@ -0,0 +1,542 @@
package tui
import (
"fmt"
"io"
"os"
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// Tab identifiers
const (
tabDashboard = iota
tabConfig
tabAuthFiles
tabAPIKeys
tabOAuth
tabUsage
tabLogs
)
// App is the root bubbletea model that contains all tab sub-models.
type App struct {
activeTab int
tabs []string
standalone bool
logsEnabled bool
authenticated bool
authInput textinput.Model
authError string
authConnecting bool
dashboard dashboardModel
config configTabModel
auth authTabModel
keys keysTabModel
oauth oauthTabModel
usage usageTabModel
logs logsTabModel
client *Client
width int
height int
ready bool
// Track which tabs have been initialized (fetched data)
initialized [7]bool
}
type authConnectMsg struct {
cfg map[string]any
err error
}
// NewApp creates the root TUI application model.
func NewApp(port int, secretKey string, hook *LogHook) App {
standalone := hook != nil
authRequired := !standalone
ti := textinput.New()
ti.CharLimit = 512
ti.EchoMode = textinput.EchoPassword
ti.EchoCharacter = '*'
ti.SetValue(strings.TrimSpace(secretKey))
ti.Focus()
client := NewClient(port, secretKey)
app := App{
activeTab: tabDashboard,
standalone: standalone,
logsEnabled: true,
authenticated: !authRequired,
authInput: ti,
dashboard: newDashboardModel(client),
config: newConfigTabModel(client),
auth: newAuthTabModel(client),
keys: newKeysTabModel(client),
oauth: newOAuthTabModel(client),
usage: newUsageTabModel(client),
logs: newLogsTabModel(client, hook),
client: client,
initialized: [7]bool{
tabDashboard: true,
tabLogs: true,
},
}
app.refreshTabs()
if authRequired {
app.initialized = [7]bool{}
}
app.setAuthInputPrompt()
return app
}
func (a App) Init() tea.Cmd {
if !a.authenticated {
return textinput.Blink
}
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
cmds = append(cmds, a.logs.Init())
}
return tea.Batch(cmds...)
}
func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
a.width = msg.Width
a.height = msg.Height
a.ready = true
if a.width > 0 {
a.authInput.Width = a.width - 6
}
contentH := a.height - 4 // tab bar + status bar
if contentH < 1 {
contentH = 1
}
contentW := a.width
a.dashboard.SetSize(contentW, contentH)
a.config.SetSize(contentW, contentH)
a.auth.SetSize(contentW, contentH)
a.keys.SetSize(contentW, contentH)
a.oauth.SetSize(contentW, contentH)
a.usage.SetSize(contentW, contentH)
a.logs.SetSize(contentW, contentH)
return a, nil
case authConnectMsg:
a.authConnecting = false
if msg.err != nil {
a.authError = fmt.Sprintf(T("auth_gate_connect_fail"), msg.err.Error())
return a, nil
}
a.authError = ""
a.authenticated = true
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
a.refreshTabs()
a.initialized = [7]bool{}
a.initialized[tabDashboard] = true
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
a.initialized[tabLogs] = true
cmds = append(cmds, a.logs.Init())
}
return a, tea.Batch(cmds...)
case configUpdateMsg:
var cmdLogs tea.Cmd
if !a.standalone && msg.err == nil && msg.path == "logging-to-file" {
logsEnabledConfig, okConfig := msg.value.(bool)
if okConfig {
logsEnabledBefore := a.logsEnabled
a.logsEnabled = logsEnabledConfig
if logsEnabledBefore != a.logsEnabled {
a.refreshTabs()
}
if !a.logsEnabled {
a.initialized[tabLogs] = false
}
if !logsEnabledBefore && a.logsEnabled {
a.initialized[tabLogs] = true
cmdLogs = a.logs.Init()
}
}
}
var cmdConfig tea.Cmd
a.config, cmdConfig = a.config.Update(msg)
if cmdConfig != nil && cmdLogs != nil {
return a, tea.Batch(cmdConfig, cmdLogs)
}
if cmdConfig != nil {
return a, cmdConfig
}
return a, cmdLogs
case tea.KeyMsg:
if !a.authenticated {
switch msg.String() {
case "ctrl+c", "q":
return a, tea.Quit
case "L":
ToggleLocale()
a.refreshTabs()
a.setAuthInputPrompt()
return a, nil
case "enter":
if a.authConnecting {
return a, nil
}
password := strings.TrimSpace(a.authInput.Value())
if password == "" {
a.authError = T("auth_gate_password_required")
return a, nil
}
a.authError = ""
a.authConnecting = true
return a, a.connectWithPassword(password)
default:
var cmd tea.Cmd
a.authInput, cmd = a.authInput.Update(msg)
return a, cmd
}
}
switch msg.String() {
case "ctrl+c":
return a, tea.Quit
case "q":
// Only quit if not in logs tab (where 'q' might be useful)
if !a.logsEnabled || a.activeTab != tabLogs {
return a, tea.Quit
}
case "L":
ToggleLocale()
a.refreshTabs()
return a.broadcastToAllTabs(localeChangedMsg{})
case "tab":
if len(a.tabs) == 0 {
return a, nil
}
prevTab := a.activeTab
a.activeTab = (a.activeTab + 1) % len(a.tabs)
return a, a.initTabIfNeeded(prevTab)
case "shift+tab":
if len(a.tabs) == 0 {
return a, nil
}
prevTab := a.activeTab
a.activeTab = (a.activeTab - 1 + len(a.tabs)) % len(a.tabs)
return a, a.initTabIfNeeded(prevTab)
}
}
if !a.authenticated {
var cmd tea.Cmd
a.authInput, cmd = a.authInput.Update(msg)
return a, cmd
}
// Route msg to active tab
var cmd tea.Cmd
switch a.activeTab {
case tabDashboard:
a.dashboard, cmd = a.dashboard.Update(msg)
case tabConfig:
a.config, cmd = a.config.Update(msg)
case tabAuthFiles:
a.auth, cmd = a.auth.Update(msg)
case tabAPIKeys:
a.keys, cmd = a.keys.Update(msg)
case tabOAuth:
a.oauth, cmd = a.oauth.Update(msg)
case tabUsage:
a.usage, cmd = a.usage.Update(msg)
case tabLogs:
a.logs, cmd = a.logs.Update(msg)
}
// Keep logs polling alive even when logs tab is not active.
if a.logsEnabled && a.activeTab != tabLogs {
switch msg.(type) {
case logsPollMsg, logsTickMsg, logLineMsg:
var logCmd tea.Cmd
a.logs, logCmd = a.logs.Update(msg)
if logCmd != nil {
cmd = logCmd
}
}
}
return a, cmd
}
// localeChangedMsg is broadcast to all tabs when the user toggles locale.
type localeChangedMsg struct{}
func (a *App) refreshTabs() {
names := TabNames()
if a.logsEnabled {
a.tabs = names
} else {
filtered := make([]string, 0, len(names)-1)
for idx, name := range names {
if idx == tabLogs {
continue
}
filtered = append(filtered, name)
}
a.tabs = filtered
}
if len(a.tabs) == 0 {
a.activeTab = tabDashboard
return
}
if a.activeTab >= len(a.tabs) {
a.activeTab = len(a.tabs) - 1
}
}
func (a *App) initTabIfNeeded(_ int) tea.Cmd {
if a.initialized[a.activeTab] {
return nil
}
a.initialized[a.activeTab] = true
switch a.activeTab {
case tabDashboard:
return a.dashboard.Init()
case tabConfig:
return a.config.Init()
case tabAuthFiles:
return a.auth.Init()
case tabAPIKeys:
return a.keys.Init()
case tabOAuth:
return a.oauth.Init()
case tabUsage:
return a.usage.Init()
case tabLogs:
if !a.logsEnabled {
return nil
}
return a.logs.Init()
}
return nil
}
func (a App) View() string {
if !a.authenticated {
return a.renderAuthView()
}
if !a.ready {
return T("initializing_tui")
}
var sb strings.Builder
// Tab bar
sb.WriteString(a.renderTabBar())
sb.WriteString("\n")
// Content
switch a.activeTab {
case tabDashboard:
sb.WriteString(a.dashboard.View())
case tabConfig:
sb.WriteString(a.config.View())
case tabAuthFiles:
sb.WriteString(a.auth.View())
case tabAPIKeys:
sb.WriteString(a.keys.View())
case tabOAuth:
sb.WriteString(a.oauth.View())
case tabUsage:
sb.WriteString(a.usage.View())
case tabLogs:
if a.logsEnabled {
sb.WriteString(a.logs.View())
}
}
// Status bar
sb.WriteString("\n")
sb.WriteString(a.renderStatusBar())
return sb.String()
}
func (a App) renderAuthView() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("auth_gate_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_gate_help")))
sb.WriteString("\n\n")
if a.authConnecting {
sb.WriteString(warningStyle.Render(T("auth_gate_connecting")))
sb.WriteString("\n\n")
}
if strings.TrimSpace(a.authError) != "" {
sb.WriteString(errorStyle.Render(a.authError))
sb.WriteString("\n\n")
}
sb.WriteString(a.authInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_gate_enter")))
return sb.String()
}
func (a App) renderTabBar() string {
var tabs []string
for i, name := range a.tabs {
if i == a.activeTab {
tabs = append(tabs, tabActiveStyle.Render(name))
} else {
tabs = append(tabs, tabInactiveStyle.Render(name))
}
}
tabBar := lipgloss.JoinHorizontal(lipgloss.Top, tabs...)
return tabBarStyle.Width(a.width).Render(tabBar)
}
func (a App) renderStatusBar() string {
left := strings.TrimRight(T("status_left"), " ")
right := strings.TrimRight(T("status_right"), " ")
width := a.width
if width < 1 {
width = 1
}
// statusBarStyle has left/right padding(1), so content area is width-2.
contentWidth := width - 2
if contentWidth < 0 {
contentWidth = 0
}
if lipgloss.Width(left) > contentWidth {
left = fitStringWidth(left, contentWidth)
right = ""
}
remaining := contentWidth - lipgloss.Width(left)
if remaining < 0 {
remaining = 0
}
if lipgloss.Width(right) > remaining {
right = fitStringWidth(right, remaining)
}
gap := contentWidth - lipgloss.Width(left) - lipgloss.Width(right)
if gap < 0 {
gap = 0
}
return statusBarStyle.Width(width).Render(left + strings.Repeat(" ", gap) + right)
}
func fitStringWidth(text string, maxWidth int) string {
if maxWidth <= 0 {
return ""
}
if lipgloss.Width(text) <= maxWidth {
return text
}
out := ""
for _, r := range text {
next := out + string(r)
if lipgloss.Width(next) > maxWidth {
break
}
out = next
}
return out
}
func isLogsEnabledFromConfig(cfg map[string]any) bool {
if cfg == nil {
return true
}
value, ok := cfg["logging-to-file"]
if !ok {
return true
}
enabled, ok := value.(bool)
if !ok {
return true
}
return enabled
}
func (a *App) setAuthInputPrompt() {
if a == nil {
return
}
a.authInput.Prompt = fmt.Sprintf(" %s: ", T("auth_gate_password"))
}
func (a App) connectWithPassword(password string) tea.Cmd {
return func() tea.Msg {
a.client.SetSecretKey(password)
cfg, errGetConfig := a.client.GetConfig()
return authConnectMsg{cfg: cfg, err: errGetConfig}
}
}
// Run starts the TUI application.
// output specifies where bubbletea renders. If nil, defaults to os.Stdout.
func Run(port int, secretKey string, hook *LogHook, output io.Writer) error {
if output == nil {
output = os.Stdout
}
app := NewApp(port, secretKey, hook)
p := tea.NewProgram(app, tea.WithAltScreen(), tea.WithOutput(output))
_, err := p.Run()
return err
}
func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
var cmd tea.Cmd
a.dashboard, cmd = a.dashboard.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.config, cmd = a.config.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.auth, cmd = a.auth.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.keys, cmd = a.keys.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.oauth, cmd = a.oauth.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.usage, cmd = a.usage.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.logs, cmd = a.logs.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
return a, tea.Batch(cmds...)
}

456
internal/tui/auth_tab.go Normal file
View File

@@ -0,0 +1,456 @@
package tui
import (
"fmt"
"strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// editableField represents an editable field on an auth file.
type editableField struct {
label string
key string // API field key: "prefix", "proxy_url", "priority"
}
var authEditableFields = []editableField{
{label: "Prefix", key: "prefix"},
{label: "Proxy URL", key: "proxy_url"},
{label: "Priority", key: "priority"},
}
// authTabModel displays auth credential files with interactive management.
type authTabModel struct {
client *Client
viewport viewport.Model
files []map[string]any
err error
width int
height int
ready bool
cursor int
expanded int // -1 = none expanded, >=0 = expanded index
confirm int // -1 = no confirmation, >=0 = confirm delete for index
status string
// Editing state
editing bool // true when editing a field
editField int // index into authEditableFields
editInput textinput.Model // text input for editing
editFileName string // name of file being edited
}
type authFilesMsg struct {
files []map[string]any
err error
}
type authActionMsg struct {
action string // "deleted", "toggled", "updated"
err error
}
func newAuthTabModel(client *Client) authTabModel {
ti := textinput.New()
ti.CharLimit = 256
return authTabModel{
client: client,
expanded: -1,
confirm: -1,
editInput: ti,
}
}
func (m authTabModel) Init() tea.Cmd {
return m.fetchFiles
}
func (m authTabModel) fetchFiles() tea.Msg {
files, err := m.client.GetAuthFiles()
return authFilesMsg{files: files, err: err}
}
func (m authTabModel) Update(msg tea.Msg) (authTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case authFilesMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.files = msg.files
if m.cursor >= len(m.files) {
m.cursor = max(0, len(m.files)-1)
}
m.status = ""
}
m.viewport.SetContent(m.renderContent())
return m, nil
case authActionMsg:
if msg.err != nil {
m.status = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.status = successStyle.Render("✓ " + msg.action)
}
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, m.fetchFiles
case tea.KeyMsg:
// ---- Editing mode ----
if m.editing {
return m.handleEditInput(msg)
}
// ---- Delete confirmation mode ----
if m.confirm >= 0 {
return m.handleConfirmInput(msg)
}
// ---- Normal mode ----
return m.handleNormalInput(msg)
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
// startEdit activates inline editing for a field on the currently selected auth file.
func (m *authTabModel) startEdit(fieldIdx int) tea.Cmd {
if m.cursor >= len(m.files) {
return nil
}
f := m.files[m.cursor]
m.editFileName = getString(f, "name")
m.editField = fieldIdx
m.editing = true
// Pre-populate with current value
key := authEditableFields[fieldIdx].key
currentVal := getAnyString(f, key)
m.editInput.SetValue(currentVal)
m.editInput.Focus()
m.editInput.Prompt = fmt.Sprintf(" %s: ", authEditableFields[fieldIdx].label)
m.viewport.SetContent(m.renderContent())
return textinput.Blink
}
func (m *authTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.editInput.Width = w - 20
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m authTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m authTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("auth_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_help1")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_help2")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if len(m.files) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_auth_files")))
sb.WriteString("\n")
return sb.String()
}
for i, f := range m.files {
name := getString(f, "name")
channel := getString(f, "channel")
email := getString(f, "email")
disabled := getBool(f, "disabled")
statusIcon := successStyle.Render("●")
statusText := T("status_active")
if disabled {
statusIcon = lipgloss.NewStyle().Foreground(colorMuted).Render("○")
statusText = T("status_disabled")
}
cursor := " "
rowStyle := lipgloss.NewStyle()
if i == m.cursor {
cursor = "▸ "
rowStyle = lipgloss.NewStyle().Bold(true)
}
displayName := name
if len(displayName) > 24 {
displayName = displayName[:21] + "..."
}
displayEmail := email
if len(displayEmail) > 28 {
displayEmail = displayEmail[:25] + "..."
}
row := fmt.Sprintf("%s%s %-24s %-12s %-28s %s",
cursor, statusIcon, displayName, channel, displayEmail, statusText)
sb.WriteString(rowStyle.Render(row))
sb.WriteString("\n")
// Delete confirmation
if m.confirm == i {
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete"), name)))
sb.WriteString("\n")
}
// Inline edit input
if m.editing && i == m.cursor {
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(" " + T("enter_save") + " • " + T("esc_cancel")))
sb.WriteString("\n")
}
// Expanded detail view
if m.expanded == i {
sb.WriteString(m.renderDetail(f))
}
}
if m.status != "" {
sb.WriteString("\n")
sb.WriteString(m.status)
sb.WriteString("\n")
}
return sb.String()
}
func (m authTabModel) renderDetail(f map[string]any) string {
var sb strings.Builder
labelStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("111")).
Bold(true)
valueStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
editableMarker := lipgloss.NewStyle().
Foreground(lipgloss.Color("214")).
Render(" ✎")
sb.WriteString(" ┌─────────────────────────────────────────────\n")
fields := []struct {
label string
key string
editable bool
}{
{"Name", "name", false},
{"Channel", "channel", false},
{"Email", "email", false},
{"Status", "status", false},
{"Status Msg", "status_message", false},
{"File Name", "file_name", false},
{"Auth Type", "auth_type", false},
{"Prefix", "prefix", true},
{"Proxy URL", "proxy_url", true},
{"Priority", "priority", true},
{"Project ID", "project_id", false},
{"Disabled", "disabled", false},
{"Created", "created_at", false},
{"Updated", "updated_at", false},
}
for _, field := range fields {
val := getAnyString(f, field.key)
if val == "" || val == "<nil>" {
if field.editable {
val = T("not_set")
} else {
continue
}
}
editMark := ""
if field.editable {
editMark = editableMarker
}
line := fmt.Sprintf(" │ %s %s%s",
labelStyle.Render(fmt.Sprintf("%-12s:", field.label)),
valueStyle.Render(val),
editMark)
sb.WriteString(line)
sb.WriteString("\n")
}
sb.WriteString(" └─────────────────────────────────────────────\n")
return sb.String()
}
// getAnyString converts any value to its string representation.
func getAnyString(m map[string]any, key string) string {
v, ok := m[key]
if !ok || v == nil {
return ""
}
return fmt.Sprintf("%v", v)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func (m authTabModel) handleEditInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "enter":
value := m.editInput.Value()
fieldKey := authEditableFields[m.editField].key
fileName := m.editFileName
m.editing = false
m.editInput.Blur()
fields := map[string]any{}
if fieldKey == "priority" {
p, err := strconv.Atoi(value)
if err != nil {
return m, func() tea.Msg {
return authActionMsg{err: fmt.Errorf("%s: %s", T("invalid_int"), value)}
}
}
fields[fieldKey] = p
} else {
fields[fieldKey] = value
}
return m, func() tea.Msg {
err := m.client.PatchAuthFileFields(fileName, fields)
if err != nil {
return authActionMsg{err: err}
}
return authActionMsg{action: fmt.Sprintf(T("updated_field"), fieldKey, fileName)}
}
case "esc":
m.editing = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.editInput, cmd = m.editInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
func (m authTabModel) handleConfirmInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "y", "Y":
idx := m.confirm
m.confirm = -1
if idx < len(m.files) {
name := getString(m.files[idx], "name")
return m, func() tea.Msg {
err := m.client.DeleteAuthFile(name)
if err != nil {
return authActionMsg{err: err}
}
return authActionMsg{action: fmt.Sprintf(T("deleted"), name)}
}
}
m.viewport.SetContent(m.renderContent())
return m, nil
case "n", "N", "esc":
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, nil
}
return m, nil
}
func (m authTabModel) handleNormalInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "j", "down":
if len(m.files) > 0 {
m.cursor = (m.cursor + 1) % len(m.files)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "k", "up":
if len(m.files) > 0 {
m.cursor = (m.cursor - 1 + len(m.files)) % len(m.files)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "enter", " ":
if m.expanded == m.cursor {
m.expanded = -1
} else {
m.expanded = m.cursor
}
m.viewport.SetContent(m.renderContent())
return m, nil
case "d", "D":
if m.cursor < len(m.files) {
m.confirm = m.cursor
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "e", "E":
if m.cursor < len(m.files) {
f := m.files[m.cursor]
name := getString(f, "name")
disabled := getBool(f, "disabled")
newDisabled := !disabled
return m, func() tea.Msg {
err := m.client.ToggleAuthFile(name, newDisabled)
if err != nil {
return authActionMsg{err: err}
}
action := T("enabled")
if newDisabled {
action = T("disabled")
}
return authActionMsg{action: fmt.Sprintf("%s %s", action, name)}
}
}
return m, nil
case "1":
return m, m.startEdit(0) // prefix
case "2":
return m, m.startEdit(1) // proxy_url
case "3":
return m, m.startEdit(2) // priority
case "r":
m.status = ""
return m, m.fetchFiles
default:
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
}

20
internal/tui/browser.go Normal file
View File

@@ -0,0 +1,20 @@
package tui
import (
"os/exec"
"runtime"
)
// openBrowser opens the specified URL in the user's default browser.
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
default:
return exec.Command("xdg-open", url).Start()
}
}

400
internal/tui/client.go Normal file
View File

@@ -0,0 +1,400 @@
package tui
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
// Client wraps HTTP calls to the management API.
type Client struct {
baseURL string
secretKey string
http *http.Client
}
// NewClient creates a new management API client.
func NewClient(port int, secretKey string) *Client {
return &Client{
baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
secretKey: strings.TrimSpace(secretKey),
http: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// SetSecretKey updates management API bearer token used by this client.
func (c *Client) SetSecretKey(secretKey string) {
c.secretKey = strings.TrimSpace(secretKey)
}
func (c *Client) doRequest(method, path string, body io.Reader) ([]byte, int, error) {
url := c.baseURL + path
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, 0, err
}
if c.secretKey != "" {
req.Header.Set("Authorization", "Bearer "+c.secretKey)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.http.Do(req)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, err
}
return data, resp.StatusCode, nil
}
func (c *Client) get(path string) ([]byte, error) {
data, code, err := c.doRequest("GET", path, nil)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
func (c *Client) put(path string, body io.Reader) ([]byte, error) {
data, code, err := c.doRequest("PUT", path, body)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
func (c *Client) patch(path string, body io.Reader) ([]byte, error) {
data, code, err := c.doRequest("PATCH", path, body)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
// getJSON fetches a path and unmarshals JSON into a generic map.
func (c *Client) getJSON(path string) (map[string]any, error) {
data, err := c.get(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
// postJSON sends a JSON body via POST and checks for errors.
func (c *Client) postJSON(path string, body any) error {
jsonBody, err := json.Marshal(body)
if err != nil {
return err
}
_, code, err := c.doRequest("POST", path, strings.NewReader(string(jsonBody)))
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("HTTP %d", code)
}
return nil
}
// GetConfig fetches the parsed config.
func (c *Client) GetConfig() (map[string]any, error) {
return c.getJSON("/v0/management/config")
}
// GetConfigYAML fetches the raw config.yaml content.
func (c *Client) GetConfigYAML() (string, error) {
data, err := c.get("/v0/management/config.yaml")
if err != nil {
return "", err
}
return string(data), nil
}
// PutConfigYAML uploads new config.yaml content.
func (c *Client) PutConfigYAML(yamlContent string) error {
_, err := c.put("/v0/management/config.yaml", strings.NewReader(yamlContent))
return err
}
// GetUsage fetches usage statistics.
func (c *Client) GetUsage() (map[string]any, error) {
return c.getJSON("/v0/management/usage")
}
// GetAuthFiles lists auth credential files.
// API returns {"files": [...]}.
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
wrapper, err := c.getJSON("/v0/management/auth-files")
if err != nil {
return nil, err
}
return extractList(wrapper, "files")
}
// DeleteAuthFile deletes a single auth file by name.
func (c *Client) DeleteAuthFile(name string) error {
query := url.Values{}
query.Set("name", name)
path := "/v0/management/auth-files?" + query.Encode()
_, code, err := c.doRequest("DELETE", path, nil)
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("delete failed (HTTP %d)", code)
}
return nil
}
// ToggleAuthFile enables or disables an auth file.
func (c *Client) ToggleAuthFile(name string, disabled bool) error {
body, _ := json.Marshal(map[string]any{"name": name, "disabled": disabled})
_, err := c.patch("/v0/management/auth-files/status", strings.NewReader(string(body)))
return err
}
// PatchAuthFileFields updates editable fields on an auth file.
func (c *Client) PatchAuthFileFields(name string, fields map[string]any) error {
fields["name"] = name
body, _ := json.Marshal(fields)
_, err := c.patch("/v0/management/auth-files/fields", strings.NewReader(string(body)))
return err
}
// GetLogs fetches log lines from the server.
func (c *Client) GetLogs(after int64, limit int) ([]string, int64, error) {
query := url.Values{}
if limit > 0 {
query.Set("limit", strconv.Itoa(limit))
}
if after > 0 {
query.Set("after", strconv.FormatInt(after, 10))
}
path := "/v0/management/logs"
encodedQuery := query.Encode()
if encodedQuery != "" {
path += "?" + encodedQuery
}
wrapper, err := c.getJSON(path)
if err != nil {
return nil, after, err
}
lines := []string{}
if rawLines, ok := wrapper["lines"]; ok && rawLines != nil {
rawJSON, errMarshal := json.Marshal(rawLines)
if errMarshal != nil {
return nil, after, errMarshal
}
if errUnmarshal := json.Unmarshal(rawJSON, &lines); errUnmarshal != nil {
return nil, after, errUnmarshal
}
}
latest := after
if rawLatest, ok := wrapper["latest-timestamp"]; ok {
switch value := rawLatest.(type) {
case float64:
latest = int64(value)
case json.Number:
if parsed, errParse := value.Int64(); errParse == nil {
latest = parsed
}
case int64:
latest = value
case int:
latest = int64(value)
}
}
if latest < after {
latest = after
}
return lines, latest, nil
}
// GetAPIKeys fetches the list of API keys.
// API returns {"api-keys": [...]}.
func (c *Client) GetAPIKeys() ([]string, error) {
wrapper, err := c.getJSON("/v0/management/api-keys")
if err != nil {
return nil, err
}
arr, ok := wrapper["api-keys"]
if !ok {
return nil, nil
}
raw, err := json.Marshal(arr)
if err != nil {
return nil, err
}
var result []string
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
return result, nil
}
// AddAPIKey adds a new API key by sending old=nil, new=key which appends.
func (c *Client) AddAPIKey(key string) error {
body := map[string]any{"old": nil, "new": key}
jsonBody, _ := json.Marshal(body)
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
return err
}
// EditAPIKey replaces an API key at the given index.
func (c *Client) EditAPIKey(index int, newValue string) error {
body := map[string]any{"index": index, "value": newValue}
jsonBody, _ := json.Marshal(body)
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
return err
}
// DeleteAPIKey deletes an API key by index.
func (c *Client) DeleteAPIKey(index int) error {
_, code, err := c.doRequest("DELETE", fmt.Sprintf("/v0/management/api-keys?index=%d", index), nil)
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("delete failed (HTTP %d)", code)
}
return nil
}
// GetGeminiKeys fetches Gemini API keys.
// API returns {"gemini-api-key": [...]}.
func (c *Client) GetGeminiKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/gemini-api-key", "gemini-api-key")
}
// GetClaudeKeys fetches Claude API keys.
func (c *Client) GetClaudeKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/claude-api-key", "claude-api-key")
}
// GetCodexKeys fetches Codex API keys.
func (c *Client) GetCodexKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/codex-api-key", "codex-api-key")
}
// GetVertexKeys fetches Vertex API keys.
func (c *Client) GetVertexKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/vertex-api-key", "vertex-api-key")
}
// GetOpenAICompat fetches OpenAI compatibility entries.
func (c *Client) GetOpenAICompat() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/openai-compatibility", "openai-compatibility")
}
// getWrappedKeyList fetches a wrapped list from the API.
func (c *Client) getWrappedKeyList(path, key string) ([]map[string]any, error) {
wrapper, err := c.getJSON(path)
if err != nil {
return nil, err
}
return extractList(wrapper, key)
}
// extractList pulls an array of maps from a wrapper object by key.
func extractList(wrapper map[string]any, key string) ([]map[string]any, error) {
arr, ok := wrapper[key]
if !ok || arr == nil {
return nil, nil
}
raw, err := json.Marshal(arr)
if err != nil {
return nil, err
}
var result []map[string]any
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
return result, nil
}
// GetDebug fetches the current debug setting.
func (c *Client) GetDebug() (bool, error) {
wrapper, err := c.getJSON("/v0/management/debug")
if err != nil {
return false, err
}
if v, ok := wrapper["debug"]; ok {
if b, ok := v.(bool); ok {
return b, nil
}
}
return false, nil
}
// GetAuthStatus polls the OAuth session status.
// Returns status ("wait", "ok", "error") and optional error message.
func (c *Client) GetAuthStatus(state string) (string, string, error) {
query := url.Values{}
query.Set("state", state)
path := "/v0/management/get-auth-status?" + query.Encode()
wrapper, err := c.getJSON(path)
if err != nil {
return "", "", err
}
status := getString(wrapper, "status")
errMsg := getString(wrapper, "error")
return status, errMsg, nil
}
// ----- Config field update methods -----
// PutBoolField updates a boolean config field.
func (c *Client) PutBoolField(path string, value bool) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// PutIntField updates an integer config field.
func (c *Client) PutIntField(path string, value int) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// PutStringField updates a string config field.
func (c *Client) PutStringField(path string, value string) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// DeleteField sends a DELETE request for a config field.
func (c *Client) DeleteField(path string) error {
_, _, err := c.doRequest("DELETE", "/v0/management/"+path, nil)
return err
}

413
internal/tui/config_tab.go Normal file
View File

@@ -0,0 +1,413 @@
package tui
import (
"fmt"
"strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// configField represents a single editable config field.
type configField struct {
label string
apiPath string // management API path (e.g. "debug", "proxy-url")
kind string // "bool", "int", "string", "readonly"
value string // current display value
rawValue any // raw value from API
}
// configTabModel displays parsed config with interactive editing.
type configTabModel struct {
client *Client
viewport viewport.Model
fields []configField
cursor int
editing bool
textInput textinput.Model
err error
message string // status message (success/error)
width int
height int
ready bool
}
type configDataMsg struct {
config map[string]any
err error
}
type configUpdateMsg struct {
path string
value any
err error
}
func newConfigTabModel(client *Client) configTabModel {
ti := textinput.New()
ti.CharLimit = 256
return configTabModel{
client: client,
textInput: ti,
}
}
func (m configTabModel) Init() tea.Cmd {
return m.fetchConfig
}
func (m configTabModel) fetchConfig() tea.Msg {
cfg, err := m.client.GetConfig()
return configDataMsg{config: cfg, err: err}
}
func (m configTabModel) Update(msg tea.Msg) (configTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case configDataMsg:
if msg.err != nil {
m.err = msg.err
m.fields = nil
} else {
m.err = nil
m.fields = m.parseConfig(msg.config)
}
m.viewport.SetContent(m.renderContent())
return m, nil
case configUpdateMsg:
if msg.err != nil {
m.message = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.message = successStyle.Render(T("updated_ok"))
}
m.viewport.SetContent(m.renderContent())
// Refresh config from server
return m, m.fetchConfig
case tea.KeyMsg:
if m.editing {
return m.handleEditingKey(msg)
}
return m.handleNormalKey(msg)
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m configTabModel) handleNormalKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
switch msg.String() {
case "r":
m.message = ""
return m, m.fetchConfig
case "up", "k":
if m.cursor > 0 {
m.cursor--
m.viewport.SetContent(m.renderContent())
// Ensure cursor is visible
m.ensureCursorVisible()
}
return m, nil
case "down", "j":
if m.cursor < len(m.fields)-1 {
m.cursor++
m.viewport.SetContent(m.renderContent())
m.ensureCursorVisible()
}
return m, nil
case "enter", " ":
if m.cursor >= 0 && m.cursor < len(m.fields) {
f := m.fields[m.cursor]
if f.kind == "readonly" {
return m, nil
}
if f.kind == "bool" {
// Toggle directly
return m, m.toggleBool(m.cursor)
}
// Start editing for int/string
m.editing = true
m.textInput.SetValue(configFieldEditValue(f))
m.textInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
}
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m configTabModel) handleEditingKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
switch msg.String() {
case "enter":
m.editing = false
m.textInput.Blur()
return m, m.submitEdit(m.cursor, m.textInput.Value())
case "esc":
m.editing = false
m.textInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
func (m configTabModel) toggleBool(idx int) tea.Cmd {
return func() tea.Msg {
f := m.fields[idx]
current := f.value == "true"
newValue := !current
errPutBool := m.client.PutBoolField(f.apiPath, newValue)
return configUpdateMsg{
path: f.apiPath,
value: newValue,
err: errPutBool,
}
}
}
func (m configTabModel) submitEdit(idx int, newValue string) tea.Cmd {
return func() tea.Msg {
f := m.fields[idx]
var err error
var value any
switch f.kind {
case "int":
valueInt, errAtoi := strconv.Atoi(newValue)
if errAtoi != nil {
return configUpdateMsg{
path: f.apiPath,
err: fmt.Errorf("%s: %s", T("invalid_int"), newValue),
}
}
value = valueInt
err = m.client.PutIntField(f.apiPath, valueInt)
case "string":
value = newValue
err = m.client.PutStringField(f.apiPath, newValue)
}
return configUpdateMsg{
path: f.apiPath,
value: value,
err: err,
}
}
}
func configFieldEditValue(f configField) string {
if rawString, ok := f.rawValue.(string); ok {
return rawString
}
return f.value
}
func (m *configTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m *configTabModel) ensureCursorVisible() {
// Each field takes ~1 line, header takes ~4 lines
targetLine := m.cursor + 5
if targetLine < m.viewport.YOffset {
m.viewport.SetYOffset(targetLine)
}
if targetLine >= m.viewport.YOffset+m.viewport.Height {
m.viewport.SetYOffset(targetLine - m.viewport.Height + 1)
}
}
func (m configTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m configTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("config_title")))
sb.WriteString("\n")
if m.message != "" {
sb.WriteString(" " + m.message)
sb.WriteString("\n")
}
sb.WriteString(helpStyle.Render(T("config_help1")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("config_help2")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render(" ⚠ Error: " + m.err.Error()))
return sb.String()
}
if len(m.fields) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_config")))
return sb.String()
}
currentSection := ""
for i, f := range m.fields {
// Section headers
section := fieldSection(f.apiPath)
if section != currentSection {
currentSection = section
sb.WriteString("\n")
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(" ── " + section + " "))
sb.WriteString("\n")
}
isSelected := i == m.cursor
prefix := " "
if isSelected {
prefix = "▸ "
}
labelStr := lipgloss.NewStyle().
Foreground(colorInfo).
Bold(isSelected).
Width(32).
Render(f.label)
var valueStr string
if m.editing && isSelected {
valueStr = m.textInput.View()
} else {
switch f.kind {
case "bool":
if f.value == "true" {
valueStr = successStyle.Render("● ON")
} else {
valueStr = lipgloss.NewStyle().Foreground(colorMuted).Render("○ OFF")
}
case "readonly":
valueStr = lipgloss.NewStyle().Foreground(colorSubtext).Render(f.value)
default:
valueStr = valueStyle.Render(f.value)
}
}
line := prefix + labelStr + " " + valueStr
if isSelected && !m.editing {
line = lipgloss.NewStyle().Background(colorSurface).Render(line)
}
sb.WriteString(line + "\n")
}
return sb.String()
}
func (m configTabModel) parseConfig(cfg map[string]any) []configField {
var fields []configField
// Server settings
fields = append(fields, configField{"Port", "port", "readonly", fmt.Sprintf("%.0f", getFloat(cfg, "port")), nil})
fields = append(fields, configField{"Host", "host", "readonly", getString(cfg, "host"), nil})
fields = append(fields, configField{"Debug", "debug", "bool", fmt.Sprintf("%v", getBool(cfg, "debug")), nil})
fields = append(fields, configField{"Proxy URL", "proxy-url", "string", getString(cfg, "proxy-url"), nil})
fields = append(fields, configField{"Request Retry", "request-retry", "int", fmt.Sprintf("%.0f", getFloat(cfg, "request-retry")), nil})
fields = append(fields, configField{"Max Retry Interval (s)", "max-retry-interval", "int", fmt.Sprintf("%.0f", getFloat(cfg, "max-retry-interval")), nil})
fields = append(fields, configField{"Force Model Prefix", "force-model-prefix", "string", getString(cfg, "force-model-prefix"), nil})
// Logging
fields = append(fields, configField{"Logging to File", "logging-to-file", "bool", fmt.Sprintf("%v", getBool(cfg, "logging-to-file")), nil})
fields = append(fields, configField{"Logs Max Total Size (MB)", "logs-max-total-size-mb", "int", fmt.Sprintf("%.0f", getFloat(cfg, "logs-max-total-size-mb")), nil})
fields = append(fields, configField{"Error Logs Max Files", "error-logs-max-files", "int", fmt.Sprintf("%.0f", getFloat(cfg, "error-logs-max-files")), nil})
fields = append(fields, configField{"Usage Stats Enabled", "usage-statistics-enabled", "bool", fmt.Sprintf("%v", getBool(cfg, "usage-statistics-enabled")), nil})
fields = append(fields, configField{"Request Log", "request-log", "bool", fmt.Sprintf("%v", getBool(cfg, "request-log")), nil})
// Quota exceeded
fields = append(fields, configField{"Switch Project on Quota", "quota-exceeded/switch-project", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-project")), nil})
fields = append(fields, configField{"Switch Preview Model", "quota-exceeded/switch-preview-model", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-preview-model")), nil})
// Routing
if routing, ok := cfg["routing"].(map[string]any); ok {
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", getString(routing, "strategy"), nil})
} else {
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", "", nil})
}
// WebSocket auth
fields = append(fields, configField{"WebSocket Auth", "ws-auth", "bool", fmt.Sprintf("%v", getBool(cfg, "ws-auth")), nil})
// AMP settings
if amp, ok := cfg["ampcode"].(map[string]any); ok {
upstreamURL := getString(amp, "upstream-url")
upstreamAPIKey := getString(amp, "upstream-api-key")
fields = append(fields, configField{"AMP Upstream URL", "ampcode/upstream-url", "string", upstreamURL, upstreamURL})
fields = append(fields, configField{"AMP Upstream API Key", "ampcode/upstream-api-key", "string", maskIfNotEmpty(upstreamAPIKey), upstreamAPIKey})
fields = append(fields, configField{"AMP Restrict Mgmt Localhost", "ampcode/restrict-management-to-localhost", "bool", fmt.Sprintf("%v", getBool(amp, "restrict-management-to-localhost")), nil})
}
return fields
}
func fieldSection(apiPath string) string {
if strings.HasPrefix(apiPath, "ampcode/") {
return T("section_ampcode")
}
if strings.HasPrefix(apiPath, "quota-exceeded/") {
return T("section_quota")
}
if strings.HasPrefix(apiPath, "routing/") {
return T("section_routing")
}
switch apiPath {
case "port", "host", "debug", "proxy-url", "request-retry", "max-retry-interval", "force-model-prefix":
return T("section_server")
case "logging-to-file", "logs-max-total-size-mb", "error-logs-max-files", "usage-statistics-enabled", "request-log":
return T("section_logging")
case "ws-auth":
return T("section_websocket")
default:
return T("section_other")
}
}
func getBoolNested(m map[string]any, keys ...string) bool {
current := m
for i, key := range keys {
if i == len(keys)-1 {
return getBool(current, key)
}
if nested, ok := current[key].(map[string]any); ok {
current = nested
} else {
return false
}
}
return false
}
func maskIfNotEmpty(s string) string {
if s == "" {
return T("not_set")
}
return maskKey(s)
}

360
internal/tui/dashboard.go Normal file
View File

@@ -0,0 +1,360 @@
package tui
import (
"encoding/json"
"fmt"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// dashboardModel displays server info, stats cards, and config overview.
type dashboardModel struct {
client *Client
viewport viewport.Model
content string
err error
width int
height int
ready bool
// Cached data for re-rendering on locale change
lastConfig map[string]any
lastUsage map[string]any
lastAuthFiles []map[string]any
lastAPIKeys []string
}
type dashboardDataMsg struct {
config map[string]any
usage map[string]any
authFiles []map[string]any
apiKeys []string
err error
}
func newDashboardModel(client *Client) dashboardModel {
return dashboardModel{
client: client,
}
}
func (m dashboardModel) Init() tea.Cmd {
return m.fetchData
}
func (m dashboardModel) fetchData() tea.Msg {
cfg, cfgErr := m.client.GetConfig()
usage, usageErr := m.client.GetUsage()
authFiles, authErr := m.client.GetAuthFiles()
apiKeys, keysErr := m.client.GetAPIKeys()
var err error
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
if e != nil {
err = e
break
}
}
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
}
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
// Re-render immediately with cached data using new locale
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
m.viewport.SetContent(m.content)
// Also fetch fresh data in background
return m, m.fetchData
case dashboardDataMsg:
if msg.err != nil {
m.err = msg.err
m.content = errorStyle.Render("⚠ Error: " + msg.err.Error())
} else {
m.err = nil
// Cache data for locale switching
m.lastConfig = msg.config
m.lastUsage = msg.usage
m.lastAuthFiles = msg.authFiles
m.lastAPIKeys = msg.apiKeys
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
}
m.viewport.SetContent(m.content)
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *dashboardModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.content)
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m dashboardModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("dashboard_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("dashboard_help")))
sb.WriteString("\n\n")
// ━━━ Connection Status ━━━
connStyle := lipgloss.NewStyle().Bold(true).Foreground(colorSuccess)
sb.WriteString(connStyle.Render(T("connected")))
sb.WriteString(fmt.Sprintf(" %s", m.client.baseURL))
sb.WriteString("\n\n")
// ━━━ Stats Cards ━━━
cardWidth := 25
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 18 {
cardWidth = 18
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(2)
// Card 1: API Keys
keyCount := len(apiKeys)
card1 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("🔑 %d", keyCount)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("mgmt_keys")),
))
// Card 2: Auth Files
authCount := len(authFiles)
activeAuth := 0
for _, f := range authFiles {
if !getBool(f, "disabled") {
activeAuth++
}
}
card2 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("📄 %d", authCount)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
))
// Card 3: Total Requests
totalReqs := int64(0)
successReqs := int64(0)
failedReqs := int64(0)
totalTokens := int64(0)
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
totalReqs = int64(getFloat(usageMap, "total_requests"))
successReqs = int64(getFloat(usageMap, "success_count"))
failedReqs = int64(getFloat(usageMap, "failure_count"))
totalTokens = int64(getFloat(usageMap, "total_tokens"))
}
}
card3 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
))
// Card 4: Total Tokens
tokenStr := formatLargeNumber(totalTokens)
card4 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Current Config ━━━
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("current_config")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
if cfg != nil {
debug := getBool(cfg, "debug")
retry := getFloat(cfg, "request-retry")
proxyURL := getString(cfg, "proxy-url")
loggingToFile := getBool(cfg, "logging-to-file")
usageEnabled := true
if v, ok := cfg["usage-statistics-enabled"]; ok {
if b, ok2 := v.(bool); ok2 {
usageEnabled = b
}
}
configItems := []struct {
label string
value string
}{
{T("debug_mode"), boolEmoji(debug)},
{T("usage_stats"), boolEmoji(usageEnabled)},
{T("log_to_file"), boolEmoji(loggingToFile)},
{T("retry_count"), fmt.Sprintf("%.0f", retry)},
}
if proxyURL != "" {
configItems = append(configItems, struct {
label string
value string
}{T("proxy_url"), proxyURL})
}
// Render config items as a compact row
for _, item := range configItems {
sb.WriteString(fmt.Sprintf(" %s %s\n",
labelStyle.Render(item.label+":"),
valueStyle.Render(item.value)))
}
// Routing strategy
strategy := "round-robin"
if routing, ok := cfg["routing"].(map[string]any); ok {
if s := getString(routing, "strategy"); s != "" {
strategy = s
}
}
sb.WriteString(fmt.Sprintf(" %s %s\n",
labelStyle.Render(T("routing_strategy")+":"),
valueStyle.Render(strategy)))
}
sb.WriteString("\n")
// ━━━ Per-Model Usage ━━━
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for _, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
reqs := int64(getFloat(stats, "total_requests"))
toks := int64(getFloat(stats, "total_tokens"))
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
sb.WriteString(tableCellStyle.Render(row))
sb.WriteString("\n")
}
}
}
}
}
}
}
}
return sb.String()
}
func formatKV(key, value string) string {
return fmt.Sprintf(" %s %s\n", labelStyle.Render(key+":"), valueStyle.Render(value))
}
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getFloat(m map[string]any, key string) float64 {
if v, ok := m[key]; ok {
switch n := v.(type) {
case float64:
return n
case json.Number:
f, _ := n.Float64()
return f
}
}
return 0
}
func getBool(m map[string]any, key string) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return false
}
func boolEmoji(b bool) string {
if b {
return T("bool_yes")
}
return T("bool_no")
}
func formatLargeNumber(n int64) string {
if n >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
}
if n >= 1_000 {
return fmt.Sprintf("%.1fK", float64(n)/1_000)
}
return fmt.Sprintf("%d", n)
}
func truncate(s string, maxLen int) string {
if len(s) > maxLen {
return s[:maxLen-3] + "..."
}
return s
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}

364
internal/tui/i18n.go Normal file
View File

@@ -0,0 +1,364 @@
package tui
// i18n provides a simple internationalization system for the TUI.
// Supported locales: "zh" (Chinese, default), "en" (English).
var currentLocale = "en"
// SetLocale changes the active locale.
func SetLocale(locale string) {
if _, ok := locales[locale]; ok {
currentLocale = locale
}
}
// CurrentLocale returns the active locale code.
func CurrentLocale() string {
return currentLocale
}
// ToggleLocale switches between zh and en.
func ToggleLocale() {
if currentLocale == "zh" {
currentLocale = "en"
} else {
currentLocale = "zh"
}
}
// T returns the translated string for the given key.
func T(key string) string {
if m, ok := locales[currentLocale]; ok {
if v, ok := m[key]; ok {
return v
}
}
// Fallback to English
if m, ok := locales["en"]; ok {
if v, ok := m[key]; ok {
return v
}
}
return key
}
var locales = map[string]map[string]string{
"zh": zhStrings,
"en": enStrings,
}
// ──────────────────────────────────────────
// Tab names
// ──────────────────────────────────────────
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
// TabNames returns tab names in the current locale.
func TabNames() []string {
if currentLocale == "zh" {
return zhTabNames
}
return enTabNames
}
var zhStrings = map[string]string{
// ── Common ──
"loading": "加载中...",
"refresh": "刷新",
"save": "保存",
"cancel": "取消",
"confirm": "确认",
"yes": "是",
"no": "否",
"error": "错误",
"success": "成功",
"navigate": "导航",
"scroll": "滚动",
"enter_save": "Enter: 保存",
"esc_cancel": "Esc: 取消",
"enter_submit": "Enter: 提交",
"press_r": "[r] 刷新",
"press_scroll": "[↑↓] 滚动",
"not_set": "(未设置)",
"error_prefix": "⚠ 错误: ",
// ── Status bar ──
"status_left": " CLIProxyAPI 管理终端",
"status_right": "Tab/Shift+Tab: 切换 • L: 语言 • q/Ctrl+C: 退出 ",
"initializing_tui": "正在初始化...",
"auth_gate_title": "🔐 连接管理 API",
"auth_gate_help": " 请输入管理密码并按 Enter 连接",
"auth_gate_password": "密码",
"auth_gate_enter": " Enter: 连接 • q/Ctrl+C: 退出 • L: 语言",
"auth_gate_connecting": "正在连接...",
"auth_gate_connect_fail": "连接失败:%s",
"auth_gate_password_required": "请输入密码",
// ── Dashboard ──
"dashboard_title": "📊 仪表盘",
"dashboard_help": " [r] 刷新 • [↑↓] 滚动",
"connected": "● 已连接",
"mgmt_keys": "管理密钥",
"auth_files_label": "认证文件",
"active_suffix": "活跃",
"total_requests": "请求",
"success_label": "成功",
"failure_label": "失败",
"total_tokens": "总 Tokens",
"current_config": "当前配置",
"debug_mode": "启用调试模式",
"usage_stats": "启用使用统计",
"log_to_file": "启用日志记录到文件",
"retry_count": "重试次数",
"proxy_url": "代理 URL",
"routing_strategy": "路由策略",
"model_stats": "模型统计",
"model": "模型",
"requests": "请求数",
"tokens": "Tokens",
"bool_yes": "是 ✓",
"bool_no": "否",
// ── Config ──
"config_title": "⚙ 配置",
"config_help1": " [↑↓/jk] 导航 • [Enter/Space] 编辑 • [r] 刷新",
"config_help2": " 布尔: Enter 切换 • 文本/数字: Enter 输入, Enter 确认, Esc 取消",
"updated_ok": "✓ 更新成功",
"no_config": " 未加载配置",
"invalid_int": "无效整数",
"section_server": "服务器",
"section_logging": "日志与统计",
"section_quota": "配额超限处理",
"section_routing": "路由",
"section_websocket": "WebSocket",
"section_ampcode": "AMP Code",
"section_other": "其他",
// ── Auth Files ──
"auth_title": "🔑 认证文件",
"auth_help1": " [↑↓/jk] 导航 • [Enter] 展开 • [e] 启用/停用 • [d] 删除 • [r] 刷新",
"auth_help2": " [1] 编辑 prefix • [2] 编辑 proxy_url • [3] 编辑 priority",
"no_auth_files": " 无认证文件",
"confirm_delete": "⚠ 删除 %s? [y/n]",
"deleted": "已删除 %s",
"enabled": "已启用",
"disabled": "已停用",
"updated_field": "已更新 %s 的 %s",
"status_active": "活跃",
"status_disabled": "已停用",
// ── API Keys ──
"keys_title": "🔐 API 密钥",
"keys_help": " [↑↓/jk] 导航 • [a] 添加 • [e] 编辑 • [d] 删除 • [c] 复制 • [r] 刷新",
"no_keys": " 无 API Key按 [a] 添加",
"access_keys": "Access API Keys",
"confirm_delete_key": "⚠ 确认删除 %s? [y/n]",
"key_added": "已添加 API Key",
"key_updated": "已更新 API Key",
"key_deleted": "已删除 API Key",
"copied": "✓ 已复制到剪贴板",
"copy_failed": "✗ 复制失败",
"new_key_prompt": " New Key: ",
"edit_key_prompt": " Edit Key: ",
"enter_add": " Enter: 添加 • Esc: 取消",
"enter_save_esc": " Enter: 保存 • Esc: 取消",
// ── OAuth ──
"oauth_title": "🔐 OAuth 登录",
"oauth_select": " 选择提供商并按 [Enter] 开始 OAuth 登录:",
"oauth_help": " [↑↓/jk] 导航 • [Enter] 登录 • [Esc] 清除状态",
"oauth_initiating": "⏳ 正在初始化 %s 登录...",
"oauth_success": "认证成功! 请刷新 Auth Files 标签查看新凭证。",
"oauth_completed": "认证流程已完成。",
"oauth_failed": "认证失败",
"oauth_timeout": "OAuth 流程超时 (5 分钟)",
"oauth_press_esc": " 按 [Esc] 取消",
"oauth_auth_url": " 授权链接:",
"oauth_remote_hint": " 远程浏览器模式:在浏览器中打开上述链接完成授权后,将回调 URL 粘贴到下方。",
"oauth_callback_url": " 回调 URL:",
"oauth_press_c": " 按 [c] 输入回调 URL • [Esc] 返回",
"oauth_submitting": "⏳ 提交回调中...",
"oauth_submit_ok": "✓ 回调已提交,等待处理...",
"oauth_submit_fail": "✗ 提交回调失败",
"oauth_waiting": " 等待认证中...",
// ── Usage ──
"usage_title": "📈 使用统计",
"usage_help": " [r] 刷新 • [↑↓] 滚动",
"usage_no_data": " 使用数据不可用",
"usage_total_reqs": "总请求数",
"usage_total_tokens": "总 Token 数",
"usage_success": "成功",
"usage_failure": "失败",
"usage_total_token_l": "总Token",
"usage_rpm": "RPM",
"usage_tpm": "TPM",
"usage_req_by_hour": "请求趋势 (按小时)",
"usage_tok_by_hour": "Token 使用趋势 (按小时)",
"usage_req_by_day": "请求趋势 (按天)",
"usage_api_detail": "API 详细统计",
"usage_input": "输入",
"usage_output": "输出",
"usage_cached": "缓存",
"usage_reasoning": "思考",
// ── Logs ──
"logs_title": "📋 日志",
"logs_auto_scroll": "● 自动滚动",
"logs_paused": "○ 已暂停",
"logs_filter": "过滤",
"logs_lines": "行数",
"logs_help": " [a] 自动滚动 • [c] 清除 • [1] 全部 [2] info+ [3] warn+ [4] error • [↑↓] 滚动",
"logs_waiting": " 等待日志输出...",
}
var enStrings = map[string]string{
// ── Common ──
"loading": "Loading...",
"refresh": "Refresh",
"save": "Save",
"cancel": "Cancel",
"confirm": "Confirm",
"yes": "Yes",
"no": "No",
"error": "Error",
"success": "Success",
"navigate": "Navigate",
"scroll": "Scroll",
"enter_save": "Enter: Save",
"esc_cancel": "Esc: Cancel",
"enter_submit": "Enter: Submit",
"press_r": "[r] Refresh",
"press_scroll": "[↑↓] Scroll",
"not_set": "(not set)",
"error_prefix": "⚠ Error: ",
// ── Status bar ──
"status_left": " CLIProxyAPI Management TUI",
"status_right": "Tab/Shift+Tab: switch • L: lang • q/Ctrl+C: quit ",
"initializing_tui": "Initializing...",
"auth_gate_title": "🔐 Connect Management API",
"auth_gate_help": " Enter management password and press Enter to connect",
"auth_gate_password": "Password",
"auth_gate_enter": " Enter: connect • q/Ctrl+C: quit • L: lang",
"auth_gate_connecting": "Connecting...",
"auth_gate_connect_fail": "Connection failed: %s",
"auth_gate_password_required": "password is required",
// ── Dashboard ──
"dashboard_title": "📊 Dashboard",
"dashboard_help": " [r] Refresh • [↑↓] Scroll",
"connected": "● Connected",
"mgmt_keys": "Mgmt Keys",
"auth_files_label": "Auth Files",
"active_suffix": "active",
"total_requests": "Requests",
"success_label": "Success",
"failure_label": "Failed",
"total_tokens": "Total Tokens",
"current_config": "Current Config",
"debug_mode": "Debug Mode",
"usage_stats": "Usage Statistics",
"log_to_file": "Log to File",
"retry_count": "Retry Count",
"proxy_url": "Proxy URL",
"routing_strategy": "Routing Strategy",
"model_stats": "Model Stats",
"model": "Model",
"requests": "Requests",
"tokens": "Tokens",
"bool_yes": "Yes ✓",
"bool_no": "No",
// ── Config ──
"config_title": "⚙ Configuration",
"config_help1": " [↑↓/jk] Navigate • [Enter/Space] Edit • [r] Refresh",
"config_help2": " Bool: Enter to toggle • String/Int: Enter to type, Enter to confirm, Esc to cancel",
"updated_ok": "✓ Updated successfully",
"no_config": " No configuration loaded",
"invalid_int": "invalid integer",
"section_server": "Server",
"section_logging": "Logging & Stats",
"section_quota": "Quota Exceeded Handling",
"section_routing": "Routing",
"section_websocket": "WebSocket",
"section_ampcode": "AMP Code",
"section_other": "Other",
// ── Auth Files ──
"auth_title": "🔑 Auth Files",
"auth_help1": " [↑↓/jk] Navigate • [Enter] Expand • [e] Enable/Disable • [d] Delete • [r] Refresh",
"auth_help2": " [1] Edit prefix • [2] Edit proxy_url • [3] Edit priority",
"no_auth_files": " No auth files found",
"confirm_delete": "⚠ Delete %s? [y/n]",
"deleted": "Deleted %s",
"enabled": "Enabled",
"disabled": "Disabled",
"updated_field": "Updated %s on %s",
"status_active": "active",
"status_disabled": "disabled",
// ── API Keys ──
"keys_title": "🔐 API Keys",
"keys_help": " [↑↓/jk] Navigate • [a] Add • [e] Edit • [d] Delete • [c] Copy • [r] Refresh",
"no_keys": " No API Keys. Press [a] to add",
"access_keys": "Access API Keys",
"confirm_delete_key": "⚠ Delete %s? [y/n]",
"key_added": "API Key added",
"key_updated": "API Key updated",
"key_deleted": "API Key deleted",
"copied": "✓ Copied to clipboard",
"copy_failed": "✗ Copy failed",
"new_key_prompt": " New Key: ",
"edit_key_prompt": " Edit Key: ",
"enter_add": " Enter: Add • Esc: Cancel",
"enter_save_esc": " Enter: Save • Esc: Cancel",
// ── OAuth ──
"oauth_title": "🔐 OAuth Login",
"oauth_select": " Select a provider and press [Enter] to start OAuth login:",
"oauth_help": " [↑↓/jk] Navigate • [Enter] Login • [Esc] Clear status",
"oauth_initiating": "⏳ Initiating %s login...",
"oauth_success": "Authentication successful! Refresh Auth Files tab to see the new credential.",
"oauth_completed": "Authentication flow completed.",
"oauth_failed": "Authentication failed",
"oauth_timeout": "OAuth flow timed out (5 minutes)",
"oauth_press_esc": " Press [Esc] to cancel",
"oauth_auth_url": " Authorization URL:",
"oauth_remote_hint": " Remote browser mode: Open the URL above in browser, paste the callback URL below after authorization.",
"oauth_callback_url": " Callback URL:",
"oauth_press_c": " Press [c] to enter callback URL • [Esc] to go back",
"oauth_submitting": "⏳ Submitting callback...",
"oauth_submit_ok": "✓ Callback submitted, waiting...",
"oauth_submit_fail": "✗ Callback submission failed",
"oauth_waiting": " Waiting for authentication...",
// ── Usage ──
"usage_title": "📈 Usage Statistics",
"usage_help": " [r] Refresh • [↑↓] Scroll",
"usage_no_data": " Usage data not available",
"usage_total_reqs": "Total Requests",
"usage_total_tokens": "Total Tokens",
"usage_success": "Success",
"usage_failure": "Failed",
"usage_total_token_l": "Total Tokens",
"usage_rpm": "RPM",
"usage_tpm": "TPM",
"usage_req_by_hour": "Requests by Hour",
"usage_tok_by_hour": "Token Usage by Hour",
"usage_req_by_day": "Requests by Day",
"usage_api_detail": "API Detail Statistics",
"usage_input": "Input",
"usage_output": "Output",
"usage_cached": "Cached",
"usage_reasoning": "Reasoning",
// ── Logs ──
"logs_title": "📋 Logs",
"logs_auto_scroll": "● AUTO-SCROLL",
"logs_paused": "○ PAUSED",
"logs_filter": "Filter",
"logs_lines": "Lines",
"logs_help": " [a] Auto-scroll • [c] Clear • [1] All [2] info+ [3] warn+ [4] error • [↑↓] Scroll",
"logs_waiting": " Waiting for log output...",
}

405
internal/tui/keys_tab.go Normal file
View File

@@ -0,0 +1,405 @@
package tui
import (
"fmt"
"strings"
"github.com/atotto/clipboard"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// keysTabModel displays and manages API keys.
type keysTabModel struct {
client *Client
viewport viewport.Model
keys []string
gemini []map[string]any
claude []map[string]any
codex []map[string]any
vertex []map[string]any
openai []map[string]any
err error
width int
height int
ready bool
cursor int
confirm int // -1 = no deletion pending
status string
// Editing / Adding
editing bool
adding bool
editIdx int
editInput textinput.Model
}
type keysDataMsg struct {
apiKeys []string
gemini []map[string]any
claude []map[string]any
codex []map[string]any
vertex []map[string]any
openai []map[string]any
err error
}
type keyActionMsg struct {
action string
err error
}
func newKeysTabModel(client *Client) keysTabModel {
ti := textinput.New()
ti.CharLimit = 512
ti.Prompt = " Key: "
return keysTabModel{
client: client,
confirm: -1,
editInput: ti,
}
}
func (m keysTabModel) Init() tea.Cmd {
return m.fetchKeys
}
func (m keysTabModel) fetchKeys() tea.Msg {
result := keysDataMsg{}
apiKeys, err := m.client.GetAPIKeys()
if err != nil {
result.err = err
return result
}
result.apiKeys = apiKeys
result.gemini, _ = m.client.GetGeminiKeys()
result.claude, _ = m.client.GetClaudeKeys()
result.codex, _ = m.client.GetCodexKeys()
result.vertex, _ = m.client.GetVertexKeys()
result.openai, _ = m.client.GetOpenAICompat()
return result
}
func (m keysTabModel) Update(msg tea.Msg) (keysTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case keysDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.keys = msg.apiKeys
m.gemini = msg.gemini
m.claude = msg.claude
m.codex = msg.codex
m.vertex = msg.vertex
m.openai = msg.openai
if m.cursor >= len(m.keys) {
m.cursor = max(0, len(m.keys)-1)
}
}
m.viewport.SetContent(m.renderContent())
return m, nil
case keyActionMsg:
if msg.err != nil {
m.status = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.status = successStyle.Render("✓ " + msg.action)
}
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, m.fetchKeys
case tea.KeyMsg:
// ---- Editing / Adding mode ----
if m.editing || m.adding {
switch msg.String() {
case "enter":
value := strings.TrimSpace(m.editInput.Value())
if value == "" {
m.editing = false
m.adding = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
}
isAdding := m.adding
editIdx := m.editIdx
m.editing = false
m.adding = false
m.editInput.Blur()
if isAdding {
return m, func() tea.Msg {
err := m.client.AddAPIKey(value)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_added")}
}
}
return m, func() tea.Msg {
err := m.client.EditAPIKey(editIdx, value)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_updated")}
}
case "esc":
m.editing = false
m.adding = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.editInput, cmd = m.editInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
// ---- Delete confirmation ----
if m.confirm >= 0 {
switch msg.String() {
case "y", "Y":
idx := m.confirm
m.confirm = -1
return m, func() tea.Msg {
err := m.client.DeleteAPIKey(idx)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_deleted")}
}
case "n", "N", "esc":
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, nil
}
return m, nil
}
// ---- Normal mode ----
switch msg.String() {
case "j", "down":
if len(m.keys) > 0 {
m.cursor = (m.cursor + 1) % len(m.keys)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "k", "up":
if len(m.keys) > 0 {
m.cursor = (m.cursor - 1 + len(m.keys)) % len(m.keys)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "a":
// Add new key
m.adding = true
m.editing = false
m.editInput.SetValue("")
m.editInput.Prompt = T("new_key_prompt")
m.editInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
case "e":
// Edit selected key
if m.cursor < len(m.keys) {
m.editing = true
m.adding = false
m.editIdx = m.cursor
m.editInput.SetValue(m.keys[m.cursor])
m.editInput.Prompt = T("edit_key_prompt")
m.editInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
}
return m, nil
case "d":
// Delete selected key
if m.cursor < len(m.keys) {
m.confirm = m.cursor
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "c":
// Copy selected key to clipboard
if m.cursor < len(m.keys) {
key := m.keys[m.cursor]
if err := clipboard.WriteAll(key); err != nil {
m.status = errorStyle.Render(T("copy_failed") + ": " + err.Error())
} else {
m.status = successStyle.Render(T("copied"))
}
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "r":
m.status = ""
return m, m.fetchKeys
default:
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *keysTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.editInput.Width = w - 16
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m keysTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m keysTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("keys_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("keys_help")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.err != nil {
sb.WriteString(errorStyle.Render(T("error_prefix") + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
// ━━━ Access API Keys (interactive) ━━━
sb.WriteString(tableHeaderStyle.Render(fmt.Sprintf(" %s (%d)", T("access_keys"), len(m.keys))))
sb.WriteString("\n")
if len(m.keys) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_keys")))
sb.WriteString("\n")
}
for i, key := range m.keys {
cursor := " "
rowStyle := lipgloss.NewStyle()
if i == m.cursor {
cursor = "▸ "
rowStyle = lipgloss.NewStyle().Bold(true)
}
row := fmt.Sprintf("%s%d. %s", cursor, i+1, maskKey(key))
sb.WriteString(rowStyle.Render(row))
sb.WriteString("\n")
// Delete confirmation
if m.confirm == i {
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete_key"), maskKey(key))))
sb.WriteString("\n")
}
// Edit input
if m.editing && m.editIdx == i {
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("enter_save_esc")))
sb.WriteString("\n")
}
}
// Add input
if m.adding {
sb.WriteString("\n")
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("enter_add")))
sb.WriteString("\n")
}
sb.WriteString("\n")
// ━━━ Provider Keys (read-only display) ━━━
renderProviderKeys(&sb, "Gemini API Keys", m.gemini)
renderProviderKeys(&sb, "Claude API Keys", m.claude)
renderProviderKeys(&sb, "Codex API Keys", m.codex)
renderProviderKeys(&sb, "Vertex API Keys", m.vertex)
if len(m.openai) > 0 {
renderSection(&sb, "OpenAI Compatibility", len(m.openai))
for i, entry := range m.openai {
name := getString(entry, "name")
baseURL := getString(entry, "base-url")
prefix := getString(entry, "prefix")
info := name
if prefix != "" {
info += " (prefix: " + prefix + ")"
}
if baseURL != "" {
info += " → " + baseURL
}
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
}
sb.WriteString("\n")
}
if m.status != "" {
sb.WriteString(m.status)
sb.WriteString("\n")
}
return sb.String()
}
func renderSection(sb *strings.Builder, title string, count int) {
header := fmt.Sprintf("%s (%d)", title, count)
sb.WriteString(tableHeaderStyle.Render(" " + header))
sb.WriteString("\n")
}
func renderProviderKeys(sb *strings.Builder, title string, keys []map[string]any) {
if len(keys) == 0 {
return
}
renderSection(sb, title, len(keys))
for i, key := range keys {
apiKey := getString(key, "api-key")
prefix := getString(key, "prefix")
baseURL := getString(key, "base-url")
info := maskKey(apiKey)
if prefix != "" {
info += " (prefix: " + prefix + ")"
}
if baseURL != "" {
info += " → " + baseURL
}
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
}
sb.WriteString("\n")
}
func maskKey(key string) string {
if len(key) <= 8 {
return strings.Repeat("*", len(key))
}
return key[:4] + strings.Repeat("*", len(key)-8) + key[len(key)-4:]
}

78
internal/tui/loghook.go Normal file
View File

@@ -0,0 +1,78 @@
package tui
import (
"fmt"
"strings"
"sync"
log "github.com/sirupsen/logrus"
)
// LogHook is a logrus hook that captures log entries and sends them to a channel.
type LogHook struct {
ch chan string
formatter log.Formatter
mu sync.Mutex
levels []log.Level
}
// NewLogHook creates a new LogHook with a buffered channel of the given size.
func NewLogHook(bufSize int) *LogHook {
return &LogHook{
ch: make(chan string, bufSize),
formatter: &log.TextFormatter{DisableColors: true, FullTimestamp: true},
levels: log.AllLevels,
}
}
// SetFormatter sets a custom formatter for the hook.
func (h *LogHook) SetFormatter(f log.Formatter) {
h.mu.Lock()
defer h.mu.Unlock()
h.formatter = f
}
// Levels returns the log levels this hook should fire on.
func (h *LogHook) Levels() []log.Level {
return h.levels
}
// Fire is called by logrus when a log entry is fired.
func (h *LogHook) Fire(entry *log.Entry) error {
h.mu.Lock()
f := h.formatter
h.mu.Unlock()
var line string
if f != nil {
b, err := f.Format(entry)
if err == nil {
line = strings.TrimRight(string(b), "\n\r")
} else {
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
}
} else {
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
}
// Non-blocking send
select {
case h.ch <- line:
default:
// Drop oldest if full
select {
case <-h.ch:
default:
}
select {
case h.ch <- line:
default:
}
}
return nil
}
// Chan returns the channel to read log lines from.
func (h *LogHook) Chan() <-chan string {
return h.ch
}

261
internal/tui/logs_tab.go Normal file
View File

@@ -0,0 +1,261 @@
package tui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
)
// logsTabModel displays real-time log lines from hook/API source.
type logsTabModel struct {
client *Client
hook *LogHook
viewport viewport.Model
lines []string
maxLines int
autoScroll bool
width int
height int
ready bool
filter string // "", "debug", "info", "warn", "error"
after int64
lastErr error
}
type logsPollMsg struct {
lines []string
latest int64
err error
}
type logsTickMsg struct{}
type logLineMsg string
func newLogsTabModel(client *Client, hook *LogHook) logsTabModel {
return logsTabModel{
client: client,
hook: hook,
maxLines: 5000,
autoScroll: true,
}
}
func (m logsTabModel) Init() tea.Cmd {
if m.hook != nil {
return m.waitForLog
}
return m.fetchLogs
}
func (m logsTabModel) fetchLogs() tea.Msg {
lines, latest, err := m.client.GetLogs(m.after, 200)
return logsPollMsg{
lines: lines,
latest: latest,
err: err,
}
}
func (m logsTabModel) waitForNextPoll() tea.Cmd {
return tea.Tick(2*time.Second, func(_ time.Time) tea.Msg {
return logsTickMsg{}
})
}
func (m logsTabModel) waitForLog() tea.Msg {
if m.hook == nil {
return nil
}
line, ok := <-m.hook.Chan()
if !ok {
return nil
}
return logLineMsg(line)
}
func (m logsTabModel) Update(msg tea.Msg) (logsTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderLogs())
return m, nil
case logsTickMsg:
if m.hook != nil {
return m, nil
}
return m, m.fetchLogs
case logsPollMsg:
if m.hook != nil {
return m, nil
}
if msg.err != nil {
m.lastErr = msg.err
} else {
m.lastErr = nil
m.after = msg.latest
if len(msg.lines) > 0 {
m.lines = append(m.lines, msg.lines...)
if len(m.lines) > m.maxLines {
m.lines = m.lines[len(m.lines)-m.maxLines:]
}
}
}
m.viewport.SetContent(m.renderLogs())
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, m.waitForNextPoll()
case logLineMsg:
m.lines = append(m.lines, string(msg))
if len(m.lines) > m.maxLines {
m.lines = m.lines[len(m.lines)-m.maxLines:]
}
m.viewport.SetContent(m.renderLogs())
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, m.waitForLog
case tea.KeyMsg:
switch msg.String() {
case "a":
m.autoScroll = !m.autoScroll
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, nil
case "c":
m.lines = nil
m.lastErr = nil
m.viewport.SetContent(m.renderLogs())
return m, nil
case "1":
m.filter = ""
m.viewport.SetContent(m.renderLogs())
return m, nil
case "2":
m.filter = "info"
m.viewport.SetContent(m.renderLogs())
return m, nil
case "3":
m.filter = "warn"
m.viewport.SetContent(m.renderLogs())
return m, nil
case "4":
m.filter = "error"
m.viewport.SetContent(m.renderLogs())
return m, nil
default:
wasAtBottom := m.viewport.AtBottom()
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
// If user scrolls up, disable auto-scroll
if !m.viewport.AtBottom() && wasAtBottom {
m.autoScroll = false
}
// If user scrolls to bottom, re-enable auto-scroll
if m.viewport.AtBottom() {
m.autoScroll = true
}
return m, cmd
}
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *logsTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderLogs())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m logsTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m logsTabModel) renderLogs() string {
var sb strings.Builder
scrollStatus := successStyle.Render(T("logs_auto_scroll"))
if !m.autoScroll {
scrollStatus = warningStyle.Render(T("logs_paused"))
}
filterLabel := "ALL"
if m.filter != "" {
filterLabel = strings.ToUpper(m.filter) + "+"
}
header := fmt.Sprintf(" %s %s %s: %s %s: %d",
T("logs_title"), scrollStatus, T("logs_filter"), filterLabel, T("logs_lines"), len(m.lines))
sb.WriteString(titleStyle.Render(header))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("logs_help")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.lastErr != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.lastErr.Error()))
sb.WriteString("\n")
}
if len(m.lines) == 0 {
sb.WriteString(subtitleStyle.Render(T("logs_waiting")))
return sb.String()
}
for _, line := range m.lines {
if m.filter != "" && !m.matchLevel(line) {
continue
}
styled := m.styleLine(line)
sb.WriteString(styled)
sb.WriteString("\n")
}
return sb.String()
}
func (m logsTabModel) matchLevel(line string) bool {
switch m.filter {
case "error":
return strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") || strings.Contains(line, "[panic]")
case "warn":
return strings.Contains(line, "[warn") || strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]")
case "info":
return !strings.Contains(line, "[debug]")
default:
return true
}
}
func (m logsTabModel) styleLine(line string) string {
if strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") {
return logErrorStyle.Render(line)
}
if strings.Contains(line, "[warn") {
return logWarnStyle.Render(line)
}
if strings.Contains(line, "[info") {
return logInfoStyle.Render(line)
}
if strings.Contains(line, "[debug]") {
return logDebugStyle.Render(line)
}
return line
}

473
internal/tui/oauth_tab.go Normal file
View File

@@ -0,0 +1,473 @@
package tui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// oauthProvider represents an OAuth provider option.
type oauthProvider struct {
name string
apiPath string // management API path
emoji string
}
var oauthProviders = []oauthProvider{
{"Gemini CLI", "gemini-cli-auth-url", "🟦"},
{"Claude (Anthropic)", "anthropic-auth-url", "🟧"},
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
{"Antigravity", "antigravity-auth-url", "🟪"},
{"Qwen", "qwen-auth-url", "🟨"},
{"Kimi", "kimi-auth-url", "🟫"},
{"IFlow", "iflow-auth-url", "⬜"},
}
// oauthTabModel handles OAuth login flows.
type oauthTabModel struct {
client *Client
viewport viewport.Model
cursor int
state oauthState
message string
err error
width int
height int
ready bool
// Remote browser mode
authURL string // auth URL to display
authState string // OAuth state parameter
providerName string // current provider name
callbackInput textinput.Model
inputActive bool // true when user is typing callback URL
}
type oauthState int
const (
oauthIdle oauthState = iota
oauthPending
oauthRemote // remote browser mode: waiting for manual callback
oauthSuccess
oauthError
)
// Messages
type oauthStartMsg struct {
url string
state string
providerName string
err error
}
type oauthPollMsg struct {
done bool
message string
err error
}
type oauthCallbackSubmitMsg struct {
err error
}
func newOAuthTabModel(client *Client) oauthTabModel {
ti := textinput.New()
ti.Placeholder = "http://localhost:.../auth/callback?code=...&state=..."
ti.CharLimit = 2048
ti.Prompt = " 回调 URL: "
return oauthTabModel{
client: client,
callbackInput: ti,
}
}
func (m oauthTabModel) Init() tea.Cmd {
return nil
}
func (m oauthTabModel) Update(msg tea.Msg) (oauthTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case oauthStartMsg:
if msg.err != nil {
m.state = oauthError
m.err = msg.err
m.message = errorStyle.Render("✗ " + msg.err.Error())
m.viewport.SetContent(m.renderContent())
return m, nil
}
m.authURL = msg.url
m.authState = msg.state
m.providerName = msg.providerName
m.state = oauthRemote
m.callbackInput.SetValue("")
m.callbackInput.Focus()
m.inputActive = true
m.message = ""
m.viewport.SetContent(m.renderContent())
// Also start polling in the background
return m, tea.Batch(textinput.Blink, m.pollOAuthStatus(msg.state))
case oauthPollMsg:
if msg.err != nil {
m.state = oauthError
m.err = msg.err
m.message = errorStyle.Render("✗ " + msg.err.Error())
m.inputActive = false
m.callbackInput.Blur()
} else if msg.done {
m.state = oauthSuccess
m.message = successStyle.Render("✓ " + msg.message)
m.inputActive = false
m.callbackInput.Blur()
} else {
m.message = warningStyle.Render("⏳ " + msg.message)
}
m.viewport.SetContent(m.renderContent())
return m, nil
case oauthCallbackSubmitMsg:
if msg.err != nil {
m.message = errorStyle.Render(T("oauth_submit_fail") + ": " + msg.err.Error())
} else {
m.message = successStyle.Render(T("oauth_submit_ok"))
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
// ---- Input active: typing callback URL ----
if m.inputActive {
switch msg.String() {
case "enter":
callbackURL := m.callbackInput.Value()
if callbackURL == "" {
return m, nil
}
m.inputActive = false
m.callbackInput.Blur()
m.message = warningStyle.Render(T("oauth_submitting"))
m.viewport.SetContent(m.renderContent())
return m, m.submitCallback(callbackURL)
case "esc":
m.inputActive = false
m.callbackInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.callbackInput, cmd = m.callbackInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
// ---- Remote mode but not typing ----
if m.state == oauthRemote {
switch msg.String() {
case "c", "C":
// Re-activate input
m.inputActive = true
m.callbackInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
case "esc":
m.state = oauthIdle
m.message = ""
m.authURL = ""
m.authState = ""
m.viewport.SetContent(m.renderContent())
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
// ---- Pending (auto polling) ----
if m.state == oauthPending {
if msg.String() == "esc" {
m.state = oauthIdle
m.message = ""
m.viewport.SetContent(m.renderContent())
}
return m, nil
}
// ---- Idle ----
switch msg.String() {
case "up", "k":
if m.cursor > 0 {
m.cursor--
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "down", "j":
if m.cursor < len(oauthProviders)-1 {
m.cursor++
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "enter":
if m.cursor >= 0 && m.cursor < len(oauthProviders) {
provider := oauthProviders[m.cursor]
m.state = oauthPending
m.message = warningStyle.Render(fmt.Sprintf(T("oauth_initiating"), provider.name))
m.viewport.SetContent(m.renderContent())
return m, m.startOAuth(provider)
}
return m, nil
case "esc":
m.state = oauthIdle
m.message = ""
m.err = nil
m.viewport.SetContent(m.renderContent())
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m oauthTabModel) startOAuth(provider oauthProvider) tea.Cmd {
return func() tea.Msg {
// Call the auth URL endpoint with is_webui=true
data, err := m.client.getJSON("/v0/management/" + provider.apiPath + "?is_webui=true")
if err != nil {
return oauthStartMsg{err: fmt.Errorf("failed to start %s login: %w", provider.name, err)}
}
authURL := getString(data, "url")
state := getString(data, "state")
if authURL == "" {
return oauthStartMsg{err: fmt.Errorf("no auth URL returned for %s", provider.name)}
}
// Try to open browser (best effort)
_ = openBrowser(authURL)
return oauthStartMsg{url: authURL, state: state, providerName: provider.name}
}
}
func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
return func() tea.Msg {
// Determine provider from current context
providerKey := ""
for _, p := range oauthProviders {
if p.name == m.providerName {
// Map provider name to the canonical key the API expects
switch p.apiPath {
case "gemini-cli-auth-url":
providerKey = "gemini"
case "anthropic-auth-url":
providerKey = "anthropic"
case "codex-auth-url":
providerKey = "codex"
case "antigravity-auth-url":
providerKey = "antigravity"
case "qwen-auth-url":
providerKey = "qwen"
case "kimi-auth-url":
providerKey = "kimi"
case "iflow-auth-url":
providerKey = "iflow"
}
break
}
}
body := map[string]string{
"provider": providerKey,
"redirect_url": callbackURL,
"state": m.authState,
}
err := m.client.postJSON("/v0/management/oauth-callback", body)
if err != nil {
return oauthCallbackSubmitMsg{err: err}
}
return oauthCallbackSubmitMsg{}
}
}
func (m oauthTabModel) pollOAuthStatus(state string) tea.Cmd {
return func() tea.Msg {
// Poll session status for up to 5 minutes
deadline := time.Now().Add(5 * time.Minute)
for {
if time.Now().After(deadline) {
return oauthPollMsg{done: false, err: fmt.Errorf("%s", T("oauth_timeout"))}
}
time.Sleep(2 * time.Second)
status, errMsg, err := m.client.GetAuthStatus(state)
if err != nil {
continue // Ignore transient errors
}
switch status {
case "ok":
return oauthPollMsg{
done: true,
message: T("oauth_success"),
}
case "error":
return oauthPollMsg{
done: false,
err: fmt.Errorf("%s: %s", T("oauth_failed"), errMsg),
}
case "wait":
continue
default:
return oauthPollMsg{
done: true,
message: T("oauth_completed"),
}
}
}
}
}
func (m *oauthTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.callbackInput.Width = w - 16
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m oauthTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m oauthTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("oauth_title")))
sb.WriteString("\n\n")
if m.message != "" {
sb.WriteString(" " + m.message)
sb.WriteString("\n\n")
}
// ---- Remote browser mode ----
if m.state == oauthRemote {
sb.WriteString(m.renderRemoteMode())
return sb.String()
}
if m.state == oauthPending {
sb.WriteString(helpStyle.Render(T("oauth_press_esc")))
return sb.String()
}
sb.WriteString(helpStyle.Render(T("oauth_select")))
sb.WriteString("\n\n")
for i, p := range oauthProviders {
isSelected := i == m.cursor
prefix := " "
if isSelected {
prefix = "▸ "
}
label := fmt.Sprintf("%s %s", p.emoji, p.name)
if isSelected {
label = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#FFFFFF")).Background(colorPrimary).Padding(0, 1).Render(label)
} else {
label = lipgloss.NewStyle().Foreground(colorText).Padding(0, 1).Render(label)
}
sb.WriteString(prefix + label + "\n")
}
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("oauth_help")))
return sb.String()
}
func (m oauthTabModel) renderRemoteMode() string {
var sb strings.Builder
providerStyle := lipgloss.NewStyle().Bold(true).Foreground(colorHighlight)
sb.WriteString(providerStyle.Render(fmt.Sprintf(" ✦ %s OAuth", m.providerName)))
sb.WriteString("\n\n")
// Auth URL section
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_auth_url")))
sb.WriteString("\n")
// Wrap URL to fit terminal width
urlStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
maxURLWidth := m.width - 6
if maxURLWidth < 40 {
maxURLWidth = 40
}
wrappedURL := wrapText(m.authURL, maxURLWidth)
for _, line := range wrappedURL {
sb.WriteString(" " + urlStyle.Render(line) + "\n")
}
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("oauth_remote_hint")))
sb.WriteString("\n\n")
// Callback URL input
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_callback_url")))
sb.WriteString("\n")
if m.inputActive {
sb.WriteString(m.callbackInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(" " + T("enter_submit") + " • " + T("esc_cancel")))
} else {
sb.WriteString(helpStyle.Render(T("oauth_press_c")))
}
sb.WriteString("\n\n")
sb.WriteString(warningStyle.Render(T("oauth_waiting")))
return sb.String()
}
// wrapText splits a long string into lines of at most maxWidth characters.
func wrapText(s string, maxWidth int) []string {
if maxWidth <= 0 {
return []string{s}
}
var lines []string
for len(s) > maxWidth {
lines = append(lines, s[:maxWidth])
s = s[maxWidth:]
}
if len(s) > 0 {
lines = append(lines, s)
}
return lines
}

126
internal/tui/styles.go Normal file
View File

@@ -0,0 +1,126 @@
// Package tui provides a terminal-based management interface for CLIProxyAPI.
package tui
import "github.com/charmbracelet/lipgloss"
// Color palette
var (
colorPrimary = lipgloss.Color("#7C3AED") // violet
colorSecondary = lipgloss.Color("#6366F1") // indigo
colorSuccess = lipgloss.Color("#22C55E") // green
colorWarning = lipgloss.Color("#EAB308") // yellow
colorError = lipgloss.Color("#EF4444") // red
colorInfo = lipgloss.Color("#3B82F6") // blue
colorMuted = lipgloss.Color("#6B7280") // gray
colorBg = lipgloss.Color("#1E1E2E") // dark bg
colorSurface = lipgloss.Color("#313244") // slightly lighter
colorText = lipgloss.Color("#CDD6F4") // light text
colorSubtext = lipgloss.Color("#A6ADC8") // dimmer text
colorBorder = lipgloss.Color("#45475A") // border
colorHighlight = lipgloss.Color("#F5C2E7") // pink highlight
)
// Tab bar styles
var (
tabActiveStyle = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#FFFFFF")).
Background(colorPrimary).
Padding(0, 2)
tabInactiveStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Background(colorSurface).
Padding(0, 2)
tabBarStyle = lipgloss.NewStyle().
Background(colorSurface).
PaddingLeft(1).
PaddingBottom(0)
)
// Content styles
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHighlight).
MarginBottom(1)
subtitleStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Italic(true)
labelStyle = lipgloss.NewStyle().
Foreground(colorInfo).
Bold(true).
Width(24)
valueStyle = lipgloss.NewStyle().
Foreground(colorText)
sectionStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(colorBorder).
Padding(1, 2)
errorStyle = lipgloss.NewStyle().
Foreground(colorError).
Bold(true)
successStyle = lipgloss.NewStyle().
Foreground(colorSuccess)
warningStyle = lipgloss.NewStyle().
Foreground(colorWarning)
statusBarStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Background(colorSurface).
PaddingLeft(1).
PaddingRight(1)
helpStyle = lipgloss.NewStyle().
Foreground(colorMuted)
)
// Log level styles
var (
logDebugStyle = lipgloss.NewStyle().Foreground(colorMuted)
logInfoStyle = lipgloss.NewStyle().Foreground(colorInfo)
logWarnStyle = lipgloss.NewStyle().Foreground(colorWarning)
logErrorStyle = lipgloss.NewStyle().Foreground(colorError)
)
// Table styles
var (
tableHeaderStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHighlight).
BorderBottom(true).
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(colorBorder)
tableCellStyle = lipgloss.NewStyle().
Foreground(colorText).
PaddingRight(2)
tableSelectedStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFFFFF")).
Background(colorPrimary).
Bold(true)
)
func logLevelStyle(level string) lipgloss.Style {
switch level {
case "debug":
return logDebugStyle
case "info":
return logInfoStyle
case "warn", "warning":
return logWarnStyle
case "error", "fatal", "panic":
return logErrorStyle
default:
return logInfoStyle
}
}

364
internal/tui/usage_tab.go Normal file
View File

@@ -0,0 +1,364 @@
package tui
import (
"fmt"
"sort"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// usageTabModel displays usage statistics with charts and breakdowns.
type usageTabModel struct {
client *Client
viewport viewport.Model
usage map[string]any
err error
width int
height int
ready bool
}
type usageDataMsg struct {
usage map[string]any
err error
}
func newUsageTabModel(client *Client) usageTabModel {
return usageTabModel{
client: client,
}
}
func (m usageTabModel) Init() tea.Cmd {
return m.fetchData
}
func (m usageTabModel) fetchData() tea.Msg {
usage, err := m.client.GetUsage()
return usageDataMsg{usage: usage, err: err}
}
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case usageDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.usage = msg.usage
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *usageTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m usageTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m usageTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("usage_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("usage_help")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if m.usage == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
usageMap, _ := m.usage["usage"].(map[string]any)
if usageMap == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
totalReqs := int64(getFloat(usageMap, "total_requests"))
successCnt := int64(getFloat(usageMap, "success_count"))
failureCnt := int64(getFloat(usageMap, "failure_count"))
totalTokens := int64(getFloat(usageMap, "total_tokens"))
// ━━━ Overview Cards ━━━
cardWidth := 20
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 16 {
cardWidth = 16
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(3)
// Total Requests
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
))
// Total Tokens
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
))
// RPM
rpm := float64(0)
if totalReqs > 0 {
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
}
}
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
))
// TPM
tpm := float64(0)
if totalTokens > 0 {
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
}
}
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Requests by Hour (ASCII bar chart) ━━━
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
sb.WriteString("\n")
}
// ━━━ Tokens by Hour ━━━
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
sb.WriteString("\n")
}
// ━━━ Requests by Day ━━━
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
sb.WriteString("\n")
}
// ━━━ API Detail Stats ━━━
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for apiName, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
apiReqs := int64(getFloat(apiMap, "total_requests"))
apiToks := int64(getFloat(apiMap, "total_tokens"))
row := fmt.Sprintf(" %-30s %10d %12s",
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
sb.WriteString("\n")
// Per-model breakdown
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
mReqs := int64(getFloat(stats, "total_requests"))
mToks := int64(getFloat(stats, "total_tokens"))
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
truncate(model, 28), mReqs, formatLargeNumber(mToks))
sb.WriteString(tableCellStyle.Render(mRow))
sb.WriteString("\n")
// Token type breakdown from details
sb.WriteString(m.renderTokenBreakdown(stats))
}
}
}
}
}
}
sb.WriteString("\n")
return sb.String()
}
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
tokens, ok := dm["tokens"].(map[string]any)
if !ok {
continue
}
inputTotal += int64(getFloat(tokens, "input_tokens"))
outputTotal += int64(getFloat(tokens, "output_tokens"))
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
}
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
return ""
}
parts := []string{}
if inputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
}
if outputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
}
if cachedTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
}
if reasoningTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
}
return fmt.Sprintf(" │ %s\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
}
// renderBarChart renders a simple ASCII horizontal bar chart.
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
if maxBarWidth < 10 {
maxBarWidth = 10
}
// Sort keys
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
sort.Strings(keys)
// Find max value
maxVal := float64(0)
for _, k := range keys {
v := getFloat(data, k)
if v > maxVal {
maxVal = v
}
}
if maxVal == 0 {
return ""
}
barStyle := lipgloss.NewStyle().Foreground(barColor)
var sb strings.Builder
labelWidth := 12
barAvail := maxBarWidth - labelWidth - 12
if barAvail < 5 {
barAvail = 5
}
for _, k := range keys {
v := getFloat(data, k)
barLen := int(v / maxVal * float64(barAvail))
if barLen < 1 && v > 0 {
barLen = 1
}
bar := strings.Repeat("█", barLen)
label := k
if len(label) > labelWidth {
label = label[:labelWidth]
}
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
labelWidth, label,
barStyle.Render(bar),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
))
}
return sb.String()
}

View File

@@ -112,12 +112,13 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
modelName := gjson.GetBytes(rawJSON, "model").String()
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -165,7 +166,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
modelName := gjson.GetBytes(rawJSON, "model").String()
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
@@ -194,6 +195,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
}
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -225,7 +227,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
// This allows proper cleanup and cancellation of ongoing requests
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -257,6 +259,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
@@ -264,6 +267,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
// Success! Set headers now.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write the first chunk
if len(chunk) > 0 {

View File

@@ -159,7 +159,8 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context
modelName := modelResult.String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
@@ -172,12 +173,13 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
modelName := modelResult.String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}

View File

@@ -188,7 +188,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
@@ -223,6 +223,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
if alt == "" {
setSSEHeaders()
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
@@ -232,6 +233,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
if alt == "" {
setSSEHeaders()
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk
if alt == "" {
@@ -262,12 +264,13 @@ func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, r
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -286,13 +289,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}

View File

@@ -179,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
return retries
}
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
// Default is false.
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
return cfg != nil && cfg.PassthroughHeaders
}
func requestExecutionMetadata(ctx context.Context) map[string]any {
// Idempotency-Key is an optional client-supplied header used to correlate retries.
// It is forwarded as execution metadata; when absent we generate a UUID.
@@ -461,10 +467,10 @@ func appendAPIResponse(c *gin.Context, data []byte) {
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
return nil, errMsg
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
@@ -497,17 +503,20 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
addon = hdr.Clone()
}
}
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
return resp.Payload, nil
if !PassthroughHeadersEnabled(h.Cfg) {
return resp.Payload, nil, nil
}
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
}
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
return nil, errMsg
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
@@ -540,20 +549,24 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
addon = hdr.Clone()
}
}
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
return resp.Payload, nil
if !PassthroughHeadersEnabled(h.Cfg) {
return resp.Payload, nil, nil
}
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
}
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
// The returned http.Header carries upstream response headers captured before streaming begins.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- errMsg
close(errChan)
return nil, errChan
return nil, nil, errChan
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
@@ -572,7 +585,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
SourceFormat: sdktranslator.FromString(handlerType),
}
opts.Metadata = reqMeta
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
status := http.StatusInternalServerError
@@ -589,8 +602,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
close(errChan)
return nil, errChan
return nil, nil, errChan
}
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
var upstreamHeaders http.Header
if passthroughHeadersEnabled {
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
if upstreamHeaders == nil {
upstreamHeaders = make(http.Header)
}
}
chunks := streamResult.Chunks
dataChan := make(chan []byte)
errChan := make(chan *interfaces.ErrorMessage, 1)
go func() {
@@ -664,9 +688,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
if !sentPayload {
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
bootstrapRetries++
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if retryErr == nil {
chunks = retryChunks
if passthroughHeadersEnabled {
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
}
chunks = retryResult.Chunks
continue outer
}
streamErr = retryErr
@@ -689,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return
}
if len(chunk.Payload) > 0 {
if handlerType == "openai-response" {
if err := validateSSEDataJSON(chunk.Payload); err != nil {
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
return
}
}
sentPayload = true
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
return
@@ -697,7 +730,36 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
}
}
}()
return dataChan, errChan
return dataChan, upstreamHeaders, errChan
}
func validateSSEDataJSON(chunk []byte) error {
for _, line := range bytes.Split(chunk, []byte("\n")) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
data := bytes.TrimSpace(line[5:])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
continue
}
if json.Valid(data) {
continue
}
const max = 512
preview := data
if len(preview) > max {
preview = preview[:max]
}
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
}
return nil
}
func statusFromError(err error) int {
@@ -757,13 +819,33 @@ func cloneBytes(src []byte) []byte {
return dst
}
func cloneHeader(src http.Header) http.Header {
if src == nil {
return nil
}
dst := make(http.Header, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func replaceHeader(dst http.Header, src http.Header) {
for key := range dst {
delete(dst, key)
}
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
}
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError
if msg != nil && msg.StatusCode > 0 {
status = msg.StatusCode
}
if msg != nil && msg.Addon != nil {
if msg != nil && msg.Addon != nil && PassthroughHeadersEnabled(h.Cfg) {
for key, values := range msg.Addon {
if len(values) == 0 {
continue

View File

@@ -0,0 +1,68 @@
package handlers
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestWriteErrorResponse_AddonHeadersDisabledByDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
handler := NewBaseAPIHandlers(nil, nil)
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
StatusCode: http.StatusTooManyRequests,
Error: errors.New("rate limit"),
Addon: http.Header{
"Retry-After": {"30"},
"X-Request-Id": {"req-1"},
},
})
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
}
if got := recorder.Header().Get("Retry-After"); got != "" {
t.Fatalf("Retry-After should be empty when passthrough is disabled, got %q", got)
}
if got := recorder.Header().Get("X-Request-Id"); got != "" {
t.Fatalf("X-Request-Id should be empty when passthrough is disabled, got %q", got)
}
}
func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Writer.Header().Set("X-Request-Id", "old-value")
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{PassthroughHeaders: true}, nil)
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
StatusCode: http.StatusTooManyRequests,
Error: errors.New("rate limit"),
Addon: http.Header{
"Retry-After": {"30"},
"X-Request-Id": {"new-1", "new-2"},
},
})
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
}
if got := recorder.Header().Get("Retry-After"); got != "30" {
t.Fatalf("Retry-After = %q, want %q", got, "30")
}
if got := recorder.Header().Values("X-Request-Id"); !reflect.DeepEqual(got, []string{"new-1", "new-2"}) {
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
}
}

View File

@@ -23,7 +23,7 @@ func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreex
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
e.calls++
call := e.calls
@@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth,
},
}
close(ch)
return ch, nil
return &coreexecutor.StreamResult{
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
Chunks: ch,
}, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return ch, nil
return &coreexecutor.StreamResult{
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
Chunks: ch,
}, nil
}
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -81,7 +87,7 @@ func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
e.calls++
e.mu.Unlock()
@@ -97,7 +103,7 @@ func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreaut
},
}
close(ch)
return ch, nil
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -128,13 +134,44 @@ type authAwareStreamExecutor struct {
authIDs []string
}
type invalidJSONStreamExecutor struct{}
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
ch := make(chan coreexecutor.StreamChunk, 1)
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
_ = ctx
_ = req
_ = opts
@@ -160,12 +197,12 @@ func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *corea
},
}
close(ch)
return ch, nil
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return ch, nil
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -231,11 +268,12 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
PassthroughHeaders: true,
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
@@ -257,6 +295,70 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
if executor.Calls() != 2 {
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
}
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
if upstreamAttemptHeader != "2" {
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
}
}
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if upstreamHeaders != nil {
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
}
}
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
@@ -296,7 +398,7 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
BootstrapRetries: 1,
},
}, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
@@ -367,7 +469,7 @@ func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T)
},
}, manager)
ctx := WithPinnedAuthID(context.Background(), "auth1")
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
@@ -431,7 +533,7 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
selectedAuthID = authID
})
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
@@ -453,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
}
}
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
executor := &invalidJSONStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
gotErr := false
for msg := range errChan {
if msg == nil {
continue
}
if msg.StatusCode != http.StatusBadGateway {
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
}
if msg.Error == nil {
t.Fatalf("expected error")
}
gotErr = true
}
if !gotErr {
t.Fatalf("expected terminal error")
}
}

View File

@@ -0,0 +1,80 @@
package handlers
import (
"net/http"
"strings"
)
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
// be forwarded by proxies, plus security-sensitive headers that should not leak.
var hopByHopHeaders = map[string]struct{}{
// RFC 7230 hop-by-hop
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
// Security-sensitive
"Set-Cookie": {},
// CPA-managed (set by handlers, not upstream)
"Content-Length": {},
"Content-Encoding": {},
}
// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive
// headers removed. Returns nil if src is nil or empty after filtering.
func FilterUpstreamHeaders(src http.Header) http.Header {
if src == nil {
return nil
}
connectionScoped := connectionScopedHeaders(src)
dst := make(http.Header)
for key, values := range src {
canonicalKey := http.CanonicalHeaderKey(key)
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
continue
}
if _, scoped := connectionScoped[canonicalKey]; scoped {
continue
}
dst[key] = values
}
if len(dst) == 0 {
return nil
}
return dst
}
func connectionScopedHeaders(src http.Header) map[string]struct{} {
scoped := make(map[string]struct{})
for _, rawValue := range src.Values("Connection") {
for _, token := range strings.Split(rawValue, ",") {
headerName := strings.TrimSpace(token)
if headerName == "" {
continue
}
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
}
}
return scoped
}
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
func WriteUpstreamHeaders(dst http.Header, src http.Header) {
if src == nil {
return
}
for key, values := range src {
// Don't overwrite headers already set by CPA handlers
if dst.Get(key) != "" {
continue
}
for _, v := range values {
dst.Add(key, v)
}
}
}

View File

@@ -0,0 +1,55 @@
package handlers
import (
"net/http"
"testing"
)
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
src.Add("Connection", "x-hop-c")
src.Set("Keep-Alive", "timeout=5")
src.Set("X-Hop-A", "a")
src.Set("X-Hop-B", "b")
src.Set("X-Hop-C", "c")
src.Set("X-Request-Id", "req-1")
src.Set("Set-Cookie", "session=secret")
filtered := FilterUpstreamHeaders(src)
if filtered == nil {
t.Fatalf("expected filtered headers, got nil")
}
requestID := filtered.Get("X-Request-Id")
if requestID != "req-1" {
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
}
blockedHeaderKeys := []string{
"Connection",
"Keep-Alive",
"X-Hop-A",
"X-Hop-B",
"X-Hop-C",
"Set-Cookie",
}
for _, key := range blockedHeaderKeys {
value := filtered.Get(key)
if value != "" {
t.Fatalf("expected %s to be removed, got %q", key, value)
}
}
}
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
src := http.Header{}
src.Add("Connection", "x-hop-a")
src.Set("X-Hop-A", "a")
src.Set("Set-Cookie", "session=secret")
filtered := FilterUpstreamHeaders(src)
if filtered != nil {
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
}
}

View File

@@ -332,6 +332,7 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
// Check if this chunk has any meaningful content
hasContent := false
hasUsage := root.Get("usage").Exists()
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
// Check if delta has content or finish_reason
@@ -350,8 +351,8 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
})
}
// If no meaningful content, return nil to indicate this chunk should be skipped
if !hasContent {
// If no meaningful content and no usage, return nil to indicate this chunk should be skipped
if !hasContent && !hasUsage {
return nil
}
@@ -410,6 +411,11 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
// Copy usage if present
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
}
return []byte(out)
}
@@ -425,12 +431,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -457,7 +464,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
@@ -490,6 +497,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
@@ -498,6 +506,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
// Success! Commit to streaming headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
flusher.Flush()
@@ -525,13 +534,14 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
completionsResp := convertChatCompletionsResponseToCompletions(resp)
_, _ = c.Writer.Write(completionsResp)
cliCancel()
@@ -562,7 +572,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
@@ -593,6 +603,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
@@ -601,6 +612,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
// Success! Set headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write the first chunk
converted := convertChatCompletionsStreamChunkToCompletions(chunk)

View File

@@ -31,7 +31,7 @@ func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Aut
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
}
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
return nil, errors.New("not implemented")
}

View File

@@ -124,13 +124,14 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) {
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -149,13 +150,14 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -183,7 +185,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
// New core execution path
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
@@ -216,6 +218,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
if !ok {
// Stream closed without data? Send headers and done.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cliCancel(nil)
@@ -224,6 +227,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
// Success! Set headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk logic (matching forwardResponsesStream)
if bytes.HasPrefix(chunk, []byte("event:")) {
@@ -261,8 +265,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
},
WriteDone: func() {
_, _ = c.Writer.Write([]byte("\n"))

View File

@@ -0,0 +1,43 @@
package openai
import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
gin.SetMode(gin.TestMode)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
h := NewOpenAIResponsesAPIHandler(base)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
t.Fatalf("expected gin writer to implement http.Flusher")
}
data := make(chan []byte)
errs := make(chan *interfaces.ErrorMessage, 1)
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
body := recorder.Body.String()
if !strings.Contains(body, `"type":"error"`) {
t.Fatalf("expected responses error chunk, got: %q", body)
}
if strings.Contains(body, `"error":{`) {
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
}
}

View File

@@ -153,7 +153,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
pinnedAuthID = strings.TrimSpace(authID)
})
}
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if errForward != nil {

View File

@@ -0,0 +1,119 @@
package handlers
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
type openAIResponsesStreamErrorChunk struct {
Type string `json:"type"`
Code string `json:"code"`
Message string `json:"message"`
SequenceNumber int `json:"sequence_number"`
}
func openAIResponsesStreamErrorCode(status int) string {
switch status {
case http.StatusUnauthorized:
return "invalid_api_key"
case http.StatusForbidden:
return "insufficient_quota"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "model_not_found"
case http.StatusRequestTimeout:
return "request_timeout"
default:
if status >= http.StatusInternalServerError {
return "internal_server_error"
}
if status >= http.StatusBadRequest {
return "invalid_request_error"
}
return "unknown_error"
}
}
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
//
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
// of chunks that requires a top-level `type` field.
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if sequenceNumber < 0 {
sequenceNumber = 0
}
message := strings.TrimSpace(errText)
if message == "" {
message = http.StatusText(status)
}
code := openAIResponsesStreamErrorCode(status)
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := payload["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
sequenceNumber = int(v)
}
}
if e, ok := payload["error"].(map[string]any); ok {
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := e["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
}
}
}
if strings.TrimSpace(code) == "" {
code = "unknown_error"
}
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: code,
Message: message,
SequenceNumber: sequenceNumber,
})
if err == nil {
return data
}
// Extremely defensive fallback.
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: "internal_server_error",
Message: message,
SequenceNumber: sequenceNumber,
})
if len(data) > 0 {
return data
}
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
}

View File

@@ -0,0 +1,48 @@
package handlers
import (
"encoding/json"
"net/http"
"testing"
)
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "unexpected EOF" {
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
}
if payload["sequence_number"] != float64(0) {
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
}
}
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(
http.StatusInternalServerError,
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
0,
)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "oops" {
t.Fatalf("message = %v, want %q", payload["message"], "oops")
}
}

View File

@@ -2,8 +2,6 @@ package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
@@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
opts = &LoginOptions{}
}
if shouldUseCodexDeviceFlow(opts) {
return a.loginWithDeviceFlow(ctx, cfg, opts)
}
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
@@ -186,39 +188,5 @@ waitForCallback:
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
}
tokenStorage := authSvc.CreateTokenStorage(authBundle)
if tokenStorage == nil || tokenStorage.Email == "" {
return nil, fmt.Errorf("codex token storage missing account information")
}
planType := ""
hashAccountID := ""
if tokenStorage.IDToken != "" {
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
}
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Codex authentication successful")
if authBundle.APIKey != "" {
fmt.Println("Codex API key obtained and stored")
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
return a.buildAuthRecord(authSvc, authBundle)
}

291
sdk/auth/codex_device.go Normal file
View File

@@ -0,0 +1,291 @@
package auth
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
codexDeviceVerificationURL = "https://auth.openai.com/codex/device"
codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback"
codexDeviceTimeout = 15 * time.Minute
codexDeviceDefaultPollIntervalSeconds = 5
)
type codexDeviceUserCodeRequest struct {
ClientID string `json:"client_id"`
}
type codexDeviceUserCodeResponse struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
UserCodeAlt string `json:"usercode"`
Interval json.RawMessage `json:"interval"`
}
type codexDeviceTokenRequest struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
}
type codexDeviceTokenResponse struct {
AuthorizationCode string `json:"authorization_code"`
CodeVerifier string `json:"code_verifier"`
CodeChallenge string `json:"code_challenge"`
}
func shouldUseCodexDeviceFlow(opts *LoginOptions) bool {
if opts == nil || opts.Metadata == nil {
return false
}
return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice)
}
func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if ctx == nil {
ctx = context.Background()
}
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient)
if err != nil {
return nil, err
}
deviceCode := strings.TrimSpace(userCodeResp.UserCode)
if deviceCode == "" {
deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt)
}
deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID)
if deviceCode == "" || deviceAuthID == "" {
return nil, fmt.Errorf("codex device flow did not return required fields")
}
pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval)
fmt.Println("Starting Codex device authentication...")
fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL)
fmt.Printf("Codex device code: %s\n", deviceCode)
if !opts.NoBrowser {
if !browser.IsAvailable() {
log.Warn("No browser available; please open the device URL manually")
} else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
}
}
tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval)
if err != nil {
return nil, err
}
authCode := strings.TrimSpace(tokenResp.AuthorizationCode)
codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier)
codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge)
if authCode == "" || codeVerifier == "" || codeChallenge == "" {
return nil, fmt.Errorf("codex device flow token response missing required fields")
}
authSvc := codex.NewCodexAuth(cfg)
authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect(
ctx,
authCode,
codexDeviceTokenExchangeRedirectURI,
&codex.PKCECodes{
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
},
)
if err != nil {
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
}
return a.buildAuthRecord(authSvc, authBundle)
}
func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) {
body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID})
if err != nil {
return nil, fmt.Errorf("failed to encode codex device request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create codex device request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to request codex device code: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read codex device code response: %w", err)
}
if !codexDeviceIsSuccessStatus(resp.StatusCode) {
trimmed := strings.TrimSpace(string(respBody))
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode)
}
if trimmed == "" {
trimmed = "empty response body"
}
return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed)
}
var parsed codexDeviceUserCodeResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("failed to decode codex device code response: %w", err)
}
return &parsed, nil
}
func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) {
deadline := time.Now().Add(codexDeviceTimeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("codex device authentication timed out after 15 minutes")
}
body, err := json.Marshal(codexDeviceTokenRequest{
DeviceAuthID: deviceAuthID,
UserCode: userCode,
})
if err != nil {
return nil, fmt.Errorf("failed to encode codex device poll request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create codex device poll request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to poll codex device token: %w", err)
}
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr)
}
switch {
case codexDeviceIsSuccessStatus(resp.StatusCode):
var parsed codexDeviceTokenResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("failed to decode codex device token response: %w", err)
}
return &parsed, nil
case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound:
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(interval):
continue
}
default:
trimmed := strings.TrimSpace(string(respBody))
if trimmed == "" {
trimmed = "empty response body"
}
return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed)
}
}
}
func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration {
defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second
if len(raw) == 0 {
return defaultInterval
}
var asString string
if err := json.Unmarshal(raw, &asString); err == nil {
if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
var asInt int
if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 {
return time.Duration(asInt) * time.Second
}
return defaultInterval
}
func codexDeviceIsSuccessStatus(code int) bool {
return code >= 200 && code < 300
}
func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) {
tokenStorage := authSvc.CreateTokenStorage(authBundle)
if tokenStorage == nil || tokenStorage.Email == "" {
return nil, fmt.Errorf("codex token storage missing account information")
}
planType := ""
hashAccountID := ""
if tokenStorage.IDToken != "" {
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
}
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Codex authentication successful")
if authBundle.APIKey != "" {
fmt.Println("Codex API key obtained and stored")
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
}

View File

@@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
}
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
type metadataSetter interface {
SetMetadata(map[string]any)
}
switch {
case auth.Storage != nil:
if setter, ok := auth.Storage.(metadataSetter); ok {
setter.SetMetadata(auth.Metadata)
}
if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err
}

View File

@@ -30,8 +30,9 @@ type ProviderExecutor interface {
Identifier() string
// Execute handles non-streaming execution and returns the provider response payload.
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
// ExecuteStream handles streaming execution and returns a channel of provider chunks.
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error)
// ExecuteStream handles streaming execution and returns a StreamResult containing
// upstream headers and a channel of provider chunks.
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error)
// Refresh attempts to refresh provider credentials and returns the updated auth state.
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
// CountTokens returns the token count for the given request.
@@ -558,7 +559,7 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
// ExecuteStream performs a streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
@@ -568,9 +569,9 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
var lastErr error
for attempt := 0; ; attempt++ {
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
if errStream == nil {
return chunks, nil
return result, nil
}
lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
@@ -699,7 +700,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
}
}
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
@@ -730,7 +731,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
@@ -778,8 +779,11 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
return &cliproxyexecutor.StreamResult{
Headers: streamResult.Headers,
Chunks: out,
}, nil
}
}
@@ -1824,9 +1828,7 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
// every few seconds and triggers refresh operations when required.
// Only one loop is kept alive; starting a new one cancels the previous run.
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
if interval <= 0 || interval > refreshCheckInterval {
interval = refreshCheckInterval
} else {
if interval <= 0 {
interval = refreshCheckInterval
}
if m.refreshCancel != nil {

View File

@@ -24,10 +24,10 @@ func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
ch := make(chan cliproxyexecutor.StreamChunk)
close(ch)
return ch, nil
return &cliproxyexecutor.StreamResult{Chunks: ch}, nil
}
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
@@ -89,7 +89,11 @@ func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
if !okResolved {
t.Fatal("expected registered executor to be found")
}
if resolved != current {
resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor)
if !okResolvedExecutor {
t.Fatalf("expected resolved executor type %T, got %T", current, resolved)
}
if resolvedExecutor != current {
t.Fatal("expected resolved executor to match registered executor")
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"math"
"math/rand/v2"
"net/http"
"sort"
"strconv"
@@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
}
// Pick selects the next available auth for the provider in a round-robin manner.
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
// a two-level round-robin is used: first cycling across credential groups (parent
// accounts), then cycling within each group's project auths.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = opts
now := time.Now()
@@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
if limit <= 0 {
limit = 4096
}
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
s.cursors = make(map[string]int)
}
index := s.cursors[key]
// Check if any available auth has gemini_virtual_parent attribute,
// indicating gemini-cli virtual auths that should use credential-level polling.
groups, parentOrder := groupByVirtualParent(available)
if len(parentOrder) > 1 {
// Two-level round-robin: first select a credential group, then pick within it.
groupKey := key + "::group"
s.ensureCursorKey(groupKey, limit)
if _, exists := s.cursors[groupKey]; !exists {
// Seed with a random initial offset so the starting credential is randomized.
s.cursors[groupKey] = rand.IntN(len(parentOrder))
}
groupIndex := s.cursors[groupKey]
if groupIndex >= 2_147_483_640 {
groupIndex = 0
}
s.cursors[groupKey] = groupIndex + 1
selectedParent := parentOrder[groupIndex%len(parentOrder)]
group := groups[selectedParent]
// Second level: round-robin within the selected credential group.
innerKey := key + "::cred:" + selectedParent
s.ensureCursorKey(innerKey, limit)
innerIndex := s.cursors[innerKey]
if innerIndex >= 2_147_483_640 {
innerIndex = 0
}
s.cursors[innerKey] = innerIndex + 1
s.mu.Unlock()
return group[innerIndex%len(group)], nil
}
// Flat round-robin for non-grouped auths (original behavior).
s.ensureCursorKey(key, limit)
index := s.cursors[key]
if index >= 2_147_483_640 {
index = 0
}
s.cursors[key] = index + 1
s.mu.Unlock()
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
return available[index%len(available)], nil
}
// ensureCursorKey ensures the cursor map has capacity for the given key.
// Must be called with s.mu held.
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
s.cursors = make(map[string]int)
}
}
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
if len(auths) == 0 {
return nil, nil
}
groups := make(map[string][]*Auth)
for _, a := range auths {
parent := ""
if a.Attributes != nil {
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
}
if parent == "" {
// Non-virtual auth present; fall back to flat round-robin.
return nil, nil
}
groups[parent] = append(groups[parent], a)
}
// Collect parent IDs in sorted order for stable cursor indexing.
parentOrder := make([]string, 0, len(groups))
for p := range groups {
parentOrder = append(parentOrder, p)
}
sort.Strings(parentOrder)
return groups, parentOrder
}
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = opts

View File

@@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
}
}
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Simulate two gemini-cli credentials, each with multiple projects:
// Credential A (parent = "cred-a.json") has 3 projects
// Credential B (parent = "cred-b.json") has 2 projects
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
}
// Two-level round-robin: consecutive picks must alternate between credentials.
// Credential group order is randomized, but within each call the group cursor
// advances by 1, so consecutive picks should cycle through different parents.
picks := make([]string, 6)
parents := make([]string, 6)
for i := 0; i < 6; i++ {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
picks[i] = got.ID
parents[i] = got.Attributes["gemini_virtual_parent"]
}
// Verify property: consecutive picks must alternate between credential groups.
for i := 1; i < len(parents); i++ {
if parents[i] == parents[i-1] {
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
i-1, i, parents[i], picks[i-1], picks[i])
}
}
// Verify property: each credential's projects are picked in sequence (round-robin within group).
credPicks := map[string][]string{}
for i, id := range picks {
credPicks[parents[i]] = append(credPicks[parents[i]], id)
}
for parent, ids := range credPicks {
for i := 1; i < len(ids); i++ {
if ids[i] == ids[i-1] {
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
}
}
}
}
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// All auths from the same parent - should fall back to flat round-robin
// because there's only one credential group (no benefit from two-level).
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
}
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
// Sorted by ID: proj-a1, proj-a2, proj-a3
want := []string{
"cred-a.json::proj-a1",
"cred-a.json::proj-a2",
"cred-a.json::proj-a3",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
// alongside virtual ones). Should fall back to flat round-robin.
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-regular.json"}, // no gemini_virtual_parent
}
// groupByVirtualParent returns nil when any auth lacks the attribute,
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
want := []string{
"cred-a.json::proj-a1",
"cred-regular.json",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}

View File

@@ -1,9 +1,12 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
@@ -12,6 +15,33 @@ import (
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
)
// PostAuthHook defines a function that is called after an Auth record is created
// but before it is persisted to storage. This allows for modification of the
// Auth record (e.g., injecting metadata) based on external context.
type PostAuthHook func(context.Context, *Auth) error
// RequestInfo holds information extracted from the HTTP request.
// It is injected into the context passed to PostAuthHook.
type RequestInfo struct {
Query url.Values
Headers http.Header
}
type requestInfoKey struct{}
// WithRequestInfo returns a new context with the given RequestInfo attached.
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, info)
}
// GetRequestInfo retrieves the RequestInfo from the context, if present.
func GetRequestInfo(ctx context.Context) *RequestInfo {
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
return val
}
return nil
}
// Auth encapsulates the runtime state and metadata associated with a single credential.
type Auth struct {
// ID uniquely identifies the auth record across restarts.
@@ -213,6 +243,23 @@ func (a *Auth) DisableCoolingOverride() (bool, bool) {
return false, false
}
// ToolPrefixDisabled returns whether the proxy_ tool name prefix should be
// skipped for this auth. When true, tool names are sent to Anthropic unchanged.
// The value is read from metadata key "tool_prefix_disabled" (or "tool-prefix-disabled").
func (a *Auth) ToolPrefixDisabled() bool {
if a == nil || a.Metadata == nil {
return false
}
for _, key := range []string{"tool_prefix_disabled", "tool-prefix-disabled"} {
if val, ok := a.Metadata[key]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed
}
}
}
return false
}
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
// The value is read from metadata key "request_retry" (or legacy "request-retry").
func (a *Auth) RequestRetryOverride() (int, bool) {

View File

@@ -0,0 +1,35 @@
package auth
import "testing"
func TestToolPrefixDisabled(t *testing.T) {
var a *Auth
if a.ToolPrefixDisabled() {
t.Error("nil auth should return false")
}
a = &Auth{}
if a.ToolPrefixDisabled() {
t.Error("empty auth should return false")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": true}}
if !a.ToolPrefixDisabled() {
t.Error("should return true when set to true")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": "true"}}
if !a.ToolPrefixDisabled() {
t.Error("should return true when set to string 'true'")
}
a = &Auth{Metadata: map[string]any{"tool-prefix-disabled": true}}
if !a.ToolPrefixDisabled() {
t.Error("should return true with kebab-case key")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": false}}
if a.ToolPrefixDisabled() {
t.Error("should return false when set to false")
}
}

View File

@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
return b
}
// WithPostAuthHook registers a hook to be called after an Auth record is created
// but before it is persisted to storage.
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
if hook == nil {
return b
}
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
return b
}
// Build validates inputs, applies defaults, and returns a ready-to-run service.
func (b *Builder) Build() (*Service, error) {
if b.cfg == nil {

View File

@@ -57,6 +57,8 @@ type Response struct {
Payload []byte
// Metadata exposes optional structured data for translators.
Metadata map[string]any
// Headers carries upstream HTTP response headers for passthrough to clients.
Headers http.Header
}
// StreamChunk represents a single streaming payload unit emitted by provider executors.
@@ -67,6 +69,15 @@ type StreamChunk struct {
Err error
}
// StreamResult wraps the streaming response, providing both the chunk channel
// and the upstream HTTP response headers captured before streaming begins.
type StreamResult struct {
// Headers carries upstream HTTP response headers from the initial connection.
Headers http.Header
// Chunks is the channel of streaming payload units.
Chunks <-chan StreamChunk
}
// StatusError represents an error that carries an HTTP-like status code.
// Provider executors should implement this when possible to enable
// better auth state updates on failures (e.g., 401/402/429).