Compare commits

...

32 Commits

Author SHA1 Message Date
Luis Pater
9c040445af Merge pull request #1635 from thebtf/fix/openai-translator-tool-streaming
fix: handle tool call argument streaming in Codex→OpenAI translator
2026-02-19 04:22:12 +08:00
Luis Pater
fff866424e Merge pull request #1628 from thebtf/fix/masquerading-headers
fix: update Claude masquerading headers and configurable defaults
2026-02-19 04:19:59 +08:00
Luis Pater
2d12becfd6 Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping
fix: clamp reasoning_effort to valid OpenAI-format values
2026-02-19 04:15:19 +08:00
Luis Pater
252f7e0751 Merge pull request #1625 from thebtf/feat/tool-prefix-config
feat: add per-auth tool_prefix_disabled option
2026-02-19 04:07:22 +08:00
Luis Pater
b2b17528cb Merge branch 'pr-1624' into dev
# Conflicts:
#	internal/runtime/executor/claude_executor.go
#	internal/runtime/executor/claude_executor_test.go
2026-02-19 04:05:04 +08:00
Luis Pater
55f938164b Merge pull request #1618 from alexey-yanchenko/fix/completions-usage
Fix empty usage in /v1/completions
2026-02-19 03:57:11 +08:00
Luis Pater
76294f0c59 Merge pull request #1608 from thebtf/fix/tool-reference-proxy-prefix-mainline
fix: add proxy_ prefix handling for tool_reference content blocks
2026-02-19 03:50:34 +08:00
Luis Pater
2bcee78c6e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:19:18 +08:00
Luis Pater
7fe8246a9f Merge branch 'tui' into dev 2026-02-19 03:18:24 +08:00
Luis Pater
93fe58e31e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:18:08 +08:00
Luis Pater
e5b5dc870f chore(executor): remove unused Openai-Beta header from Codex executor 2026-02-19 02:19:48 +08:00
Luis Pater
a54877c023 Merge branch 'dev' 2026-02-19 02:03:41 +08:00
Luis Pater
bb86a0c0c4 feat(logging, executor): add request logging tests and WebSocket-based Codex executor
- Introduced unit tests for request logging middleware to enhance coverage.
- Added WebSocket-based Codex executor to support Responses API upgrade.
- Updated middleware logic to selectively capture request bodies for memory efficiency.
- Enhanced Codex configuration handling with new WebSocket attributes.
2026-02-19 01:57:02 +08:00
Kirill Turanskiy
5fa23c7f41 fix: handle tool call argument streaming in Codex→OpenAI translator
The OpenAI Chat Completions translator was silently dropping
response.function_call_arguments.delta and
response.function_call_arguments.done Codex SSE events, meaning
tool call arguments were never streamed incrementally to clients.

Add proper handling mirroring the proven Claude translator pattern:

- response.output_item.added: announce tool call (id, name, empty args)
- response.function_call_arguments.delta: stream argument chunks
- response.function_call_arguments.done: emit full args if no deltas
- response.output_item.done: defensive fallback for backward compat

State tracking via HasReceivedArgumentsDelta and HasToolCallAnnounced
ensures no duplicate argument emission and correct behavior for models
like codex-spark that skip delta events entirely.
2026-02-18 19:09:05 +03:00
Kirill Turanskiy
73dc0b10b8 fix: update Claude masquerading headers and make them configurable
Update hardcoded X-Stainless-* and User-Agent defaults to match
Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (verified via
diagnostic proxy capture 2026-02-17).

Changes:
- X-Stainless-Os/Arch: dynamic via runtime.GOOS/GOARCH
- X-Stainless-Package-Version: 0.55.1 → 0.74.0
- X-Stainless-Timeout: 60 → 600
- User-Agent: claude-cli/1.0.83 (external, cli) → claude-cli/2.1.44 (external, sdk-cli)

Add claude-header-defaults config section so values can be updated
without recompilation when Claude Code releases new versions.
2026-02-18 03:38:51 +03:00
Kirill Turanskiy
2ea95266e3 fix: clamp reasoning_effort to valid OpenAI-format values
CPA-internal thinking levels like 'xhigh' and 'minimal' are not accepted
by OpenAI-format providers (MiniMax, etc.). The OpenAI applier now maps
non-standard levels to the nearest valid reasoning_effort value before
writing to the request body:

  xhigh   → high
  minimal → low
  auto    → medium
2026-02-18 03:36:42 +03:00
Kirill Turanskiy
9261b0c20b feat: add per-auth tool_prefix_disabled option
Allow disabling the proxy_ tool name prefix on a per-account basis.
Users who route their own Anthropic account through CPA can set
"tool_prefix_disabled": true in their OAuth auth JSON to send tool
names unchanged to Anthropic.

Default behavior is fully preserved — prefix is applied unless
explicitly disabled.

Changes:
- Add ToolPrefixDisabled() accessor to Auth (reads metadata key
  "tool_prefix_disabled" or "tool-prefix-disabled")
- Gate all 6 prefix apply/strip points with the new flag
- Add unit tests for the accessor
2026-02-17 21:48:19 +03:00
Kirill Turanskiy
7cc725496e fix: skip proxy_ prefix for built-in tools in message history
The proxy_ prefix logic correctly skips built-in tools (those with a
non-empty "type" field) in tools[] definitions but does not skip them
in messages[].content[] tool_use blocks or tool_choice. This causes
web_search in conversation history to become proxy_web_search, which
Anthropic does not recognize.

Fix: collect built-in tool names from tools[] into a set and also
maintain a hardcoded fallback set (web_search, code_execution,
text_editor, computer) for cases where the built-in tool appears in
history but not in the current request's tools[] array. Skip prefixing
in messages and tool_choice when name matches a built-in.
2026-02-17 21:42:32 +03:00
Alexey Yanchenko
709d999f9f Add usage to /v1/completions 2026-02-17 17:21:03 +07:00
Kirill Turanskiy
24c18614f0 fix: skip built-in tools in tool_reference prefix + refactor to switch
- Collect built-in tool names (those with a "type" field like
  web_search, code_execution) and skip prefixing tool_reference
  blocks that reference them, preventing name mismatch.
- Refactor if-else if chains to switch statements in all three
  prefix functions for idiomatic Go style.
2026-02-16 19:37:11 +03:00
Kirill Turanskiy
603f06a762 fix: handle tool_reference nested inside tool_result.content[]
tool_reference blocks can appear nested inside tool_result.content[]
arrays, not just at the top level of messages[].content[]. The prefix
logic now iterates into tool_result blocks with array content to find
and prefix/strip nested tool_reference.tool_name fields.
2026-02-16 19:06:24 +03:00
Kirill Turanskiy
98f0a3e3bd fix: add proxy_ prefix handling for tool_reference content blocks (#1)
applyClaudeToolPrefix, stripClaudeToolPrefixFromResponse, and
stripClaudeToolPrefixFromStreamLine now handle "tool_reference" blocks
(field "tool_name") in addition to "tool_use" blocks (field "name").

Without this fix, tool_reference blocks in conversation history retain
their original unprefixed names while tool definitions carry the proxy_
prefix, causing Anthropic API 400 errors: "Tool reference 'X' not found
in available tools."

Co-authored-by: Kirill Turanskiy <kt@novamedia.ru>
2026-02-16 19:06:24 +03:00
Luis Pater
453aaf8774 chore(runtime): update Qwen executor user agent and headers for compatibility with new runtime standards 2026-02-16 23:29:47 +08:00
Supra4E8C
1b1ab1fb9b Merge pull request #1606 from router-for-me/add-qwen-3.5
feat(registry): add Qwen 3.5 Plus model definitions
2026-02-16 23:10:53 +08:00
Supra4E8C
a9d0bb72da feat(registry): add Qwen 3.5 Plus model definitions 2026-02-16 22:55:37 +08:00
lhpqaq
2c8821891c fix(tui): update with review 2026-02-16 00:24:25 +08:00
haopeng
0a2555b0f3 Update internal/tui/auth_tab.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-16 00:11:31 +08:00
lhpqaq
020df41efe chore(tui): update readme, fix usage 2026-02-16 00:04:04 +08:00
lhpqaq
f31f7f701a feat(tui): add i18n 2026-02-15 15:42:59 +08:00
Supra4E8C
b5fe78eb70 Merge pull request #1597 from router-for-me/kimi-fix
feat(registry): add support for 'kimi' channel in model definitions
2026-02-15 15:35:17 +08:00
Supra4E8C
d1f667cf8d feat(registry): add support for 'kimi' channel in model definitions 2026-02-15 15:21:33 +08:00
lhpqaq
54ad7c1b6b feat(tui): add manager tui 2026-02-15 14:52:40 +08:00
51 changed files with 8406 additions and 119 deletions

View File

@@ -8,6 +8,7 @@ import (
"errors"
"flag"
"fmt"
"io"
"io/fs"
"net/url"
"os"
@@ -25,6 +26,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -68,6 +70,8 @@ func main() {
var vertexImport string
var configPath string
var password string
var tuiMode bool
var standalone bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
@@ -84,6 +88,8 @@ func main() {
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output()
@@ -479,8 +485,83 @@ func main() {
cmd.WaitForCloudDeploy()
return
}
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
cmd.StartService(cfg, configFilePath, password)
if tuiMode {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
origStdout := os.Stdout
origStderr := os.Stderr
origLogOutput := log.StandardLogger().Out
log.SetOutput(io.Discard)
devNull, errOpenDevNull := os.Open(os.DevNull)
if errOpenDevNull == nil {
os.Stdout = devNull
os.Stderr = devNull
}
restoreIO := func() {
os.Stdout = origStdout
os.Stderr = origStderr
log.SetOutput(origLogOutput)
if devNull != nil {
_ = devNull.Close()
}
}
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
if password == "" {
password = localMgmtPassword
}
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
client := tui.NewClient(cfg.Port, password)
ready := false
backoff := 100 * time.Millisecond
for i := 0; i < 30; i++ {
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
ready = true
break
}
time.Sleep(backoff)
if backoff < time.Second {
backoff = time.Duration(float64(backoff) * 1.5)
}
}
if !ready {
restoreIO()
cancel()
<-done
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
return
}
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
restoreIO()
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
} else {
restoreIO()
}
cancel()
<-done
} else {
// Default TUI mode: pure management client.
// The proxy server must already be running.
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
cmd.StartService(cfg, configFilePath, password)
}
}
}

View File

@@ -156,6 +156,14 @@ nonstream-keepalive-interval: 0
# - "API"
# - "proxy"
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# These are used as fallbacks when the client does not send its own headers.
# claude-header-defaults:
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
# package-version: "0.74.0"
# runtime-version: "v24.3.0"
# timeout: "600"
# OpenAI compatibility providers
# openai-compatibility:
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.

21
go.mod
View File

@@ -4,6 +4,10 @@ go 1.26.0
require (
github.com/andybalholm/brotli v1.0.6
github.com/atotto/clipboard v0.1.4
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.1
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
@@ -31,8 +35,16 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
@@ -40,6 +52,7 @@ require (
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-git/gcfg/v2 v2.0.2 // indirect
@@ -56,19 +69,27 @@ require (
github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sergi/go-diff v1.4.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect

45
go.sum
View File

@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -99,8 +125,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
@@ -112,6 +144,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
@@ -120,6 +158,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
@@ -159,17 +199,22 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=

View File

@@ -808,6 +808,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
changed := false
if req.Prefix != nil {
targetAuth.Prefix = *req.Prefix
changed = true
}
if req.ProxyURL != nil {
targetAuth.ProxyURL = *req.ProxyURL
changed = true
}
if req.Priority != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
}
changed = true
}
if !changed {
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
return
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return

View File

@@ -15,10 +15,12 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. When logging is disabled in the
// logger, it still captures data so that upstream errors can be persisted.
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
if logger == nil {
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
if c.Request.Method == http.MethodGet {
if shouldSkipMethodForRequestLogging(c.Request) {
c.Next()
return
}
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
loggerEnabled := logger.IsEnabled()
// Capture request information
requestInfo, err := captureRequestInfo(c)
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
if err != nil {
// Log error but continue processing
// In a real implementation, you might want to use a proper logger here
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
if !logger.IsEnabled() {
if !loggerEnabled {
wrapper.logOnErrorOnly = true
}
c.Writer = wrapper
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
}
}
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
if req == nil {
return true
}
if req.Method != http.MethodGet {
return false
}
return !isResponsesWebsocketUpgrade(req)
}
func isResponsesWebsocketUpgrade(req *http.Request) bool {
if req == nil || req.URL == nil {
return false
}
if req.URL.Path != "/v1/responses" {
return false
}
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
}
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
if loggerEnabled {
return true
}
if req == nil || req.Body == nil {
return false
}
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
if strings.HasPrefix(contentType, "multipart/form-data") {
return false
}
if req.ContentLength <= 0 {
return false
}
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
}
// captureRequestInfo extracts relevant information from the incoming HTTP request.
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
// Capture URL with sensitive query parameters masked
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
url := c.Request.URL.Path
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
// Capture request body
var body []byte
if c.Request.Body != nil {
if captureBody && c.Request.Body != nil {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {

View File

@@ -0,0 +1,138 @@
package middleware
import (
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
tests := []struct {
name string
req *http.Request
skip bool
}{
{
name: "nil request",
req: nil,
skip: true,
},
{
name: "post request should not skip",
req: &http.Request{
Method: http.MethodPost,
URL: &url.URL{Path: "/v1/responses"},
},
skip: false,
},
{
name: "plain get should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/models"},
Header: http.Header{},
},
skip: true,
},
{
name: "responses websocket upgrade should not skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{"Upgrade": []string{"websocket"}},
},
skip: false,
},
{
name: "responses get without upgrade should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{},
},
skip: true,
},
}
for i := range tests {
got := shouldSkipMethodForRequestLogging(tests[i].req)
if got != tests[i].skip {
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
}
}
}
func TestShouldCaptureRequestBody(t *testing.T) {
tests := []struct {
name string
loggerEnabled bool
req *http.Request
want bool
}{
{
name: "logger enabled always captures",
loggerEnabled: true,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "nil request",
loggerEnabled: false,
req: nil,
want: false,
},
{
name: "small known size json in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: 2,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "large known size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "unknown size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "multipart skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: 1,
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
},
want: false,
},
}
for i := range tests {
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
if got != tests[i].want {
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
}
}
}

View File

@@ -14,6 +14,8 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
// Only fall back to request payload hints when Content-Type is not set yet.
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
bodyStr := string(w.requestInfo.Body)
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
}
return false
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
return time.Time{}
}
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
var requestBody []byte
if len(w.requestInfo.Body) > 0 {
requestBody = w.requestInfo.Body
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {

View File

@@ -0,0 +1,43 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{
requestInfo: &RequestInfo{Body: []byte("original-body")},
}
body := wrapper.extractRequestBody(c)
if string(body) != "original-body" {
t.Fatalf("request body = %q, want %q", string(body), "original-body")
}
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
body = wrapper.extractRequestBody(c)
if string(body) != "override-body" {
t.Fatalf("request body = %q, want %q", string(body), "override-body")
}
}
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
if string(body) != "override-as-string" {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}

View File

@@ -284,8 +284,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
optionState.routerConfigurator(engine, s.handlers, cfg)
}
// Register management routes when configuration or environment secrets are available.
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
// Register management routes when configuration or environment secrets are available,
// or when a local management password is provided (e.g. TUI mode).
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
s.managementRoutesEnabled.Store(hasManagementSecret)
if hasManagementSecret {
s.registerManagementRoutes()
@@ -323,6 +324,7 @@ func (s *Server) setupRoutes() {
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
v1.POST("/responses", openaiResponsesHandlers.Responses)
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
}
@@ -616,6 +618,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)

View File

@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
}
}
// StartServiceBackground starts the proxy service in a background goroutine
// and returns a cancel function for shutdown and a done channel.
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
builder := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath(configPath).
WithLocalManagementPassword(localPassword)
ctx, cancelFn := context.WithCancel(context.Background())
doneCh := make(chan struct{})
service, err := builder.Build()
if err != nil {
log.Errorf("failed to build proxy service: %v", err)
close(doneCh)
return cancelFn, doneCh
}
go func() {
defer close(doneCh)
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("proxy service exited with error: %v", err)
}
}()
return cancelFn, doneCh
}
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
// when no configuration file is available.
func WaitForCloudDeploy() {

View File

@@ -90,6 +90,10 @@ type Config struct {
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
// ClaudeHeaderDefaults configures default header values for Claude API requests.
// These are used as fallbacks when the client does not send its own headers.
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
@@ -117,6 +121,15 @@ type Config struct {
legacyMigrationPending bool `yaml:"-" json:"-"`
}
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
// when the client does not send them. Update these when Claude Code releases a new version.
type ClaudeHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
Timeout string `yaml:"timeout" json:"timeout"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
@@ -355,6 +368,9 @@ type CodexKey struct {
// If empty, the default Codex API URL will be used.
BaseURL string `yaml:"base-url" json:"base-url"`
// Websockets enables the Responses API websocket transport for this credential.
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`

View File

@@ -19,6 +19,7 @@ import (
// - codex
// - qwen
// - iflow
// - kimi
// - antigravity (returns static overrides only)
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
key := strings.ToLower(strings.TrimSpace(channel))
@@ -39,6 +40,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetQwenModels()
case "iflow":
return GetIFlowModels()
case "kimi":
return GetKimiModels()
case "antigravity":
cfg := GetAntigravityModelConfig()
if len(cfg) == 0 {
@@ -83,6 +86,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
}
for _, models := range allModels {
for _, m := range models {

View File

@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771372800, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-6",
Object: "model",
@@ -788,6 +799,19 @@ func GetQwenModels() []*ModelInfo {
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "coder-model",
Object: "model",
Created: 1771171200,
OwnedBy: "qwen",
Type: "qwen",
Version: "3.5",
DisplayName: "Qwen 3.5 Plus",
Description: "efficient hybrid model with leading coding performance",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "vision-model",
Object: "model",
@@ -884,6 +908,8 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"runtime"
"strings"
"time"
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if err != nil {
return resp, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.publish(ctx, parseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
var param any
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if err != nil {
return nil, err
}
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -348,7 +349,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
// Forward the line as-is to preserve SSE format
@@ -375,7 +376,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
chunks := sdktranslator.TranslateStream(
@@ -423,7 +424,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -432,7 +433,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
if err != nil {
return cliproxyexecutor.Response{}, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -638,7 +639,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
return body, nil
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
hdrDefault := func(cfgVal, fallback string) string {
if cfgVal != "" {
return cfgVal
}
return fallback
}
var hd config.ClaudeHeaderDefaults
if cfg != nil {
hd = cfg.ClaudeHeaderDefaults
}
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
@@ -685,16 +728,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
r.Header.Set("Connection", "keep-alive")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
if stream {
@@ -702,6 +746,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
} else {
r.Header.Set("Accept", "application/json")
}
// Keep OS/Arch mapping dynamic (not configurable).
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
@@ -753,11 +799,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return body
}
// Collect built-in tool names (those with a non-empty "type" field) so we can
// skip them consistently in both tools and message history.
builtinTools := map[string]bool{}
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
builtinTools[name] = true
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have
// a "type" field and require their name to remain unchanged.
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
if n := tool.Get("name").String(); n != "" {
builtinTools[n] = true
}
return true
}
name := tool.Get("name").String()
@@ -772,7 +828,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
name := gjson.GetBytes(body, "tool_choice.name").String()
if name != "" && !strings.HasPrefix(name, prefix) {
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
}
}
@@ -784,15 +840,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return true
}
content.ForEach(func(contentIndex, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
case "tool_reference":
toolName := part.Get("tool_name").String()
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+toolName)
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
}
}
return true
})
}
}
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
return true
@@ -811,15 +890,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
return body
}
content.ForEach(func(index, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
case "tool_reference":
toolName := part.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return true
}
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if strings.HasPrefix(nestedToolName, prefix) {
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
}
}
return true
})
}
}
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
return true
})
return body
@@ -834,15 +936,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
return line
}
contentBlock := gjson.GetBytes(payload, "content_block")
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
if !contentBlock.Exists() {
return line
}
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
blockType := contentBlock.Get("type").String()
var updated []byte
var err error
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
return line
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
if err != nil {
return line
}
default:
return line
}

View File

@@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
@@ -37,6 +49,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
body := []byte(`{
"tools": [
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
body := []byte(`{
"tools": [{"name": "Read"}, {"name": "Write"}],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
}
}
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"name": "Read"}
],
"tool_choice": {"type": "tool", "name": "web_search"}
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -49,6 +152,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
}
}
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
@@ -61,3 +176,53 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
}
}
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
}
}
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "proxy_mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
}
}
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
if got != "mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
}
}
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
// tool_result.content can be a string - should not be processed
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
if got != "plain string result" {
t.Fatalf("string content should remain unchanged = %q", got)
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "web_search" {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}

View File

@@ -643,7 +643,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)

File diff suppressed because it is too large Load Diff

View File

@@ -22,9 +22,7 @@ import (
)
const (
qwenUserAgent = "google-api-nodejs-client/9.15.1"
qwenXGoogAPIClient = "gl-node/22.17.0"
qwenClientMetadataValue = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
)
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
@@ -344,8 +342,18 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", qwenUserAgent)
r.Header.Set("X-Goog-Api-Client", qwenXGoogAPIClient)
r.Header.Set("Client-Metadata", qwenClientMetadataValue)
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
r.Header.Set("Sec-Fetch-Mode", "cors")
r.Header.Set("X-Stainless-Lang", "js")
r.Header.Set("X-Stainless-Arch", "arm64")
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
r.Header.Set("X-Stainless-Retry-Count", "0")
r.Header.Set("X-Stainless-Os", "MacOS")
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
r.Header.Set("X-Stainless-Runtime", "node")
if stream {
r.Header.Set("Accept", "text/event-stream")
return

View File

@@ -10,10 +10,53 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// validReasoningEffortLevels contains the standard values accepted by the
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
// auto) are NOT in this set and must be clamped before use.
var validReasoningEffortLevels = map[string]struct{}{
"none": {},
"low": {},
"medium": {},
"high": {},
}
// clampReasoningEffort maps any thinking level string to a value that is safe
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
// mapped to the nearest standard equivalent.
//
// Mapping rules:
// - none / low / medium / high → returned as-is (already valid)
// - xhigh → "high" (nearest lower standard level)
// - minimal → "low" (nearest higher standard level)
// - auto → "medium" (reasonable default)
// - anything else → "medium" (safe default)
func clampReasoningEffort(level string) string {
if _, ok := validReasoningEffortLevels[level]; ok {
return level
}
var clamped string
switch level {
case string(thinking.LevelXHigh):
clamped = string(thinking.LevelHigh)
case string(thinking.LevelMinimal):
clamped = string(thinking.LevelLow)
case string(thinking.LevelAuto):
clamped = string(thinking.LevelMedium)
default:
clamped = string(thinking.LevelMedium)
}
log.WithFields(log.Fields{
"original": level,
"clamped": clamped,
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
return clamped
}
// Applier implements thinking.ProviderApplier for OpenAI models.
//
// OpenAI-specific behavior:
@@ -58,7 +101,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
return result, nil
}
@@ -79,7 +122,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
return result, nil
}
@@ -114,7 +157,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
return result, nil
}

View File

@@ -20,10 +20,12 @@ var (
// ConvertCliToOpenAIParams holds parameters for response conversion.
type ConvertCliToOpenAIParams struct {
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
HasReceivedArgumentsDelta bool
HasToolCallAnnounced bool
}
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertCliToOpenAIParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
HasReceivedArgumentsDelta: false,
HasToolCallAnnounced: false,
}
}
@@ -118,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
} else if dataType == "response.output_item.done" {
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
} else if dataType == "response.output_item.added" {
itemResult := rootResult.Get("item")
if itemResult.Exists() {
if itemResult.Get("type").String() != "function_call" {
return []string{}
}
// set the index
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened
name := itemResult.Get("name").String()
// Build reverse map on demand from original request tools
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
// Increment index for this new function call item.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.delta" {
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
deltaValue := rootResult.Get("delta").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.done" {
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
// Arguments were already streamed via delta events; nothing to emit.
return []string{}
}
// Fallback: no delta events were received, emit the full arguments as a single chunk.
fullArgs := rootResult.Get("arguments").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.output_item.done" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
// Tool call was already announced via output_item.added; skip emission.
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
return []string{}
}
// Fallback path: model skipped output_item.added, so emit complete tool call now.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else {
return []string{}
}

542
internal/tui/app.go Normal file
View File

@@ -0,0 +1,542 @@
package tui
import (
"fmt"
"io"
"os"
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// Tab identifiers
const (
tabDashboard = iota
tabConfig
tabAuthFiles
tabAPIKeys
tabOAuth
tabUsage
tabLogs
)
// App is the root bubbletea model that contains all tab sub-models.
type App struct {
activeTab int
tabs []string
standalone bool
logsEnabled bool
authenticated bool
authInput textinput.Model
authError string
authConnecting bool
dashboard dashboardModel
config configTabModel
auth authTabModel
keys keysTabModel
oauth oauthTabModel
usage usageTabModel
logs logsTabModel
client *Client
width int
height int
ready bool
// Track which tabs have been initialized (fetched data)
initialized [7]bool
}
type authConnectMsg struct {
cfg map[string]any
err error
}
// NewApp creates the root TUI application model.
func NewApp(port int, secretKey string, hook *LogHook) App {
standalone := hook != nil
authRequired := !standalone
ti := textinput.New()
ti.CharLimit = 512
ti.EchoMode = textinput.EchoPassword
ti.EchoCharacter = '*'
ti.SetValue(strings.TrimSpace(secretKey))
ti.Focus()
client := NewClient(port, secretKey)
app := App{
activeTab: tabDashboard,
standalone: standalone,
logsEnabled: true,
authenticated: !authRequired,
authInput: ti,
dashboard: newDashboardModel(client),
config: newConfigTabModel(client),
auth: newAuthTabModel(client),
keys: newKeysTabModel(client),
oauth: newOAuthTabModel(client),
usage: newUsageTabModel(client),
logs: newLogsTabModel(client, hook),
client: client,
initialized: [7]bool{
tabDashboard: true,
tabLogs: true,
},
}
app.refreshTabs()
if authRequired {
app.initialized = [7]bool{}
}
app.setAuthInputPrompt()
return app
}
func (a App) Init() tea.Cmd {
if !a.authenticated {
return textinput.Blink
}
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
cmds = append(cmds, a.logs.Init())
}
return tea.Batch(cmds...)
}
func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
a.width = msg.Width
a.height = msg.Height
a.ready = true
if a.width > 0 {
a.authInput.Width = a.width - 6
}
contentH := a.height - 4 // tab bar + status bar
if contentH < 1 {
contentH = 1
}
contentW := a.width
a.dashboard.SetSize(contentW, contentH)
a.config.SetSize(contentW, contentH)
a.auth.SetSize(contentW, contentH)
a.keys.SetSize(contentW, contentH)
a.oauth.SetSize(contentW, contentH)
a.usage.SetSize(contentW, contentH)
a.logs.SetSize(contentW, contentH)
return a, nil
case authConnectMsg:
a.authConnecting = false
if msg.err != nil {
a.authError = fmt.Sprintf(T("auth_gate_connect_fail"), msg.err.Error())
return a, nil
}
a.authError = ""
a.authenticated = true
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
a.refreshTabs()
a.initialized = [7]bool{}
a.initialized[tabDashboard] = true
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
a.initialized[tabLogs] = true
cmds = append(cmds, a.logs.Init())
}
return a, tea.Batch(cmds...)
case configUpdateMsg:
var cmdLogs tea.Cmd
if !a.standalone && msg.err == nil && msg.path == "logging-to-file" {
logsEnabledConfig, okConfig := msg.value.(bool)
if okConfig {
logsEnabledBefore := a.logsEnabled
a.logsEnabled = logsEnabledConfig
if logsEnabledBefore != a.logsEnabled {
a.refreshTabs()
}
if !a.logsEnabled {
a.initialized[tabLogs] = false
}
if !logsEnabledBefore && a.logsEnabled {
a.initialized[tabLogs] = true
cmdLogs = a.logs.Init()
}
}
}
var cmdConfig tea.Cmd
a.config, cmdConfig = a.config.Update(msg)
if cmdConfig != nil && cmdLogs != nil {
return a, tea.Batch(cmdConfig, cmdLogs)
}
if cmdConfig != nil {
return a, cmdConfig
}
return a, cmdLogs
case tea.KeyMsg:
if !a.authenticated {
switch msg.String() {
case "ctrl+c", "q":
return a, tea.Quit
case "L":
ToggleLocale()
a.refreshTabs()
a.setAuthInputPrompt()
return a, nil
case "enter":
if a.authConnecting {
return a, nil
}
password := strings.TrimSpace(a.authInput.Value())
if password == "" {
a.authError = T("auth_gate_password_required")
return a, nil
}
a.authError = ""
a.authConnecting = true
return a, a.connectWithPassword(password)
default:
var cmd tea.Cmd
a.authInput, cmd = a.authInput.Update(msg)
return a, cmd
}
}
switch msg.String() {
case "ctrl+c":
return a, tea.Quit
case "q":
// Only quit if not in logs tab (where 'q' might be useful)
if !a.logsEnabled || a.activeTab != tabLogs {
return a, tea.Quit
}
case "L":
ToggleLocale()
a.refreshTabs()
return a.broadcastToAllTabs(localeChangedMsg{})
case "tab":
if len(a.tabs) == 0 {
return a, nil
}
prevTab := a.activeTab
a.activeTab = (a.activeTab + 1) % len(a.tabs)
return a, a.initTabIfNeeded(prevTab)
case "shift+tab":
if len(a.tabs) == 0 {
return a, nil
}
prevTab := a.activeTab
a.activeTab = (a.activeTab - 1 + len(a.tabs)) % len(a.tabs)
return a, a.initTabIfNeeded(prevTab)
}
}
if !a.authenticated {
var cmd tea.Cmd
a.authInput, cmd = a.authInput.Update(msg)
return a, cmd
}
// Route msg to active tab
var cmd tea.Cmd
switch a.activeTab {
case tabDashboard:
a.dashboard, cmd = a.dashboard.Update(msg)
case tabConfig:
a.config, cmd = a.config.Update(msg)
case tabAuthFiles:
a.auth, cmd = a.auth.Update(msg)
case tabAPIKeys:
a.keys, cmd = a.keys.Update(msg)
case tabOAuth:
a.oauth, cmd = a.oauth.Update(msg)
case tabUsage:
a.usage, cmd = a.usage.Update(msg)
case tabLogs:
a.logs, cmd = a.logs.Update(msg)
}
// Keep logs polling alive even when logs tab is not active.
if a.logsEnabled && a.activeTab != tabLogs {
switch msg.(type) {
case logsPollMsg, logsTickMsg, logLineMsg:
var logCmd tea.Cmd
a.logs, logCmd = a.logs.Update(msg)
if logCmd != nil {
cmd = logCmd
}
}
}
return a, cmd
}
// localeChangedMsg is broadcast to all tabs when the user toggles locale.
type localeChangedMsg struct{}
func (a *App) refreshTabs() {
names := TabNames()
if a.logsEnabled {
a.tabs = names
} else {
filtered := make([]string, 0, len(names)-1)
for idx, name := range names {
if idx == tabLogs {
continue
}
filtered = append(filtered, name)
}
a.tabs = filtered
}
if len(a.tabs) == 0 {
a.activeTab = tabDashboard
return
}
if a.activeTab >= len(a.tabs) {
a.activeTab = len(a.tabs) - 1
}
}
func (a *App) initTabIfNeeded(_ int) tea.Cmd {
if a.initialized[a.activeTab] {
return nil
}
a.initialized[a.activeTab] = true
switch a.activeTab {
case tabDashboard:
return a.dashboard.Init()
case tabConfig:
return a.config.Init()
case tabAuthFiles:
return a.auth.Init()
case tabAPIKeys:
return a.keys.Init()
case tabOAuth:
return a.oauth.Init()
case tabUsage:
return a.usage.Init()
case tabLogs:
if !a.logsEnabled {
return nil
}
return a.logs.Init()
}
return nil
}
func (a App) View() string {
if !a.authenticated {
return a.renderAuthView()
}
if !a.ready {
return T("initializing_tui")
}
var sb strings.Builder
// Tab bar
sb.WriteString(a.renderTabBar())
sb.WriteString("\n")
// Content
switch a.activeTab {
case tabDashboard:
sb.WriteString(a.dashboard.View())
case tabConfig:
sb.WriteString(a.config.View())
case tabAuthFiles:
sb.WriteString(a.auth.View())
case tabAPIKeys:
sb.WriteString(a.keys.View())
case tabOAuth:
sb.WriteString(a.oauth.View())
case tabUsage:
sb.WriteString(a.usage.View())
case tabLogs:
if a.logsEnabled {
sb.WriteString(a.logs.View())
}
}
// Status bar
sb.WriteString("\n")
sb.WriteString(a.renderStatusBar())
return sb.String()
}
func (a App) renderAuthView() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("auth_gate_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_gate_help")))
sb.WriteString("\n\n")
if a.authConnecting {
sb.WriteString(warningStyle.Render(T("auth_gate_connecting")))
sb.WriteString("\n\n")
}
if strings.TrimSpace(a.authError) != "" {
sb.WriteString(errorStyle.Render(a.authError))
sb.WriteString("\n\n")
}
sb.WriteString(a.authInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_gate_enter")))
return sb.String()
}
func (a App) renderTabBar() string {
var tabs []string
for i, name := range a.tabs {
if i == a.activeTab {
tabs = append(tabs, tabActiveStyle.Render(name))
} else {
tabs = append(tabs, tabInactiveStyle.Render(name))
}
}
tabBar := lipgloss.JoinHorizontal(lipgloss.Top, tabs...)
return tabBarStyle.Width(a.width).Render(tabBar)
}
func (a App) renderStatusBar() string {
left := strings.TrimRight(T("status_left"), " ")
right := strings.TrimRight(T("status_right"), " ")
width := a.width
if width < 1 {
width = 1
}
// statusBarStyle has left/right padding(1), so content area is width-2.
contentWidth := width - 2
if contentWidth < 0 {
contentWidth = 0
}
if lipgloss.Width(left) > contentWidth {
left = fitStringWidth(left, contentWidth)
right = ""
}
remaining := contentWidth - lipgloss.Width(left)
if remaining < 0 {
remaining = 0
}
if lipgloss.Width(right) > remaining {
right = fitStringWidth(right, remaining)
}
gap := contentWidth - lipgloss.Width(left) - lipgloss.Width(right)
if gap < 0 {
gap = 0
}
return statusBarStyle.Width(width).Render(left + strings.Repeat(" ", gap) + right)
}
func fitStringWidth(text string, maxWidth int) string {
if maxWidth <= 0 {
return ""
}
if lipgloss.Width(text) <= maxWidth {
return text
}
out := ""
for _, r := range text {
next := out + string(r)
if lipgloss.Width(next) > maxWidth {
break
}
out = next
}
return out
}
func isLogsEnabledFromConfig(cfg map[string]any) bool {
if cfg == nil {
return true
}
value, ok := cfg["logging-to-file"]
if !ok {
return true
}
enabled, ok := value.(bool)
if !ok {
return true
}
return enabled
}
func (a *App) setAuthInputPrompt() {
if a == nil {
return
}
a.authInput.Prompt = fmt.Sprintf(" %s: ", T("auth_gate_password"))
}
func (a App) connectWithPassword(password string) tea.Cmd {
return func() tea.Msg {
a.client.SetSecretKey(password)
cfg, errGetConfig := a.client.GetConfig()
return authConnectMsg{cfg: cfg, err: errGetConfig}
}
}
// Run starts the TUI application.
// output specifies where bubbletea renders. If nil, defaults to os.Stdout.
func Run(port int, secretKey string, hook *LogHook, output io.Writer) error {
if output == nil {
output = os.Stdout
}
app := NewApp(port, secretKey, hook)
p := tea.NewProgram(app, tea.WithAltScreen(), tea.WithOutput(output))
_, err := p.Run()
return err
}
func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
var cmd tea.Cmd
a.dashboard, cmd = a.dashboard.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.config, cmd = a.config.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.auth, cmd = a.auth.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.keys, cmd = a.keys.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.oauth, cmd = a.oauth.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.usage, cmd = a.usage.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.logs, cmd = a.logs.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
return a, tea.Batch(cmds...)
}

456
internal/tui/auth_tab.go Normal file
View File

@@ -0,0 +1,456 @@
package tui
import (
"fmt"
"strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// editableField represents an editable field on an auth file.
type editableField struct {
label string
key string // API field key: "prefix", "proxy_url", "priority"
}
var authEditableFields = []editableField{
{label: "Prefix", key: "prefix"},
{label: "Proxy URL", key: "proxy_url"},
{label: "Priority", key: "priority"},
}
// authTabModel displays auth credential files with interactive management.
type authTabModel struct {
client *Client
viewport viewport.Model
files []map[string]any
err error
width int
height int
ready bool
cursor int
expanded int // -1 = none expanded, >=0 = expanded index
confirm int // -1 = no confirmation, >=0 = confirm delete for index
status string
// Editing state
editing bool // true when editing a field
editField int // index into authEditableFields
editInput textinput.Model // text input for editing
editFileName string // name of file being edited
}
type authFilesMsg struct {
files []map[string]any
err error
}
type authActionMsg struct {
action string // "deleted", "toggled", "updated"
err error
}
func newAuthTabModel(client *Client) authTabModel {
ti := textinput.New()
ti.CharLimit = 256
return authTabModel{
client: client,
expanded: -1,
confirm: -1,
editInput: ti,
}
}
func (m authTabModel) Init() tea.Cmd {
return m.fetchFiles
}
func (m authTabModel) fetchFiles() tea.Msg {
files, err := m.client.GetAuthFiles()
return authFilesMsg{files: files, err: err}
}
func (m authTabModel) Update(msg tea.Msg) (authTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case authFilesMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.files = msg.files
if m.cursor >= len(m.files) {
m.cursor = max(0, len(m.files)-1)
}
m.status = ""
}
m.viewport.SetContent(m.renderContent())
return m, nil
case authActionMsg:
if msg.err != nil {
m.status = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.status = successStyle.Render("✓ " + msg.action)
}
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, m.fetchFiles
case tea.KeyMsg:
// ---- Editing mode ----
if m.editing {
return m.handleEditInput(msg)
}
// ---- Delete confirmation mode ----
if m.confirm >= 0 {
return m.handleConfirmInput(msg)
}
// ---- Normal mode ----
return m.handleNormalInput(msg)
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
// startEdit activates inline editing for a field on the currently selected auth file.
func (m *authTabModel) startEdit(fieldIdx int) tea.Cmd {
if m.cursor >= len(m.files) {
return nil
}
f := m.files[m.cursor]
m.editFileName = getString(f, "name")
m.editField = fieldIdx
m.editing = true
// Pre-populate with current value
key := authEditableFields[fieldIdx].key
currentVal := getAnyString(f, key)
m.editInput.SetValue(currentVal)
m.editInput.Focus()
m.editInput.Prompt = fmt.Sprintf(" %s: ", authEditableFields[fieldIdx].label)
m.viewport.SetContent(m.renderContent())
return textinput.Blink
}
func (m *authTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.editInput.Width = w - 20
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m authTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m authTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("auth_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_help1")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_help2")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if len(m.files) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_auth_files")))
sb.WriteString("\n")
return sb.String()
}
for i, f := range m.files {
name := getString(f, "name")
channel := getString(f, "channel")
email := getString(f, "email")
disabled := getBool(f, "disabled")
statusIcon := successStyle.Render("●")
statusText := T("status_active")
if disabled {
statusIcon = lipgloss.NewStyle().Foreground(colorMuted).Render("○")
statusText = T("status_disabled")
}
cursor := " "
rowStyle := lipgloss.NewStyle()
if i == m.cursor {
cursor = "▸ "
rowStyle = lipgloss.NewStyle().Bold(true)
}
displayName := name
if len(displayName) > 24 {
displayName = displayName[:21] + "..."
}
displayEmail := email
if len(displayEmail) > 28 {
displayEmail = displayEmail[:25] + "..."
}
row := fmt.Sprintf("%s%s %-24s %-12s %-28s %s",
cursor, statusIcon, displayName, channel, displayEmail, statusText)
sb.WriteString(rowStyle.Render(row))
sb.WriteString("\n")
// Delete confirmation
if m.confirm == i {
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete"), name)))
sb.WriteString("\n")
}
// Inline edit input
if m.editing && i == m.cursor {
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(" " + T("enter_save") + " • " + T("esc_cancel")))
sb.WriteString("\n")
}
// Expanded detail view
if m.expanded == i {
sb.WriteString(m.renderDetail(f))
}
}
if m.status != "" {
sb.WriteString("\n")
sb.WriteString(m.status)
sb.WriteString("\n")
}
return sb.String()
}
func (m authTabModel) renderDetail(f map[string]any) string {
var sb strings.Builder
labelStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("111")).
Bold(true)
valueStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
editableMarker := lipgloss.NewStyle().
Foreground(lipgloss.Color("214")).
Render(" ✎")
sb.WriteString(" ┌─────────────────────────────────────────────\n")
fields := []struct {
label string
key string
editable bool
}{
{"Name", "name", false},
{"Channel", "channel", false},
{"Email", "email", false},
{"Status", "status", false},
{"Status Msg", "status_message", false},
{"File Name", "file_name", false},
{"Auth Type", "auth_type", false},
{"Prefix", "prefix", true},
{"Proxy URL", "proxy_url", true},
{"Priority", "priority", true},
{"Project ID", "project_id", false},
{"Disabled", "disabled", false},
{"Created", "created_at", false},
{"Updated", "updated_at", false},
}
for _, field := range fields {
val := getAnyString(f, field.key)
if val == "" || val == "<nil>" {
if field.editable {
val = T("not_set")
} else {
continue
}
}
editMark := ""
if field.editable {
editMark = editableMarker
}
line := fmt.Sprintf(" │ %s %s%s",
labelStyle.Render(fmt.Sprintf("%-12s:", field.label)),
valueStyle.Render(val),
editMark)
sb.WriteString(line)
sb.WriteString("\n")
}
sb.WriteString(" └─────────────────────────────────────────────\n")
return sb.String()
}
// getAnyString converts any value to its string representation.
func getAnyString(m map[string]any, key string) string {
v, ok := m[key]
if !ok || v == nil {
return ""
}
return fmt.Sprintf("%v", v)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func (m authTabModel) handleEditInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "enter":
value := m.editInput.Value()
fieldKey := authEditableFields[m.editField].key
fileName := m.editFileName
m.editing = false
m.editInput.Blur()
fields := map[string]any{}
if fieldKey == "priority" {
p, err := strconv.Atoi(value)
if err != nil {
return m, func() tea.Msg {
return authActionMsg{err: fmt.Errorf("%s: %s", T("invalid_int"), value)}
}
}
fields[fieldKey] = p
} else {
fields[fieldKey] = value
}
return m, func() tea.Msg {
err := m.client.PatchAuthFileFields(fileName, fields)
if err != nil {
return authActionMsg{err: err}
}
return authActionMsg{action: fmt.Sprintf(T("updated_field"), fieldKey, fileName)}
}
case "esc":
m.editing = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.editInput, cmd = m.editInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
func (m authTabModel) handleConfirmInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "y", "Y":
idx := m.confirm
m.confirm = -1
if idx < len(m.files) {
name := getString(m.files[idx], "name")
return m, func() tea.Msg {
err := m.client.DeleteAuthFile(name)
if err != nil {
return authActionMsg{err: err}
}
return authActionMsg{action: fmt.Sprintf(T("deleted"), name)}
}
}
m.viewport.SetContent(m.renderContent())
return m, nil
case "n", "N", "esc":
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, nil
}
return m, nil
}
func (m authTabModel) handleNormalInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "j", "down":
if len(m.files) > 0 {
m.cursor = (m.cursor + 1) % len(m.files)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "k", "up":
if len(m.files) > 0 {
m.cursor = (m.cursor - 1 + len(m.files)) % len(m.files)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "enter", " ":
if m.expanded == m.cursor {
m.expanded = -1
} else {
m.expanded = m.cursor
}
m.viewport.SetContent(m.renderContent())
return m, nil
case "d", "D":
if m.cursor < len(m.files) {
m.confirm = m.cursor
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "e", "E":
if m.cursor < len(m.files) {
f := m.files[m.cursor]
name := getString(f, "name")
disabled := getBool(f, "disabled")
newDisabled := !disabled
return m, func() tea.Msg {
err := m.client.ToggleAuthFile(name, newDisabled)
if err != nil {
return authActionMsg{err: err}
}
action := T("enabled")
if newDisabled {
action = T("disabled")
}
return authActionMsg{action: fmt.Sprintf("%s %s", action, name)}
}
}
return m, nil
case "1":
return m, m.startEdit(0) // prefix
case "2":
return m, m.startEdit(1) // proxy_url
case "3":
return m, m.startEdit(2) // priority
case "r":
m.status = ""
return m, m.fetchFiles
default:
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
}

20
internal/tui/browser.go Normal file
View File

@@ -0,0 +1,20 @@
package tui
import (
"os/exec"
"runtime"
)
// openBrowser opens the specified URL in the user's default browser.
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
default:
return exec.Command("xdg-open", url).Start()
}
}

400
internal/tui/client.go Normal file
View File

@@ -0,0 +1,400 @@
package tui
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
// Client wraps HTTP calls to the management API.
type Client struct {
baseURL string
secretKey string
http *http.Client
}
// NewClient creates a new management API client.
func NewClient(port int, secretKey string) *Client {
return &Client{
baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
secretKey: strings.TrimSpace(secretKey),
http: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// SetSecretKey updates management API bearer token used by this client.
func (c *Client) SetSecretKey(secretKey string) {
c.secretKey = strings.TrimSpace(secretKey)
}
func (c *Client) doRequest(method, path string, body io.Reader) ([]byte, int, error) {
url := c.baseURL + path
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, 0, err
}
if c.secretKey != "" {
req.Header.Set("Authorization", "Bearer "+c.secretKey)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.http.Do(req)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, err
}
return data, resp.StatusCode, nil
}
func (c *Client) get(path string) ([]byte, error) {
data, code, err := c.doRequest("GET", path, nil)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
func (c *Client) put(path string, body io.Reader) ([]byte, error) {
data, code, err := c.doRequest("PUT", path, body)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
func (c *Client) patch(path string, body io.Reader) ([]byte, error) {
data, code, err := c.doRequest("PATCH", path, body)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
// getJSON fetches a path and unmarshals JSON into a generic map.
func (c *Client) getJSON(path string) (map[string]any, error) {
data, err := c.get(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
// postJSON sends a JSON body via POST and checks for errors.
func (c *Client) postJSON(path string, body any) error {
jsonBody, err := json.Marshal(body)
if err != nil {
return err
}
_, code, err := c.doRequest("POST", path, strings.NewReader(string(jsonBody)))
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("HTTP %d", code)
}
return nil
}
// GetConfig fetches the parsed config.
func (c *Client) GetConfig() (map[string]any, error) {
return c.getJSON("/v0/management/config")
}
// GetConfigYAML fetches the raw config.yaml content.
func (c *Client) GetConfigYAML() (string, error) {
data, err := c.get("/v0/management/config.yaml")
if err != nil {
return "", err
}
return string(data), nil
}
// PutConfigYAML uploads new config.yaml content.
func (c *Client) PutConfigYAML(yamlContent string) error {
_, err := c.put("/v0/management/config.yaml", strings.NewReader(yamlContent))
return err
}
// GetUsage fetches usage statistics.
func (c *Client) GetUsage() (map[string]any, error) {
return c.getJSON("/v0/management/usage")
}
// GetAuthFiles lists auth credential files.
// API returns {"files": [...]}.
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
wrapper, err := c.getJSON("/v0/management/auth-files")
if err != nil {
return nil, err
}
return extractList(wrapper, "files")
}
// DeleteAuthFile deletes a single auth file by name.
func (c *Client) DeleteAuthFile(name string) error {
query := url.Values{}
query.Set("name", name)
path := "/v0/management/auth-files?" + query.Encode()
_, code, err := c.doRequest("DELETE", path, nil)
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("delete failed (HTTP %d)", code)
}
return nil
}
// ToggleAuthFile enables or disables an auth file.
func (c *Client) ToggleAuthFile(name string, disabled bool) error {
body, _ := json.Marshal(map[string]any{"name": name, "disabled": disabled})
_, err := c.patch("/v0/management/auth-files/status", strings.NewReader(string(body)))
return err
}
// PatchAuthFileFields updates editable fields on an auth file.
func (c *Client) PatchAuthFileFields(name string, fields map[string]any) error {
fields["name"] = name
body, _ := json.Marshal(fields)
_, err := c.patch("/v0/management/auth-files/fields", strings.NewReader(string(body)))
return err
}
// GetLogs fetches log lines from the server.
func (c *Client) GetLogs(after int64, limit int) ([]string, int64, error) {
query := url.Values{}
if limit > 0 {
query.Set("limit", strconv.Itoa(limit))
}
if after > 0 {
query.Set("after", strconv.FormatInt(after, 10))
}
path := "/v0/management/logs"
encodedQuery := query.Encode()
if encodedQuery != "" {
path += "?" + encodedQuery
}
wrapper, err := c.getJSON(path)
if err != nil {
return nil, after, err
}
lines := []string{}
if rawLines, ok := wrapper["lines"]; ok && rawLines != nil {
rawJSON, errMarshal := json.Marshal(rawLines)
if errMarshal != nil {
return nil, after, errMarshal
}
if errUnmarshal := json.Unmarshal(rawJSON, &lines); errUnmarshal != nil {
return nil, after, errUnmarshal
}
}
latest := after
if rawLatest, ok := wrapper["latest-timestamp"]; ok {
switch value := rawLatest.(type) {
case float64:
latest = int64(value)
case json.Number:
if parsed, errParse := value.Int64(); errParse == nil {
latest = parsed
}
case int64:
latest = value
case int:
latest = int64(value)
}
}
if latest < after {
latest = after
}
return lines, latest, nil
}
// GetAPIKeys fetches the list of API keys.
// API returns {"api-keys": [...]}.
func (c *Client) GetAPIKeys() ([]string, error) {
wrapper, err := c.getJSON("/v0/management/api-keys")
if err != nil {
return nil, err
}
arr, ok := wrapper["api-keys"]
if !ok {
return nil, nil
}
raw, err := json.Marshal(arr)
if err != nil {
return nil, err
}
var result []string
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
return result, nil
}
// AddAPIKey adds a new API key by sending old=nil, new=key which appends.
func (c *Client) AddAPIKey(key string) error {
body := map[string]any{"old": nil, "new": key}
jsonBody, _ := json.Marshal(body)
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
return err
}
// EditAPIKey replaces an API key at the given index.
func (c *Client) EditAPIKey(index int, newValue string) error {
body := map[string]any{"index": index, "value": newValue}
jsonBody, _ := json.Marshal(body)
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
return err
}
// DeleteAPIKey deletes an API key by index.
func (c *Client) DeleteAPIKey(index int) error {
_, code, err := c.doRequest("DELETE", fmt.Sprintf("/v0/management/api-keys?index=%d", index), nil)
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("delete failed (HTTP %d)", code)
}
return nil
}
// GetGeminiKeys fetches Gemini API keys.
// API returns {"gemini-api-key": [...]}.
func (c *Client) GetGeminiKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/gemini-api-key", "gemini-api-key")
}
// GetClaudeKeys fetches Claude API keys.
func (c *Client) GetClaudeKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/claude-api-key", "claude-api-key")
}
// GetCodexKeys fetches Codex API keys.
func (c *Client) GetCodexKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/codex-api-key", "codex-api-key")
}
// GetVertexKeys fetches Vertex API keys.
func (c *Client) GetVertexKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/vertex-api-key", "vertex-api-key")
}
// GetOpenAICompat fetches OpenAI compatibility entries.
func (c *Client) GetOpenAICompat() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/openai-compatibility", "openai-compatibility")
}
// getWrappedKeyList fetches a wrapped list from the API.
func (c *Client) getWrappedKeyList(path, key string) ([]map[string]any, error) {
wrapper, err := c.getJSON(path)
if err != nil {
return nil, err
}
return extractList(wrapper, key)
}
// extractList pulls an array of maps from a wrapper object by key.
func extractList(wrapper map[string]any, key string) ([]map[string]any, error) {
arr, ok := wrapper[key]
if !ok || arr == nil {
return nil, nil
}
raw, err := json.Marshal(arr)
if err != nil {
return nil, err
}
var result []map[string]any
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
return result, nil
}
// GetDebug fetches the current debug setting.
func (c *Client) GetDebug() (bool, error) {
wrapper, err := c.getJSON("/v0/management/debug")
if err != nil {
return false, err
}
if v, ok := wrapper["debug"]; ok {
if b, ok := v.(bool); ok {
return b, nil
}
}
return false, nil
}
// GetAuthStatus polls the OAuth session status.
// Returns status ("wait", "ok", "error") and optional error message.
func (c *Client) GetAuthStatus(state string) (string, string, error) {
query := url.Values{}
query.Set("state", state)
path := "/v0/management/get-auth-status?" + query.Encode()
wrapper, err := c.getJSON(path)
if err != nil {
return "", "", err
}
status := getString(wrapper, "status")
errMsg := getString(wrapper, "error")
return status, errMsg, nil
}
// ----- Config field update methods -----
// PutBoolField updates a boolean config field.
func (c *Client) PutBoolField(path string, value bool) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// PutIntField updates an integer config field.
func (c *Client) PutIntField(path string, value int) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// PutStringField updates a string config field.
func (c *Client) PutStringField(path string, value string) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// DeleteField sends a DELETE request for a config field.
func (c *Client) DeleteField(path string) error {
_, _, err := c.doRequest("DELETE", "/v0/management/"+path, nil)
return err
}

413
internal/tui/config_tab.go Normal file
View File

@@ -0,0 +1,413 @@
package tui
import (
"fmt"
"strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// configField represents a single editable config field.
type configField struct {
label string
apiPath string // management API path (e.g. "debug", "proxy-url")
kind string // "bool", "int", "string", "readonly"
value string // current display value
rawValue any // raw value from API
}
// configTabModel displays parsed config with interactive editing.
type configTabModel struct {
client *Client
viewport viewport.Model
fields []configField
cursor int
editing bool
textInput textinput.Model
err error
message string // status message (success/error)
width int
height int
ready bool
}
type configDataMsg struct {
config map[string]any
err error
}
type configUpdateMsg struct {
path string
value any
err error
}
func newConfigTabModel(client *Client) configTabModel {
ti := textinput.New()
ti.CharLimit = 256
return configTabModel{
client: client,
textInput: ti,
}
}
func (m configTabModel) Init() tea.Cmd {
return m.fetchConfig
}
func (m configTabModel) fetchConfig() tea.Msg {
cfg, err := m.client.GetConfig()
return configDataMsg{config: cfg, err: err}
}
func (m configTabModel) Update(msg tea.Msg) (configTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case configDataMsg:
if msg.err != nil {
m.err = msg.err
m.fields = nil
} else {
m.err = nil
m.fields = m.parseConfig(msg.config)
}
m.viewport.SetContent(m.renderContent())
return m, nil
case configUpdateMsg:
if msg.err != nil {
m.message = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.message = successStyle.Render(T("updated_ok"))
}
m.viewport.SetContent(m.renderContent())
// Refresh config from server
return m, m.fetchConfig
case tea.KeyMsg:
if m.editing {
return m.handleEditingKey(msg)
}
return m.handleNormalKey(msg)
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m configTabModel) handleNormalKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
switch msg.String() {
case "r":
m.message = ""
return m, m.fetchConfig
case "up", "k":
if m.cursor > 0 {
m.cursor--
m.viewport.SetContent(m.renderContent())
// Ensure cursor is visible
m.ensureCursorVisible()
}
return m, nil
case "down", "j":
if m.cursor < len(m.fields)-1 {
m.cursor++
m.viewport.SetContent(m.renderContent())
m.ensureCursorVisible()
}
return m, nil
case "enter", " ":
if m.cursor >= 0 && m.cursor < len(m.fields) {
f := m.fields[m.cursor]
if f.kind == "readonly" {
return m, nil
}
if f.kind == "bool" {
// Toggle directly
return m, m.toggleBool(m.cursor)
}
// Start editing for int/string
m.editing = true
m.textInput.SetValue(configFieldEditValue(f))
m.textInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
}
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m configTabModel) handleEditingKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
switch msg.String() {
case "enter":
m.editing = false
m.textInput.Blur()
return m, m.submitEdit(m.cursor, m.textInput.Value())
case "esc":
m.editing = false
m.textInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
func (m configTabModel) toggleBool(idx int) tea.Cmd {
return func() tea.Msg {
f := m.fields[idx]
current := f.value == "true"
newValue := !current
errPutBool := m.client.PutBoolField(f.apiPath, newValue)
return configUpdateMsg{
path: f.apiPath,
value: newValue,
err: errPutBool,
}
}
}
func (m configTabModel) submitEdit(idx int, newValue string) tea.Cmd {
return func() tea.Msg {
f := m.fields[idx]
var err error
var value any
switch f.kind {
case "int":
valueInt, errAtoi := strconv.Atoi(newValue)
if errAtoi != nil {
return configUpdateMsg{
path: f.apiPath,
err: fmt.Errorf("%s: %s", T("invalid_int"), newValue),
}
}
value = valueInt
err = m.client.PutIntField(f.apiPath, valueInt)
case "string":
value = newValue
err = m.client.PutStringField(f.apiPath, newValue)
}
return configUpdateMsg{
path: f.apiPath,
value: value,
err: err,
}
}
}
func configFieldEditValue(f configField) string {
if rawString, ok := f.rawValue.(string); ok {
return rawString
}
return f.value
}
func (m *configTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m *configTabModel) ensureCursorVisible() {
// Each field takes ~1 line, header takes ~4 lines
targetLine := m.cursor + 5
if targetLine < m.viewport.YOffset {
m.viewport.SetYOffset(targetLine)
}
if targetLine >= m.viewport.YOffset+m.viewport.Height {
m.viewport.SetYOffset(targetLine - m.viewport.Height + 1)
}
}
func (m configTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m configTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("config_title")))
sb.WriteString("\n")
if m.message != "" {
sb.WriteString(" " + m.message)
sb.WriteString("\n")
}
sb.WriteString(helpStyle.Render(T("config_help1")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("config_help2")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render(" ⚠ Error: " + m.err.Error()))
return sb.String()
}
if len(m.fields) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_config")))
return sb.String()
}
currentSection := ""
for i, f := range m.fields {
// Section headers
section := fieldSection(f.apiPath)
if section != currentSection {
currentSection = section
sb.WriteString("\n")
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(" ── " + section + " "))
sb.WriteString("\n")
}
isSelected := i == m.cursor
prefix := " "
if isSelected {
prefix = "▸ "
}
labelStr := lipgloss.NewStyle().
Foreground(colorInfo).
Bold(isSelected).
Width(32).
Render(f.label)
var valueStr string
if m.editing && isSelected {
valueStr = m.textInput.View()
} else {
switch f.kind {
case "bool":
if f.value == "true" {
valueStr = successStyle.Render("● ON")
} else {
valueStr = lipgloss.NewStyle().Foreground(colorMuted).Render("○ OFF")
}
case "readonly":
valueStr = lipgloss.NewStyle().Foreground(colorSubtext).Render(f.value)
default:
valueStr = valueStyle.Render(f.value)
}
}
line := prefix + labelStr + " " + valueStr
if isSelected && !m.editing {
line = lipgloss.NewStyle().Background(colorSurface).Render(line)
}
sb.WriteString(line + "\n")
}
return sb.String()
}
func (m configTabModel) parseConfig(cfg map[string]any) []configField {
var fields []configField
// Server settings
fields = append(fields, configField{"Port", "port", "readonly", fmt.Sprintf("%.0f", getFloat(cfg, "port")), nil})
fields = append(fields, configField{"Host", "host", "readonly", getString(cfg, "host"), nil})
fields = append(fields, configField{"Debug", "debug", "bool", fmt.Sprintf("%v", getBool(cfg, "debug")), nil})
fields = append(fields, configField{"Proxy URL", "proxy-url", "string", getString(cfg, "proxy-url"), nil})
fields = append(fields, configField{"Request Retry", "request-retry", "int", fmt.Sprintf("%.0f", getFloat(cfg, "request-retry")), nil})
fields = append(fields, configField{"Max Retry Interval (s)", "max-retry-interval", "int", fmt.Sprintf("%.0f", getFloat(cfg, "max-retry-interval")), nil})
fields = append(fields, configField{"Force Model Prefix", "force-model-prefix", "string", getString(cfg, "force-model-prefix"), nil})
// Logging
fields = append(fields, configField{"Logging to File", "logging-to-file", "bool", fmt.Sprintf("%v", getBool(cfg, "logging-to-file")), nil})
fields = append(fields, configField{"Logs Max Total Size (MB)", "logs-max-total-size-mb", "int", fmt.Sprintf("%.0f", getFloat(cfg, "logs-max-total-size-mb")), nil})
fields = append(fields, configField{"Error Logs Max Files", "error-logs-max-files", "int", fmt.Sprintf("%.0f", getFloat(cfg, "error-logs-max-files")), nil})
fields = append(fields, configField{"Usage Stats Enabled", "usage-statistics-enabled", "bool", fmt.Sprintf("%v", getBool(cfg, "usage-statistics-enabled")), nil})
fields = append(fields, configField{"Request Log", "request-log", "bool", fmt.Sprintf("%v", getBool(cfg, "request-log")), nil})
// Quota exceeded
fields = append(fields, configField{"Switch Project on Quota", "quota-exceeded/switch-project", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-project")), nil})
fields = append(fields, configField{"Switch Preview Model", "quota-exceeded/switch-preview-model", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-preview-model")), nil})
// Routing
if routing, ok := cfg["routing"].(map[string]any); ok {
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", getString(routing, "strategy"), nil})
} else {
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", "", nil})
}
// WebSocket auth
fields = append(fields, configField{"WebSocket Auth", "ws-auth", "bool", fmt.Sprintf("%v", getBool(cfg, "ws-auth")), nil})
// AMP settings
if amp, ok := cfg["ampcode"].(map[string]any); ok {
upstreamURL := getString(amp, "upstream-url")
upstreamAPIKey := getString(amp, "upstream-api-key")
fields = append(fields, configField{"AMP Upstream URL", "ampcode/upstream-url", "string", upstreamURL, upstreamURL})
fields = append(fields, configField{"AMP Upstream API Key", "ampcode/upstream-api-key", "string", maskIfNotEmpty(upstreamAPIKey), upstreamAPIKey})
fields = append(fields, configField{"AMP Restrict Mgmt Localhost", "ampcode/restrict-management-to-localhost", "bool", fmt.Sprintf("%v", getBool(amp, "restrict-management-to-localhost")), nil})
}
return fields
}
func fieldSection(apiPath string) string {
if strings.HasPrefix(apiPath, "ampcode/") {
return T("section_ampcode")
}
if strings.HasPrefix(apiPath, "quota-exceeded/") {
return T("section_quota")
}
if strings.HasPrefix(apiPath, "routing/") {
return T("section_routing")
}
switch apiPath {
case "port", "host", "debug", "proxy-url", "request-retry", "max-retry-interval", "force-model-prefix":
return T("section_server")
case "logging-to-file", "logs-max-total-size-mb", "error-logs-max-files", "usage-statistics-enabled", "request-log":
return T("section_logging")
case "ws-auth":
return T("section_websocket")
default:
return T("section_other")
}
}
func getBoolNested(m map[string]any, keys ...string) bool {
current := m
for i, key := range keys {
if i == len(keys)-1 {
return getBool(current, key)
}
if nested, ok := current[key].(map[string]any); ok {
current = nested
} else {
return false
}
}
return false
}
func maskIfNotEmpty(s string) string {
if s == "" {
return T("not_set")
}
return maskKey(s)
}

360
internal/tui/dashboard.go Normal file
View File

@@ -0,0 +1,360 @@
package tui
import (
"encoding/json"
"fmt"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// dashboardModel displays server info, stats cards, and config overview.
type dashboardModel struct {
client *Client
viewport viewport.Model
content string
err error
width int
height int
ready bool
// Cached data for re-rendering on locale change
lastConfig map[string]any
lastUsage map[string]any
lastAuthFiles []map[string]any
lastAPIKeys []string
}
type dashboardDataMsg struct {
config map[string]any
usage map[string]any
authFiles []map[string]any
apiKeys []string
err error
}
func newDashboardModel(client *Client) dashboardModel {
return dashboardModel{
client: client,
}
}
func (m dashboardModel) Init() tea.Cmd {
return m.fetchData
}
func (m dashboardModel) fetchData() tea.Msg {
cfg, cfgErr := m.client.GetConfig()
usage, usageErr := m.client.GetUsage()
authFiles, authErr := m.client.GetAuthFiles()
apiKeys, keysErr := m.client.GetAPIKeys()
var err error
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
if e != nil {
err = e
break
}
}
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
}
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
// Re-render immediately with cached data using new locale
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
m.viewport.SetContent(m.content)
// Also fetch fresh data in background
return m, m.fetchData
case dashboardDataMsg:
if msg.err != nil {
m.err = msg.err
m.content = errorStyle.Render("⚠ Error: " + msg.err.Error())
} else {
m.err = nil
// Cache data for locale switching
m.lastConfig = msg.config
m.lastUsage = msg.usage
m.lastAuthFiles = msg.authFiles
m.lastAPIKeys = msg.apiKeys
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
}
m.viewport.SetContent(m.content)
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *dashboardModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.content)
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m dashboardModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("dashboard_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("dashboard_help")))
sb.WriteString("\n\n")
// ━━━ Connection Status ━━━
connStyle := lipgloss.NewStyle().Bold(true).Foreground(colorSuccess)
sb.WriteString(connStyle.Render(T("connected")))
sb.WriteString(fmt.Sprintf(" %s", m.client.baseURL))
sb.WriteString("\n\n")
// ━━━ Stats Cards ━━━
cardWidth := 25
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 18 {
cardWidth = 18
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(2)
// Card 1: API Keys
keyCount := len(apiKeys)
card1 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("🔑 %d", keyCount)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("mgmt_keys")),
))
// Card 2: Auth Files
authCount := len(authFiles)
activeAuth := 0
for _, f := range authFiles {
if !getBool(f, "disabled") {
activeAuth++
}
}
card2 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("📄 %d", authCount)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
))
// Card 3: Total Requests
totalReqs := int64(0)
successReqs := int64(0)
failedReqs := int64(0)
totalTokens := int64(0)
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
totalReqs = int64(getFloat(usageMap, "total_requests"))
successReqs = int64(getFloat(usageMap, "success_count"))
failedReqs = int64(getFloat(usageMap, "failure_count"))
totalTokens = int64(getFloat(usageMap, "total_tokens"))
}
}
card3 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
))
// Card 4: Total Tokens
tokenStr := formatLargeNumber(totalTokens)
card4 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Current Config ━━━
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("current_config")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
if cfg != nil {
debug := getBool(cfg, "debug")
retry := getFloat(cfg, "request-retry")
proxyURL := getString(cfg, "proxy-url")
loggingToFile := getBool(cfg, "logging-to-file")
usageEnabled := true
if v, ok := cfg["usage-statistics-enabled"]; ok {
if b, ok2 := v.(bool); ok2 {
usageEnabled = b
}
}
configItems := []struct {
label string
value string
}{
{T("debug_mode"), boolEmoji(debug)},
{T("usage_stats"), boolEmoji(usageEnabled)},
{T("log_to_file"), boolEmoji(loggingToFile)},
{T("retry_count"), fmt.Sprintf("%.0f", retry)},
}
if proxyURL != "" {
configItems = append(configItems, struct {
label string
value string
}{T("proxy_url"), proxyURL})
}
// Render config items as a compact row
for _, item := range configItems {
sb.WriteString(fmt.Sprintf(" %s %s\n",
labelStyle.Render(item.label+":"),
valueStyle.Render(item.value)))
}
// Routing strategy
strategy := "round-robin"
if routing, ok := cfg["routing"].(map[string]any); ok {
if s := getString(routing, "strategy"); s != "" {
strategy = s
}
}
sb.WriteString(fmt.Sprintf(" %s %s\n",
labelStyle.Render(T("routing_strategy")+":"),
valueStyle.Render(strategy)))
}
sb.WriteString("\n")
// ━━━ Per-Model Usage ━━━
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for _, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
reqs := int64(getFloat(stats, "total_requests"))
toks := int64(getFloat(stats, "total_tokens"))
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
sb.WriteString(tableCellStyle.Render(row))
sb.WriteString("\n")
}
}
}
}
}
}
}
}
return sb.String()
}
func formatKV(key, value string) string {
return fmt.Sprintf(" %s %s\n", labelStyle.Render(key+":"), valueStyle.Render(value))
}
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getFloat(m map[string]any, key string) float64 {
if v, ok := m[key]; ok {
switch n := v.(type) {
case float64:
return n
case json.Number:
f, _ := n.Float64()
return f
}
}
return 0
}
func getBool(m map[string]any, key string) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return false
}
func boolEmoji(b bool) string {
if b {
return T("bool_yes")
}
return T("bool_no")
}
func formatLargeNumber(n int64) string {
if n >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
}
if n >= 1_000 {
return fmt.Sprintf("%.1fK", float64(n)/1_000)
}
return fmt.Sprintf("%d", n)
}
func truncate(s string, maxLen int) string {
if len(s) > maxLen {
return s[:maxLen-3] + "..."
}
return s
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}

364
internal/tui/i18n.go Normal file
View File

@@ -0,0 +1,364 @@
package tui
// i18n provides a simple internationalization system for the TUI.
// Supported locales: "zh" (Chinese, default), "en" (English).
var currentLocale = "en"
// SetLocale changes the active locale.
func SetLocale(locale string) {
if _, ok := locales[locale]; ok {
currentLocale = locale
}
}
// CurrentLocale returns the active locale code.
func CurrentLocale() string {
return currentLocale
}
// ToggleLocale switches between zh and en.
func ToggleLocale() {
if currentLocale == "zh" {
currentLocale = "en"
} else {
currentLocale = "zh"
}
}
// T returns the translated string for the given key.
func T(key string) string {
if m, ok := locales[currentLocale]; ok {
if v, ok := m[key]; ok {
return v
}
}
// Fallback to English
if m, ok := locales["en"]; ok {
if v, ok := m[key]; ok {
return v
}
}
return key
}
var locales = map[string]map[string]string{
"zh": zhStrings,
"en": enStrings,
}
// ──────────────────────────────────────────
// Tab names
// ──────────────────────────────────────────
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
// TabNames returns tab names in the current locale.
func TabNames() []string {
if currentLocale == "zh" {
return zhTabNames
}
return enTabNames
}
var zhStrings = map[string]string{
// ── Common ──
"loading": "加载中...",
"refresh": "刷新",
"save": "保存",
"cancel": "取消",
"confirm": "确认",
"yes": "是",
"no": "否",
"error": "错误",
"success": "成功",
"navigate": "导航",
"scroll": "滚动",
"enter_save": "Enter: 保存",
"esc_cancel": "Esc: 取消",
"enter_submit": "Enter: 提交",
"press_r": "[r] 刷新",
"press_scroll": "[↑↓] 滚动",
"not_set": "(未设置)",
"error_prefix": "⚠ 错误: ",
// ── Status bar ──
"status_left": " CLIProxyAPI 管理终端",
"status_right": "Tab/Shift+Tab: 切换 • L: 语言 • q/Ctrl+C: 退出 ",
"initializing_tui": "正在初始化...",
"auth_gate_title": "🔐 连接管理 API",
"auth_gate_help": " 请输入管理密码并按 Enter 连接",
"auth_gate_password": "密码",
"auth_gate_enter": " Enter: 连接 • q/Ctrl+C: 退出 • L: 语言",
"auth_gate_connecting": "正在连接...",
"auth_gate_connect_fail": "连接失败:%s",
"auth_gate_password_required": "请输入密码",
// ── Dashboard ──
"dashboard_title": "📊 仪表盘",
"dashboard_help": " [r] 刷新 • [↑↓] 滚动",
"connected": "● 已连接",
"mgmt_keys": "管理密钥",
"auth_files_label": "认证文件",
"active_suffix": "活跃",
"total_requests": "请求",
"success_label": "成功",
"failure_label": "失败",
"total_tokens": "总 Tokens",
"current_config": "当前配置",
"debug_mode": "启用调试模式",
"usage_stats": "启用使用统计",
"log_to_file": "启用日志记录到文件",
"retry_count": "重试次数",
"proxy_url": "代理 URL",
"routing_strategy": "路由策略",
"model_stats": "模型统计",
"model": "模型",
"requests": "请求数",
"tokens": "Tokens",
"bool_yes": "是 ✓",
"bool_no": "否",
// ── Config ──
"config_title": "⚙ 配置",
"config_help1": " [↑↓/jk] 导航 • [Enter/Space] 编辑 • [r] 刷新",
"config_help2": " 布尔: Enter 切换 • 文本/数字: Enter 输入, Enter 确认, Esc 取消",
"updated_ok": "✓ 更新成功",
"no_config": " 未加载配置",
"invalid_int": "无效整数",
"section_server": "服务器",
"section_logging": "日志与统计",
"section_quota": "配额超限处理",
"section_routing": "路由",
"section_websocket": "WebSocket",
"section_ampcode": "AMP Code",
"section_other": "其他",
// ── Auth Files ──
"auth_title": "🔑 认证文件",
"auth_help1": " [↑↓/jk] 导航 • [Enter] 展开 • [e] 启用/停用 • [d] 删除 • [r] 刷新",
"auth_help2": " [1] 编辑 prefix • [2] 编辑 proxy_url • [3] 编辑 priority",
"no_auth_files": " 无认证文件",
"confirm_delete": "⚠ 删除 %s? [y/n]",
"deleted": "已删除 %s",
"enabled": "已启用",
"disabled": "已停用",
"updated_field": "已更新 %s 的 %s",
"status_active": "活跃",
"status_disabled": "已停用",
// ── API Keys ──
"keys_title": "🔐 API 密钥",
"keys_help": " [↑↓/jk] 导航 • [a] 添加 • [e] 编辑 • [d] 删除 • [c] 复制 • [r] 刷新",
"no_keys": " 无 API Key按 [a] 添加",
"access_keys": "Access API Keys",
"confirm_delete_key": "⚠ 确认删除 %s? [y/n]",
"key_added": "已添加 API Key",
"key_updated": "已更新 API Key",
"key_deleted": "已删除 API Key",
"copied": "✓ 已复制到剪贴板",
"copy_failed": "✗ 复制失败",
"new_key_prompt": " New Key: ",
"edit_key_prompt": " Edit Key: ",
"enter_add": " Enter: 添加 • Esc: 取消",
"enter_save_esc": " Enter: 保存 • Esc: 取消",
// ── OAuth ──
"oauth_title": "🔐 OAuth 登录",
"oauth_select": " 选择提供商并按 [Enter] 开始 OAuth 登录:",
"oauth_help": " [↑↓/jk] 导航 • [Enter] 登录 • [Esc] 清除状态",
"oauth_initiating": "⏳ 正在初始化 %s 登录...",
"oauth_success": "认证成功! 请刷新 Auth Files 标签查看新凭证。",
"oauth_completed": "认证流程已完成。",
"oauth_failed": "认证失败",
"oauth_timeout": "OAuth 流程超时 (5 分钟)",
"oauth_press_esc": " 按 [Esc] 取消",
"oauth_auth_url": " 授权链接:",
"oauth_remote_hint": " 远程浏览器模式:在浏览器中打开上述链接完成授权后,将回调 URL 粘贴到下方。",
"oauth_callback_url": " 回调 URL:",
"oauth_press_c": " 按 [c] 输入回调 URL • [Esc] 返回",
"oauth_submitting": "⏳ 提交回调中...",
"oauth_submit_ok": "✓ 回调已提交,等待处理...",
"oauth_submit_fail": "✗ 提交回调失败",
"oauth_waiting": " 等待认证中...",
// ── Usage ──
"usage_title": "📈 使用统计",
"usage_help": " [r] 刷新 • [↑↓] 滚动",
"usage_no_data": " 使用数据不可用",
"usage_total_reqs": "总请求数",
"usage_total_tokens": "总 Token 数",
"usage_success": "成功",
"usage_failure": "失败",
"usage_total_token_l": "总Token",
"usage_rpm": "RPM",
"usage_tpm": "TPM",
"usage_req_by_hour": "请求趋势 (按小时)",
"usage_tok_by_hour": "Token 使用趋势 (按小时)",
"usage_req_by_day": "请求趋势 (按天)",
"usage_api_detail": "API 详细统计",
"usage_input": "输入",
"usage_output": "输出",
"usage_cached": "缓存",
"usage_reasoning": "思考",
// ── Logs ──
"logs_title": "📋 日志",
"logs_auto_scroll": "● 自动滚动",
"logs_paused": "○ 已暂停",
"logs_filter": "过滤",
"logs_lines": "行数",
"logs_help": " [a] 自动滚动 • [c] 清除 • [1] 全部 [2] info+ [3] warn+ [4] error • [↑↓] 滚动",
"logs_waiting": " 等待日志输出...",
}
var enStrings = map[string]string{
// ── Common ──
"loading": "Loading...",
"refresh": "Refresh",
"save": "Save",
"cancel": "Cancel",
"confirm": "Confirm",
"yes": "Yes",
"no": "No",
"error": "Error",
"success": "Success",
"navigate": "Navigate",
"scroll": "Scroll",
"enter_save": "Enter: Save",
"esc_cancel": "Esc: Cancel",
"enter_submit": "Enter: Submit",
"press_r": "[r] Refresh",
"press_scroll": "[↑↓] Scroll",
"not_set": "(not set)",
"error_prefix": "⚠ Error: ",
// ── Status bar ──
"status_left": " CLIProxyAPI Management TUI",
"status_right": "Tab/Shift+Tab: switch • L: lang • q/Ctrl+C: quit ",
"initializing_tui": "Initializing...",
"auth_gate_title": "🔐 Connect Management API",
"auth_gate_help": " Enter management password and press Enter to connect",
"auth_gate_password": "Password",
"auth_gate_enter": " Enter: connect • q/Ctrl+C: quit • L: lang",
"auth_gate_connecting": "Connecting...",
"auth_gate_connect_fail": "Connection failed: %s",
"auth_gate_password_required": "password is required",
// ── Dashboard ──
"dashboard_title": "📊 Dashboard",
"dashboard_help": " [r] Refresh • [↑↓] Scroll",
"connected": "● Connected",
"mgmt_keys": "Mgmt Keys",
"auth_files_label": "Auth Files",
"active_suffix": "active",
"total_requests": "Requests",
"success_label": "Success",
"failure_label": "Failed",
"total_tokens": "Total Tokens",
"current_config": "Current Config",
"debug_mode": "Debug Mode",
"usage_stats": "Usage Statistics",
"log_to_file": "Log to File",
"retry_count": "Retry Count",
"proxy_url": "Proxy URL",
"routing_strategy": "Routing Strategy",
"model_stats": "Model Stats",
"model": "Model",
"requests": "Requests",
"tokens": "Tokens",
"bool_yes": "Yes ✓",
"bool_no": "No",
// ── Config ──
"config_title": "⚙ Configuration",
"config_help1": " [↑↓/jk] Navigate • [Enter/Space] Edit • [r] Refresh",
"config_help2": " Bool: Enter to toggle • String/Int: Enter to type, Enter to confirm, Esc to cancel",
"updated_ok": "✓ Updated successfully",
"no_config": " No configuration loaded",
"invalid_int": "invalid integer",
"section_server": "Server",
"section_logging": "Logging & Stats",
"section_quota": "Quota Exceeded Handling",
"section_routing": "Routing",
"section_websocket": "WebSocket",
"section_ampcode": "AMP Code",
"section_other": "Other",
// ── Auth Files ──
"auth_title": "🔑 Auth Files",
"auth_help1": " [↑↓/jk] Navigate • [Enter] Expand • [e] Enable/Disable • [d] Delete • [r] Refresh",
"auth_help2": " [1] Edit prefix • [2] Edit proxy_url • [3] Edit priority",
"no_auth_files": " No auth files found",
"confirm_delete": "⚠ Delete %s? [y/n]",
"deleted": "Deleted %s",
"enabled": "Enabled",
"disabled": "Disabled",
"updated_field": "Updated %s on %s",
"status_active": "active",
"status_disabled": "disabled",
// ── API Keys ──
"keys_title": "🔐 API Keys",
"keys_help": " [↑↓/jk] Navigate • [a] Add • [e] Edit • [d] Delete • [c] Copy • [r] Refresh",
"no_keys": " No API Keys. Press [a] to add",
"access_keys": "Access API Keys",
"confirm_delete_key": "⚠ Delete %s? [y/n]",
"key_added": "API Key added",
"key_updated": "API Key updated",
"key_deleted": "API Key deleted",
"copied": "✓ Copied to clipboard",
"copy_failed": "✗ Copy failed",
"new_key_prompt": " New Key: ",
"edit_key_prompt": " Edit Key: ",
"enter_add": " Enter: Add • Esc: Cancel",
"enter_save_esc": " Enter: Save • Esc: Cancel",
// ── OAuth ──
"oauth_title": "🔐 OAuth Login",
"oauth_select": " Select a provider and press [Enter] to start OAuth login:",
"oauth_help": " [↑↓/jk] Navigate • [Enter] Login • [Esc] Clear status",
"oauth_initiating": "⏳ Initiating %s login...",
"oauth_success": "Authentication successful! Refresh Auth Files tab to see the new credential.",
"oauth_completed": "Authentication flow completed.",
"oauth_failed": "Authentication failed",
"oauth_timeout": "OAuth flow timed out (5 minutes)",
"oauth_press_esc": " Press [Esc] to cancel",
"oauth_auth_url": " Authorization URL:",
"oauth_remote_hint": " Remote browser mode: Open the URL above in browser, paste the callback URL below after authorization.",
"oauth_callback_url": " Callback URL:",
"oauth_press_c": " Press [c] to enter callback URL • [Esc] to go back",
"oauth_submitting": "⏳ Submitting callback...",
"oauth_submit_ok": "✓ Callback submitted, waiting...",
"oauth_submit_fail": "✗ Callback submission failed",
"oauth_waiting": " Waiting for authentication...",
// ── Usage ──
"usage_title": "📈 Usage Statistics",
"usage_help": " [r] Refresh • [↑↓] Scroll",
"usage_no_data": " Usage data not available",
"usage_total_reqs": "Total Requests",
"usage_total_tokens": "Total Tokens",
"usage_success": "Success",
"usage_failure": "Failed",
"usage_total_token_l": "Total Tokens",
"usage_rpm": "RPM",
"usage_tpm": "TPM",
"usage_req_by_hour": "Requests by Hour",
"usage_tok_by_hour": "Token Usage by Hour",
"usage_req_by_day": "Requests by Day",
"usage_api_detail": "API Detail Statistics",
"usage_input": "Input",
"usage_output": "Output",
"usage_cached": "Cached",
"usage_reasoning": "Reasoning",
// ── Logs ──
"logs_title": "📋 Logs",
"logs_auto_scroll": "● AUTO-SCROLL",
"logs_paused": "○ PAUSED",
"logs_filter": "Filter",
"logs_lines": "Lines",
"logs_help": " [a] Auto-scroll • [c] Clear • [1] All [2] info+ [3] warn+ [4] error • [↑↓] Scroll",
"logs_waiting": " Waiting for log output...",
}

405
internal/tui/keys_tab.go Normal file
View File

@@ -0,0 +1,405 @@
package tui
import (
"fmt"
"strings"
"github.com/atotto/clipboard"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// keysTabModel displays and manages API keys.
type keysTabModel struct {
client *Client
viewport viewport.Model
keys []string
gemini []map[string]any
claude []map[string]any
codex []map[string]any
vertex []map[string]any
openai []map[string]any
err error
width int
height int
ready bool
cursor int
confirm int // -1 = no deletion pending
status string
// Editing / Adding
editing bool
adding bool
editIdx int
editInput textinput.Model
}
type keysDataMsg struct {
apiKeys []string
gemini []map[string]any
claude []map[string]any
codex []map[string]any
vertex []map[string]any
openai []map[string]any
err error
}
type keyActionMsg struct {
action string
err error
}
func newKeysTabModel(client *Client) keysTabModel {
ti := textinput.New()
ti.CharLimit = 512
ti.Prompt = " Key: "
return keysTabModel{
client: client,
confirm: -1,
editInput: ti,
}
}
func (m keysTabModel) Init() tea.Cmd {
return m.fetchKeys
}
func (m keysTabModel) fetchKeys() tea.Msg {
result := keysDataMsg{}
apiKeys, err := m.client.GetAPIKeys()
if err != nil {
result.err = err
return result
}
result.apiKeys = apiKeys
result.gemini, _ = m.client.GetGeminiKeys()
result.claude, _ = m.client.GetClaudeKeys()
result.codex, _ = m.client.GetCodexKeys()
result.vertex, _ = m.client.GetVertexKeys()
result.openai, _ = m.client.GetOpenAICompat()
return result
}
func (m keysTabModel) Update(msg tea.Msg) (keysTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case keysDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.keys = msg.apiKeys
m.gemini = msg.gemini
m.claude = msg.claude
m.codex = msg.codex
m.vertex = msg.vertex
m.openai = msg.openai
if m.cursor >= len(m.keys) {
m.cursor = max(0, len(m.keys)-1)
}
}
m.viewport.SetContent(m.renderContent())
return m, nil
case keyActionMsg:
if msg.err != nil {
m.status = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.status = successStyle.Render("✓ " + msg.action)
}
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, m.fetchKeys
case tea.KeyMsg:
// ---- Editing / Adding mode ----
if m.editing || m.adding {
switch msg.String() {
case "enter":
value := strings.TrimSpace(m.editInput.Value())
if value == "" {
m.editing = false
m.adding = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
}
isAdding := m.adding
editIdx := m.editIdx
m.editing = false
m.adding = false
m.editInput.Blur()
if isAdding {
return m, func() tea.Msg {
err := m.client.AddAPIKey(value)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_added")}
}
}
return m, func() tea.Msg {
err := m.client.EditAPIKey(editIdx, value)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_updated")}
}
case "esc":
m.editing = false
m.adding = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.editInput, cmd = m.editInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
// ---- Delete confirmation ----
if m.confirm >= 0 {
switch msg.String() {
case "y", "Y":
idx := m.confirm
m.confirm = -1
return m, func() tea.Msg {
err := m.client.DeleteAPIKey(idx)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_deleted")}
}
case "n", "N", "esc":
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, nil
}
return m, nil
}
// ---- Normal mode ----
switch msg.String() {
case "j", "down":
if len(m.keys) > 0 {
m.cursor = (m.cursor + 1) % len(m.keys)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "k", "up":
if len(m.keys) > 0 {
m.cursor = (m.cursor - 1 + len(m.keys)) % len(m.keys)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "a":
// Add new key
m.adding = true
m.editing = false
m.editInput.SetValue("")
m.editInput.Prompt = T("new_key_prompt")
m.editInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
case "e":
// Edit selected key
if m.cursor < len(m.keys) {
m.editing = true
m.adding = false
m.editIdx = m.cursor
m.editInput.SetValue(m.keys[m.cursor])
m.editInput.Prompt = T("edit_key_prompt")
m.editInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
}
return m, nil
case "d":
// Delete selected key
if m.cursor < len(m.keys) {
m.confirm = m.cursor
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "c":
// Copy selected key to clipboard
if m.cursor < len(m.keys) {
key := m.keys[m.cursor]
if err := clipboard.WriteAll(key); err != nil {
m.status = errorStyle.Render(T("copy_failed") + ": " + err.Error())
} else {
m.status = successStyle.Render(T("copied"))
}
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "r":
m.status = ""
return m, m.fetchKeys
default:
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *keysTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.editInput.Width = w - 16
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m keysTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m keysTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("keys_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("keys_help")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.err != nil {
sb.WriteString(errorStyle.Render(T("error_prefix") + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
// ━━━ Access API Keys (interactive) ━━━
sb.WriteString(tableHeaderStyle.Render(fmt.Sprintf(" %s (%d)", T("access_keys"), len(m.keys))))
sb.WriteString("\n")
if len(m.keys) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_keys")))
sb.WriteString("\n")
}
for i, key := range m.keys {
cursor := " "
rowStyle := lipgloss.NewStyle()
if i == m.cursor {
cursor = "▸ "
rowStyle = lipgloss.NewStyle().Bold(true)
}
row := fmt.Sprintf("%s%d. %s", cursor, i+1, maskKey(key))
sb.WriteString(rowStyle.Render(row))
sb.WriteString("\n")
// Delete confirmation
if m.confirm == i {
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete_key"), maskKey(key))))
sb.WriteString("\n")
}
// Edit input
if m.editing && m.editIdx == i {
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("enter_save_esc")))
sb.WriteString("\n")
}
}
// Add input
if m.adding {
sb.WriteString("\n")
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("enter_add")))
sb.WriteString("\n")
}
sb.WriteString("\n")
// ━━━ Provider Keys (read-only display) ━━━
renderProviderKeys(&sb, "Gemini API Keys", m.gemini)
renderProviderKeys(&sb, "Claude API Keys", m.claude)
renderProviderKeys(&sb, "Codex API Keys", m.codex)
renderProviderKeys(&sb, "Vertex API Keys", m.vertex)
if len(m.openai) > 0 {
renderSection(&sb, "OpenAI Compatibility", len(m.openai))
for i, entry := range m.openai {
name := getString(entry, "name")
baseURL := getString(entry, "base-url")
prefix := getString(entry, "prefix")
info := name
if prefix != "" {
info += " (prefix: " + prefix + ")"
}
if baseURL != "" {
info += " → " + baseURL
}
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
}
sb.WriteString("\n")
}
if m.status != "" {
sb.WriteString(m.status)
sb.WriteString("\n")
}
return sb.String()
}
func renderSection(sb *strings.Builder, title string, count int) {
header := fmt.Sprintf("%s (%d)", title, count)
sb.WriteString(tableHeaderStyle.Render(" " + header))
sb.WriteString("\n")
}
func renderProviderKeys(sb *strings.Builder, title string, keys []map[string]any) {
if len(keys) == 0 {
return
}
renderSection(sb, title, len(keys))
for i, key := range keys {
apiKey := getString(key, "api-key")
prefix := getString(key, "prefix")
baseURL := getString(key, "base-url")
info := maskKey(apiKey)
if prefix != "" {
info += " (prefix: " + prefix + ")"
}
if baseURL != "" {
info += " → " + baseURL
}
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
}
sb.WriteString("\n")
}
func maskKey(key string) string {
if len(key) <= 8 {
return strings.Repeat("*", len(key))
}
return key[:4] + strings.Repeat("*", len(key)-8) + key[len(key)-4:]
}

78
internal/tui/loghook.go Normal file
View File

@@ -0,0 +1,78 @@
package tui
import (
"fmt"
"strings"
"sync"
log "github.com/sirupsen/logrus"
)
// LogHook is a logrus hook that captures log entries and sends them to a channel.
type LogHook struct {
ch chan string
formatter log.Formatter
mu sync.Mutex
levels []log.Level
}
// NewLogHook creates a new LogHook with a buffered channel of the given size.
func NewLogHook(bufSize int) *LogHook {
return &LogHook{
ch: make(chan string, bufSize),
formatter: &log.TextFormatter{DisableColors: true, FullTimestamp: true},
levels: log.AllLevels,
}
}
// SetFormatter sets a custom formatter for the hook.
func (h *LogHook) SetFormatter(f log.Formatter) {
h.mu.Lock()
defer h.mu.Unlock()
h.formatter = f
}
// Levels returns the log levels this hook should fire on.
func (h *LogHook) Levels() []log.Level {
return h.levels
}
// Fire is called by logrus when a log entry is fired.
func (h *LogHook) Fire(entry *log.Entry) error {
h.mu.Lock()
f := h.formatter
h.mu.Unlock()
var line string
if f != nil {
b, err := f.Format(entry)
if err == nil {
line = strings.TrimRight(string(b), "\n\r")
} else {
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
}
} else {
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
}
// Non-blocking send
select {
case h.ch <- line:
default:
// Drop oldest if full
select {
case <-h.ch:
default:
}
select {
case h.ch <- line:
default:
}
}
return nil
}
// Chan returns the channel to read log lines from.
func (h *LogHook) Chan() <-chan string {
return h.ch
}

261
internal/tui/logs_tab.go Normal file
View File

@@ -0,0 +1,261 @@
package tui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
)
// logsTabModel displays real-time log lines from hook/API source.
type logsTabModel struct {
client *Client
hook *LogHook
viewport viewport.Model
lines []string
maxLines int
autoScroll bool
width int
height int
ready bool
filter string // "", "debug", "info", "warn", "error"
after int64
lastErr error
}
type logsPollMsg struct {
lines []string
latest int64
err error
}
type logsTickMsg struct{}
type logLineMsg string
func newLogsTabModel(client *Client, hook *LogHook) logsTabModel {
return logsTabModel{
client: client,
hook: hook,
maxLines: 5000,
autoScroll: true,
}
}
func (m logsTabModel) Init() tea.Cmd {
if m.hook != nil {
return m.waitForLog
}
return m.fetchLogs
}
func (m logsTabModel) fetchLogs() tea.Msg {
lines, latest, err := m.client.GetLogs(m.after, 200)
return logsPollMsg{
lines: lines,
latest: latest,
err: err,
}
}
func (m logsTabModel) waitForNextPoll() tea.Cmd {
return tea.Tick(2*time.Second, func(_ time.Time) tea.Msg {
return logsTickMsg{}
})
}
func (m logsTabModel) waitForLog() tea.Msg {
if m.hook == nil {
return nil
}
line, ok := <-m.hook.Chan()
if !ok {
return nil
}
return logLineMsg(line)
}
func (m logsTabModel) Update(msg tea.Msg) (logsTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderLogs())
return m, nil
case logsTickMsg:
if m.hook != nil {
return m, nil
}
return m, m.fetchLogs
case logsPollMsg:
if m.hook != nil {
return m, nil
}
if msg.err != nil {
m.lastErr = msg.err
} else {
m.lastErr = nil
m.after = msg.latest
if len(msg.lines) > 0 {
m.lines = append(m.lines, msg.lines...)
if len(m.lines) > m.maxLines {
m.lines = m.lines[len(m.lines)-m.maxLines:]
}
}
}
m.viewport.SetContent(m.renderLogs())
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, m.waitForNextPoll()
case logLineMsg:
m.lines = append(m.lines, string(msg))
if len(m.lines) > m.maxLines {
m.lines = m.lines[len(m.lines)-m.maxLines:]
}
m.viewport.SetContent(m.renderLogs())
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, m.waitForLog
case tea.KeyMsg:
switch msg.String() {
case "a":
m.autoScroll = !m.autoScroll
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, nil
case "c":
m.lines = nil
m.lastErr = nil
m.viewport.SetContent(m.renderLogs())
return m, nil
case "1":
m.filter = ""
m.viewport.SetContent(m.renderLogs())
return m, nil
case "2":
m.filter = "info"
m.viewport.SetContent(m.renderLogs())
return m, nil
case "3":
m.filter = "warn"
m.viewport.SetContent(m.renderLogs())
return m, nil
case "4":
m.filter = "error"
m.viewport.SetContent(m.renderLogs())
return m, nil
default:
wasAtBottom := m.viewport.AtBottom()
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
// If user scrolls up, disable auto-scroll
if !m.viewport.AtBottom() && wasAtBottom {
m.autoScroll = false
}
// If user scrolls to bottom, re-enable auto-scroll
if m.viewport.AtBottom() {
m.autoScroll = true
}
return m, cmd
}
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *logsTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderLogs())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m logsTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m logsTabModel) renderLogs() string {
var sb strings.Builder
scrollStatus := successStyle.Render(T("logs_auto_scroll"))
if !m.autoScroll {
scrollStatus = warningStyle.Render(T("logs_paused"))
}
filterLabel := "ALL"
if m.filter != "" {
filterLabel = strings.ToUpper(m.filter) + "+"
}
header := fmt.Sprintf(" %s %s %s: %s %s: %d",
T("logs_title"), scrollStatus, T("logs_filter"), filterLabel, T("logs_lines"), len(m.lines))
sb.WriteString(titleStyle.Render(header))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("logs_help")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.lastErr != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.lastErr.Error()))
sb.WriteString("\n")
}
if len(m.lines) == 0 {
sb.WriteString(subtitleStyle.Render(T("logs_waiting")))
return sb.String()
}
for _, line := range m.lines {
if m.filter != "" && !m.matchLevel(line) {
continue
}
styled := m.styleLine(line)
sb.WriteString(styled)
sb.WriteString("\n")
}
return sb.String()
}
func (m logsTabModel) matchLevel(line string) bool {
switch m.filter {
case "error":
return strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") || strings.Contains(line, "[panic]")
case "warn":
return strings.Contains(line, "[warn") || strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]")
case "info":
return !strings.Contains(line, "[debug]")
default:
return true
}
}
func (m logsTabModel) styleLine(line string) string {
if strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") {
return logErrorStyle.Render(line)
}
if strings.Contains(line, "[warn") {
return logWarnStyle.Render(line)
}
if strings.Contains(line, "[info") {
return logInfoStyle.Render(line)
}
if strings.Contains(line, "[debug]") {
return logDebugStyle.Render(line)
}
return line
}

473
internal/tui/oauth_tab.go Normal file
View File

@@ -0,0 +1,473 @@
package tui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// oauthProvider represents an OAuth provider option.
type oauthProvider struct {
name string
apiPath string // management API path
emoji string
}
var oauthProviders = []oauthProvider{
{"Gemini CLI", "gemini-cli-auth-url", "🟦"},
{"Claude (Anthropic)", "anthropic-auth-url", "🟧"},
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
{"Antigravity", "antigravity-auth-url", "🟪"},
{"Qwen", "qwen-auth-url", "🟨"},
{"Kimi", "kimi-auth-url", "🟫"},
{"IFlow", "iflow-auth-url", "⬜"},
}
// oauthTabModel handles OAuth login flows.
type oauthTabModel struct {
client *Client
viewport viewport.Model
cursor int
state oauthState
message string
err error
width int
height int
ready bool
// Remote browser mode
authURL string // auth URL to display
authState string // OAuth state parameter
providerName string // current provider name
callbackInput textinput.Model
inputActive bool // true when user is typing callback URL
}
type oauthState int
const (
oauthIdle oauthState = iota
oauthPending
oauthRemote // remote browser mode: waiting for manual callback
oauthSuccess
oauthError
)
// Messages
type oauthStartMsg struct {
url string
state string
providerName string
err error
}
type oauthPollMsg struct {
done bool
message string
err error
}
type oauthCallbackSubmitMsg struct {
err error
}
func newOAuthTabModel(client *Client) oauthTabModel {
ti := textinput.New()
ti.Placeholder = "http://localhost:.../auth/callback?code=...&state=..."
ti.CharLimit = 2048
ti.Prompt = " 回调 URL: "
return oauthTabModel{
client: client,
callbackInput: ti,
}
}
func (m oauthTabModel) Init() tea.Cmd {
return nil
}
func (m oauthTabModel) Update(msg tea.Msg) (oauthTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case oauthStartMsg:
if msg.err != nil {
m.state = oauthError
m.err = msg.err
m.message = errorStyle.Render("✗ " + msg.err.Error())
m.viewport.SetContent(m.renderContent())
return m, nil
}
m.authURL = msg.url
m.authState = msg.state
m.providerName = msg.providerName
m.state = oauthRemote
m.callbackInput.SetValue("")
m.callbackInput.Focus()
m.inputActive = true
m.message = ""
m.viewport.SetContent(m.renderContent())
// Also start polling in the background
return m, tea.Batch(textinput.Blink, m.pollOAuthStatus(msg.state))
case oauthPollMsg:
if msg.err != nil {
m.state = oauthError
m.err = msg.err
m.message = errorStyle.Render("✗ " + msg.err.Error())
m.inputActive = false
m.callbackInput.Blur()
} else if msg.done {
m.state = oauthSuccess
m.message = successStyle.Render("✓ " + msg.message)
m.inputActive = false
m.callbackInput.Blur()
} else {
m.message = warningStyle.Render("⏳ " + msg.message)
}
m.viewport.SetContent(m.renderContent())
return m, nil
case oauthCallbackSubmitMsg:
if msg.err != nil {
m.message = errorStyle.Render(T("oauth_submit_fail") + ": " + msg.err.Error())
} else {
m.message = successStyle.Render(T("oauth_submit_ok"))
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
// ---- Input active: typing callback URL ----
if m.inputActive {
switch msg.String() {
case "enter":
callbackURL := m.callbackInput.Value()
if callbackURL == "" {
return m, nil
}
m.inputActive = false
m.callbackInput.Blur()
m.message = warningStyle.Render(T("oauth_submitting"))
m.viewport.SetContent(m.renderContent())
return m, m.submitCallback(callbackURL)
case "esc":
m.inputActive = false
m.callbackInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.callbackInput, cmd = m.callbackInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
// ---- Remote mode but not typing ----
if m.state == oauthRemote {
switch msg.String() {
case "c", "C":
// Re-activate input
m.inputActive = true
m.callbackInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
case "esc":
m.state = oauthIdle
m.message = ""
m.authURL = ""
m.authState = ""
m.viewport.SetContent(m.renderContent())
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
// ---- Pending (auto polling) ----
if m.state == oauthPending {
if msg.String() == "esc" {
m.state = oauthIdle
m.message = ""
m.viewport.SetContent(m.renderContent())
}
return m, nil
}
// ---- Idle ----
switch msg.String() {
case "up", "k":
if m.cursor > 0 {
m.cursor--
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "down", "j":
if m.cursor < len(oauthProviders)-1 {
m.cursor++
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "enter":
if m.cursor >= 0 && m.cursor < len(oauthProviders) {
provider := oauthProviders[m.cursor]
m.state = oauthPending
m.message = warningStyle.Render(fmt.Sprintf(T("oauth_initiating"), provider.name))
m.viewport.SetContent(m.renderContent())
return m, m.startOAuth(provider)
}
return m, nil
case "esc":
m.state = oauthIdle
m.message = ""
m.err = nil
m.viewport.SetContent(m.renderContent())
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m oauthTabModel) startOAuth(provider oauthProvider) tea.Cmd {
return func() tea.Msg {
// Call the auth URL endpoint with is_webui=true
data, err := m.client.getJSON("/v0/management/" + provider.apiPath + "?is_webui=true")
if err != nil {
return oauthStartMsg{err: fmt.Errorf("failed to start %s login: %w", provider.name, err)}
}
authURL := getString(data, "url")
state := getString(data, "state")
if authURL == "" {
return oauthStartMsg{err: fmt.Errorf("no auth URL returned for %s", provider.name)}
}
// Try to open browser (best effort)
_ = openBrowser(authURL)
return oauthStartMsg{url: authURL, state: state, providerName: provider.name}
}
}
func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
return func() tea.Msg {
// Determine provider from current context
providerKey := ""
for _, p := range oauthProviders {
if p.name == m.providerName {
// Map provider name to the canonical key the API expects
switch p.apiPath {
case "gemini-cli-auth-url":
providerKey = "gemini"
case "anthropic-auth-url":
providerKey = "anthropic"
case "codex-auth-url":
providerKey = "codex"
case "antigravity-auth-url":
providerKey = "antigravity"
case "qwen-auth-url":
providerKey = "qwen"
case "kimi-auth-url":
providerKey = "kimi"
case "iflow-auth-url":
providerKey = "iflow"
}
break
}
}
body := map[string]string{
"provider": providerKey,
"redirect_url": callbackURL,
"state": m.authState,
}
err := m.client.postJSON("/v0/management/oauth-callback", body)
if err != nil {
return oauthCallbackSubmitMsg{err: err}
}
return oauthCallbackSubmitMsg{}
}
}
func (m oauthTabModel) pollOAuthStatus(state string) tea.Cmd {
return func() tea.Msg {
// Poll session status for up to 5 minutes
deadline := time.Now().Add(5 * time.Minute)
for {
if time.Now().After(deadline) {
return oauthPollMsg{done: false, err: fmt.Errorf("%s", T("oauth_timeout"))}
}
time.Sleep(2 * time.Second)
status, errMsg, err := m.client.GetAuthStatus(state)
if err != nil {
continue // Ignore transient errors
}
switch status {
case "ok":
return oauthPollMsg{
done: true,
message: T("oauth_success"),
}
case "error":
return oauthPollMsg{
done: false,
err: fmt.Errorf("%s: %s", T("oauth_failed"), errMsg),
}
case "wait":
continue
default:
return oauthPollMsg{
done: true,
message: T("oauth_completed"),
}
}
}
}
}
func (m *oauthTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.callbackInput.Width = w - 16
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m oauthTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m oauthTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("oauth_title")))
sb.WriteString("\n\n")
if m.message != "" {
sb.WriteString(" " + m.message)
sb.WriteString("\n\n")
}
// ---- Remote browser mode ----
if m.state == oauthRemote {
sb.WriteString(m.renderRemoteMode())
return sb.String()
}
if m.state == oauthPending {
sb.WriteString(helpStyle.Render(T("oauth_press_esc")))
return sb.String()
}
sb.WriteString(helpStyle.Render(T("oauth_select")))
sb.WriteString("\n\n")
for i, p := range oauthProviders {
isSelected := i == m.cursor
prefix := " "
if isSelected {
prefix = "▸ "
}
label := fmt.Sprintf("%s %s", p.emoji, p.name)
if isSelected {
label = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#FFFFFF")).Background(colorPrimary).Padding(0, 1).Render(label)
} else {
label = lipgloss.NewStyle().Foreground(colorText).Padding(0, 1).Render(label)
}
sb.WriteString(prefix + label + "\n")
}
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("oauth_help")))
return sb.String()
}
func (m oauthTabModel) renderRemoteMode() string {
var sb strings.Builder
providerStyle := lipgloss.NewStyle().Bold(true).Foreground(colorHighlight)
sb.WriteString(providerStyle.Render(fmt.Sprintf(" ✦ %s OAuth", m.providerName)))
sb.WriteString("\n\n")
// Auth URL section
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_auth_url")))
sb.WriteString("\n")
// Wrap URL to fit terminal width
urlStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
maxURLWidth := m.width - 6
if maxURLWidth < 40 {
maxURLWidth = 40
}
wrappedURL := wrapText(m.authURL, maxURLWidth)
for _, line := range wrappedURL {
sb.WriteString(" " + urlStyle.Render(line) + "\n")
}
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("oauth_remote_hint")))
sb.WriteString("\n\n")
// Callback URL input
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_callback_url")))
sb.WriteString("\n")
if m.inputActive {
sb.WriteString(m.callbackInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(" " + T("enter_submit") + " • " + T("esc_cancel")))
} else {
sb.WriteString(helpStyle.Render(T("oauth_press_c")))
}
sb.WriteString("\n\n")
sb.WriteString(warningStyle.Render(T("oauth_waiting")))
return sb.String()
}
// wrapText splits a long string into lines of at most maxWidth characters.
func wrapText(s string, maxWidth int) []string {
if maxWidth <= 0 {
return []string{s}
}
var lines []string
for len(s) > maxWidth {
lines = append(lines, s[:maxWidth])
s = s[maxWidth:]
}
if len(s) > 0 {
lines = append(lines, s)
}
return lines
}

126
internal/tui/styles.go Normal file
View File

@@ -0,0 +1,126 @@
// Package tui provides a terminal-based management interface for CLIProxyAPI.
package tui
import "github.com/charmbracelet/lipgloss"
// Color palette
var (
colorPrimary = lipgloss.Color("#7C3AED") // violet
colorSecondary = lipgloss.Color("#6366F1") // indigo
colorSuccess = lipgloss.Color("#22C55E") // green
colorWarning = lipgloss.Color("#EAB308") // yellow
colorError = lipgloss.Color("#EF4444") // red
colorInfo = lipgloss.Color("#3B82F6") // blue
colorMuted = lipgloss.Color("#6B7280") // gray
colorBg = lipgloss.Color("#1E1E2E") // dark bg
colorSurface = lipgloss.Color("#313244") // slightly lighter
colorText = lipgloss.Color("#CDD6F4") // light text
colorSubtext = lipgloss.Color("#A6ADC8") // dimmer text
colorBorder = lipgloss.Color("#45475A") // border
colorHighlight = lipgloss.Color("#F5C2E7") // pink highlight
)
// Tab bar styles
var (
tabActiveStyle = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#FFFFFF")).
Background(colorPrimary).
Padding(0, 2)
tabInactiveStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Background(colorSurface).
Padding(0, 2)
tabBarStyle = lipgloss.NewStyle().
Background(colorSurface).
PaddingLeft(1).
PaddingBottom(0)
)
// Content styles
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHighlight).
MarginBottom(1)
subtitleStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Italic(true)
labelStyle = lipgloss.NewStyle().
Foreground(colorInfo).
Bold(true).
Width(24)
valueStyle = lipgloss.NewStyle().
Foreground(colorText)
sectionStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(colorBorder).
Padding(1, 2)
errorStyle = lipgloss.NewStyle().
Foreground(colorError).
Bold(true)
successStyle = lipgloss.NewStyle().
Foreground(colorSuccess)
warningStyle = lipgloss.NewStyle().
Foreground(colorWarning)
statusBarStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Background(colorSurface).
PaddingLeft(1).
PaddingRight(1)
helpStyle = lipgloss.NewStyle().
Foreground(colorMuted)
)
// Log level styles
var (
logDebugStyle = lipgloss.NewStyle().Foreground(colorMuted)
logInfoStyle = lipgloss.NewStyle().Foreground(colorInfo)
logWarnStyle = lipgloss.NewStyle().Foreground(colorWarning)
logErrorStyle = lipgloss.NewStyle().Foreground(colorError)
)
// Table styles
var (
tableHeaderStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHighlight).
BorderBottom(true).
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(colorBorder)
tableCellStyle = lipgloss.NewStyle().
Foreground(colorText).
PaddingRight(2)
tableSelectedStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFFFFF")).
Background(colorPrimary).
Bold(true)
)
func logLevelStyle(level string) lipgloss.Style {
switch level {
case "debug":
return logDebugStyle
case "info":
return logInfoStyle
case "warn", "warning":
return logWarnStyle
case "error", "fatal", "panic":
return logErrorStyle
default:
return logInfoStyle
}
}

364
internal/tui/usage_tab.go Normal file
View File

@@ -0,0 +1,364 @@
package tui
import (
"fmt"
"sort"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// usageTabModel displays usage statistics with charts and breakdowns.
type usageTabModel struct {
client *Client
viewport viewport.Model
usage map[string]any
err error
width int
height int
ready bool
}
type usageDataMsg struct {
usage map[string]any
err error
}
func newUsageTabModel(client *Client) usageTabModel {
return usageTabModel{
client: client,
}
}
func (m usageTabModel) Init() tea.Cmd {
return m.fetchData
}
func (m usageTabModel) fetchData() tea.Msg {
usage, err := m.client.GetUsage()
return usageDataMsg{usage: usage, err: err}
}
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case usageDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.usage = msg.usage
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *usageTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m usageTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m usageTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("usage_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("usage_help")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if m.usage == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
usageMap, _ := m.usage["usage"].(map[string]any)
if usageMap == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
totalReqs := int64(getFloat(usageMap, "total_requests"))
successCnt := int64(getFloat(usageMap, "success_count"))
failureCnt := int64(getFloat(usageMap, "failure_count"))
totalTokens := int64(getFloat(usageMap, "total_tokens"))
// ━━━ Overview Cards ━━━
cardWidth := 20
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 16 {
cardWidth = 16
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(3)
// Total Requests
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
))
// Total Tokens
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
))
// RPM
rpm := float64(0)
if totalReqs > 0 {
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
}
}
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
))
// TPM
tpm := float64(0)
if totalTokens > 0 {
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
}
}
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Requests by Hour (ASCII bar chart) ━━━
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
sb.WriteString("\n")
}
// ━━━ Tokens by Hour ━━━
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
sb.WriteString("\n")
}
// ━━━ Requests by Day ━━━
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
sb.WriteString("\n")
}
// ━━━ API Detail Stats ━━━
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for apiName, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
apiReqs := int64(getFloat(apiMap, "total_requests"))
apiToks := int64(getFloat(apiMap, "total_tokens"))
row := fmt.Sprintf(" %-30s %10d %12s",
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
sb.WriteString("\n")
// Per-model breakdown
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
mReqs := int64(getFloat(stats, "total_requests"))
mToks := int64(getFloat(stats, "total_tokens"))
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
truncate(model, 28), mReqs, formatLargeNumber(mToks))
sb.WriteString(tableCellStyle.Render(mRow))
sb.WriteString("\n")
// Token type breakdown from details
sb.WriteString(m.renderTokenBreakdown(stats))
}
}
}
}
}
}
sb.WriteString("\n")
return sb.String()
}
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
tokens, ok := dm["tokens"].(map[string]any)
if !ok {
continue
}
inputTotal += int64(getFloat(tokens, "input_tokens"))
outputTotal += int64(getFloat(tokens, "output_tokens"))
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
}
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
return ""
}
parts := []string{}
if inputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
}
if outputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
}
if cachedTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
}
if reasoningTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
}
return fmt.Sprintf(" │ %s\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
}
// renderBarChart renders a simple ASCII horizontal bar chart.
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
if maxBarWidth < 10 {
maxBarWidth = 10
}
// Sort keys
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
sort.Strings(keys)
// Find max value
maxVal := float64(0)
for _, k := range keys {
v := getFloat(data, k)
if v > maxVal {
maxVal = v
}
}
if maxVal == 0 {
return ""
}
barStyle := lipgloss.NewStyle().Foreground(barColor)
var sb strings.Builder
labelWidth := 12
barAvail := maxBarWidth - labelWidth - 12
if barAvail < 5 {
barAvail = 5
}
for _, k := range keys {
v := getFloat(data, k)
barLen := int(v / maxVal * float64(barAvail))
if barLen < 1 && v > 0 {
barLen = 1
}
bar := strings.Repeat("█", barLen)
label := k
if len(label) > labelWidth {
label = label[:labelWidth]
}
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
labelWidth, label,
barStyle.Render(bar),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
))
}
return sb.String()
}

View File

@@ -184,6 +184,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
}
if o.Websockets != n.Websockets {
changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets))
}
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
}

View File

@@ -160,6 +160,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL
}
if ck.Websockets {
attrs["websockets"] = "true"
}
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}

View File

@@ -231,10 +231,11 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
Config: &config.Config{
CodexKey: []config.CodexKey{
{
APIKey: "codex-key-123",
Prefix: "dev",
BaseURL: "https://api.openai.com",
ProxyURL: "http://proxy.local",
APIKey: "codex-key-123",
Prefix: "dev",
BaseURL: "https://api.openai.com",
ProxyURL: "http://proxy.local",
Websockets: true,
},
},
},
@@ -259,6 +260,9 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
if auths[0].ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
}
if auths[0].Attributes["websockets"] != "true" {
t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"])
}
}
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {

View File

@@ -52,6 +52,45 @@ const (
defaultStreamingBootstrapRetries = 0
)
type pinnedAuthContextKey struct{}
type selectedAuthCallbackContextKey struct{}
type executionSessionContextKey struct{}
// WithPinnedAuthID returns a child context that requests execution on a specific auth ID.
func WithPinnedAuthID(ctx context.Context, authID string) context.Context {
authID = strings.TrimSpace(authID)
if authID == "" {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, pinnedAuthContextKey{}, authID)
}
// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID.
func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context {
if callback == nil {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback)
}
// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID.
func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context {
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, executionSessionContextKey{}, sessionID)
}
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
func BuildErrorResponseBody(status int, errText string) []byte {
@@ -152,7 +191,59 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
if key == "" {
key = uuid.NewString()
}
return map[string]any{idempotencyKeyMetadataKey: key}
meta := map[string]any{idempotencyKeyMetadataKey: key}
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
}
if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil {
meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback
}
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
}
return meta
}
func pinnedAuthIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(pinnedAuthContextKey{})
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) {
if ctx == nil {
return nil
}
raw := ctx.Value(selectedAuthCallbackContextKey{})
if callback, ok := raw.(func(string)); ok && callback != nil {
return callback
}
return nil
}
func executionSessionIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(executionSessionContextKey{})
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
// BaseAPIHandler contains the handlers for API endpoints.

View File

@@ -122,6 +122,82 @@ func (e *payloadThenErrorStreamExecutor) Calls() int {
return e.calls
}
type authAwareStreamExecutor struct {
mu sync.Mutex
calls int
authIDs []string
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
_ = ctx
_ = req
_ = opts
ch := make(chan coreexecutor.StreamChunk, 1)
authID := ""
if auth != nil {
authID = auth.ID
}
e.mu.Lock()
e.calls++
e.authIDs = append(e.authIDs, authID)
e.mu.Unlock()
if authID == "auth1" {
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "unauthorized",
Message: "unauthorized",
Retryable: false,
HTTPStatus: http.StatusUnauthorized,
},
}
close(ch)
return ch, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return ch, nil
}
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
func (e *authAwareStreamExecutor) AuthIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.authIDs))
copy(out, e.authIDs)
return out
}
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
@@ -252,3 +328,128 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
}
}
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
ctx := WithPinnedAuthID(context.Background(), "auth1")
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
var gotErr error
for msg := range errChan {
if msg != nil && msg.Error != nil {
gotErr = msg.Error
}
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
if gotErr == nil {
t.Fatalf("expected terminal error, got nil")
}
authIDs := executor.AuthIDs()
if len(authIDs) == 0 {
t.Fatalf("expected at least one upstream attempt")
}
for _, authID := range authIDs {
if authID != "auth1" {
t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs)
}
}
}
func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 0,
},
}, manager)
selectedAuthID := ""
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
selectedAuthID = authID
})
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if selectedAuthID != "auth2" {
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
}
}

View File

@@ -332,6 +332,7 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
// Check if this chunk has any meaningful content
hasContent := false
hasUsage := root.Get("usage").Exists()
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
// Check if delta has content or finish_reason
@@ -350,8 +351,8 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
})
}
// If no meaningful content, return nil to indicate this chunk should be skipped
if !hasContent {
// If no meaningful content and no usage, return nil to indicate this chunk should be skipped
if !hasContent && !hasUsage {
return nil
}
@@ -410,6 +411,11 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
// Copy usage if present
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
}
return []byte(out)
}

View File

@@ -0,0 +1,662 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
wsRequestTypeCreate = "response.create"
wsRequestTypeAppend = "response.append"
wsEventTypeError = "error"
wsEventTypeCompleted = "response.completed"
wsEventTypeDone = "response.done"
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
)
var responsesWebsocketUpgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// ResponsesWebsocket handles websocket requests for /v1/responses.
// It accepts `response.create` and `response.append` requests and streams
// response events back as JSON websocket text messages.
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
if err != nil {
return
}
passthroughSessionID := uuid.NewString()
clientRemoteAddr := ""
if c != nil && c.Request != nil {
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
}
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
var wsTerminateErr error
var wsBodyLog strings.Builder
defer func() {
if wsTerminateErr != nil {
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else {
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
}
if h != nil && h.AuthManager != nil {
h.AuthManager.CloseExecutionSession(passthroughSessionID)
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
}
setWebsocketRequestBody(c, wsBodyLog.String())
if errClose := conn.Close(); errClose != nil {
log.Warnf("responses websocket: close connection error: %v", errClose)
}
}()
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
for {
msgType, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
wsTerminateErr = errReadMessage
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
} else {
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
}
return
}
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
continue
}
// log.Infof(
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
// passthroughSessionID,
// msgType,
// websocketPayloadEventType(payload),
// websocketPayloadPreview(payload),
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
}
var requestJSON []byte
var updatedLastRequest []byte
var errMsg *interfaces.ErrorMessage
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
payload,
lastRequest,
lastResponseOutput,
allowIncrementalInputWithPreviousResponseID,
)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
passthroughSessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
passthroughSessionID,
websocketPayloadEventType(errorPayload),
errWrite,
)
return
}
continue
}
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
if pinnedAuthID != "" {
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
} else {
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
pinnedAuthID = strings.TrimSpace(authID)
})
}
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
lastResponseOutput = completedOutput
}
}
func websocketUpgradeHeaders(req *http.Request) http.Header {
headers := http.Header{}
if req == nil {
return headers
}
// Keep the same sticky turn-state across reconnects when provided by the client.
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
if turnState != "" {
headers.Set(wsTurnStateHeader, turnState)
}
return headers
}
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
}
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
switch requestType {
case wsRequestTypeCreate:
// log.Infof("responses websocket: response.create request")
if len(lastRequest) == 0 {
return normalizeResponseCreateRequest(rawJSON)
}
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
case wsRequestTypeAppend:
// log.Infof("responses websocket: response.append request")
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
default:
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
}
}
}
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
if !gjson.GetBytes(normalized, "input").Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
}
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
if modelName == "" {
return nil, nil, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("missing model in response.create request"),
}
}
return normalized, bytes.Clone(normalized), nil
}
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
if len(lastRequest) == 0 {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("websocket request received before response.create"),
}
}
nextInput := gjson.GetBytes(rawJSON, "input")
if !nextInput.Exists() || !nextInput.IsArray() {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("websocket request requires array field: input"),
}
}
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
// Do not expand it into a full input transcript; upstream expects the incremental payload.
if allowIncrementalInputWithPreviousResponseID {
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, bytes.Clone(normalized), nil
}
}
existingInput := gjson.GetBytes(lastRequest, "input")
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
}
}
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
var errSet error
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
if errSet != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
}
}
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, bytes.Clone(normalized), nil
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(metadata) == 0 {
return false
}
raw, ok := metadata["websockets"]
if !ok || raw == nil {
return false
}
switch value := raw.(type) {
case bool:
return value
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
if errParse == nil {
return parsed
}
default:
}
return false
}
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
existingRaw = strings.TrimSpace(existingRaw)
appendRaw = strings.TrimSpace(appendRaw)
if existingRaw == "" {
existingRaw = "[]"
}
if appendRaw == "" {
appendRaw = "[]"
}
var existing []json.RawMessage
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
return "", err
}
var appendItems []json.RawMessage
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
return "", err
}
merged := append(existing, appendItems...)
out, err := json.Marshal(merged)
if err != nil {
return "", err
}
return string(out), nil
}
func normalizeJSONArrayRaw(raw []byte) string {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" {
return "[]"
}
result := gjson.Parse(trimmed)
if result.Type == gjson.JSON && result.IsArray() {
return trimmed
}
return "[]"
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
c *gin.Context,
conn *websocket.Conn,
cancel handlers.APIHandlerCancelFunc,
data <-chan []byte,
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
sessionID string,
) ([]byte, error) {
completed := false
completedOutput := []byte("[]")
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
continue
}
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
// log.Warnf(
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
// sessionID,
// websocketPayloadEventType(errorPayload),
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
}
}
if errMsg != nil {
cancel(errMsg.Error)
} else {
cancel(nil)
}
return completedOutput, nil
case chunk, ok := <-data:
if !ok {
if !completed {
errMsg := &interfaces.ErrorMessage{
StatusCode: http.StatusRequestTimeout,
Error: fmt.Errorf("stream closed before response.completed"),
}
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(errorPayload),
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
}
cancel(errMsg.Error)
return completedOutput, nil
}
cancel(nil)
return completedOutput, nil
}
payloads := websocketJSONPayloadsFromChunk(chunk)
for i := range payloads {
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
completed = true
completedOutput = responseCompletedOutputFromPayload(payloads[i])
}
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
// websocket.TextMessage,
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(payloads[i]),
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
}
}
}
}
}
func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
return bytes.Clone([]byte(output.Raw))
}
return []byte("[]")
}
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
payloads := make([][]byte, 0, 2)
lines := bytes.Split(chunk, []byte("\n"))
for i := range lines {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
continue
}
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimSpace(line[len("data:"):])
}
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
continue
}
if json.Valid(line) {
payloads = append(payloads, bytes.Clone(line))
}
}
if len(payloads) > 0 {
return payloads
}
trimmed := bytes.TrimSpace(chunk)
if bytes.HasPrefix(trimmed, []byte("data:")) {
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
}
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
payloads = append(payloads, bytes.Clone(trimmed))
}
return payloads
}
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
status := http.StatusInternalServerError
errText := http.StatusText(status)
if errMsg != nil {
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
errText = http.StatusText(status)
}
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
errText = errMsg.Error.Error()
}
}
body := handlers.BuildErrorResponseBody(status, errText)
payload := map[string]any{
"type": wsEventTypeError,
"status": status,
}
if errMsg != nil && errMsg.Addon != nil {
headers := map[string]any{}
for key, values := range errMsg.Addon {
if len(values) == 0 {
continue
}
headers[key] = values[0]
}
if len(headers) > 0 {
payload["headers"] = headers
}
}
if len(body) > 0 && json.Valid(body) {
var decoded map[string]any
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
if inner, ok := decoded["error"]; ok {
payload["error"] = inner
} else {
payload["error"] = decoded
}
}
}
if _, ok := payload["error"]; !ok {
payload["error"] = map[string]any{
"type": "server_error",
"message": errText,
}
}
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return data, conn.WriteMessage(websocket.TextMessage, data)
}
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
if builder == nil {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
builder.WriteString("\n")
}
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
}
func websocketPayloadEventType(payload []byte) string {
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if eventType == "" {
return "-"
}
return eventType
}
func websocketPayloadPreview(payload []byte) string {
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return "<empty>"
}
preview := trimmedPayload
if len(preview) > wsPayloadLogMaxSize {
preview = preview[:wsPayloadLogMaxSize]
}
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
if len(trimmedPayload) > wsPayloadLogMaxSize {
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
}
return previewText
}
func setWebsocketRequestBody(c *gin.Context, body string) {
if c == nil {
return
}
trimmedBody := strings.TrimSpace(body)
if trimmedBody == "" {
return
}
c.Set(wsRequestBodyKey, []byte(trimmedBody))
}
func markAPIResponseTimestamp(c *gin.Context) {
if c == nil {
return
}
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
return
}
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
}

View File

@@ -0,0 +1,249 @@
package openai
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized create request must not include type field")
}
if !gjson.GetBytes(normalized, "stream").Bool() {
t.Fatalf("normalized create request must force stream=true")
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
if !bytes.Equal(last, normalized) {
t.Fatalf("last request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized subsequent create request must not include type field")
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 4 {
t.Fatalf("merged input len = %d, want 4", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "fc-1" ||
input[2].Get("id").String() != "assistant-1" ||
input[3].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized request must not include type field")
}
if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" {
t.Fatalf("previous_response_id must be preserved in incremental mode")
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 1 {
t.Fatalf("incremental input len = %d, want 1", len(input))
}
if input[0].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String())
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
if gjson.GetBytes(normalized, "instructions").String() != "be helpful" {
t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String())
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
t.Fatalf("previous_response_id must be removed when incremental mode is disabled")
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 4 {
t.Fatalf("merged input len = %d, want 4", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "fc-1" ||
input[2].Get("id").String() != "assistant-1" ||
input[3].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1"},
{"type":"function_call_output","id":"tool-out-1"}
]`)
raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 5 {
t.Fatalf("merged input len = %d, want 5", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "assistant-1" ||
input[2].Get("id").String() != "tool-out-1" ||
input[3].Get("id").String() != "msg-2" ||
input[4].Get("id").String() != "msg-3" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized append request")
}
}
func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) {
raw := []byte(`{"type":"response.append","input":[]}`)
_, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
if errMsg == nil {
t.Fatalf("expected error for append without previous request")
}
if errMsg.StatusCode != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest)
}
}
func TestWebsocketJSONPayloadsFromChunk(t *testing.T) {
chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n")
payloads := websocketJSONPayloadsFromChunk(chunk)
if len(payloads) != 1 {
t.Fatalf("payloads len = %d, want 1", len(payloads))
}
if gjson.GetBytes(payloads[0], "type").String() != "response.created" {
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
}
}
func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) {
chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`)
payloads := websocketJSONPayloadsFromChunk(chunk)
if len(payloads) != 1 {
t.Fatalf("payloads len = %d, want 1", len(payloads))
}
if gjson.GetBytes(payloads[0], "type").String() != "response.completed" {
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
}
}
func TestResponseCompletedOutputFromPayload(t *testing.T) {
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`)
output := responseCompletedOutputFromPayload(payload)
items := gjson.ParseBytes(output).Array()
if len(items) != 1 {
t.Fatalf("output len = %d, want 1", len(items))
}
if items[0].Get("id").String() != "out-1" {
t.Fatalf("unexpected output id: %s", items[0].Get("id").String())
}
}
func TestAppendWebsocketEvent(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"))
appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}"))
got := builder.String()
if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") {
t.Fatalf("request event not found in body: %s", got)
}
if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") {
t.Fatalf("response event not found in body: %s", got)
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
setWebsocketRequestBody(c, " \n ")
if _, exists := c.Get(wsRequestBodyKey); exists {
t.Fatalf("request body key should not be set for empty body")
}
setWebsocketRequestBody(c, "event body")
value, exists := c.Get(wsRequestBodyKey)
if !exists {
t.Fatalf("request body key not set")
}
bodyBytes, ok := value.([]byte)
if !ok {
t.Fatalf("request body key type mismatch")
}
if string(bodyBytes) != "event body" {
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
}
}

View File

@@ -41,6 +41,17 @@ type ProviderExecutor interface {
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
}
// ExecutionSessionCloser allows executors to release per-session runtime resources.
type ExecutionSessionCloser interface {
CloseExecutionSession(sessionID string)
}
const (
// CloseAllExecutionSessionsID asks an executor to release all active execution sessions.
// Executors that do not support this marker may ignore it.
CloseAllExecutionSessionsID = "__all_execution_sessions__"
)
// RefreshEvaluator allows runtime state to override refresh decisions.
type RefreshEvaluator interface {
ShouldRefresh(now time.Time, auth *Auth) bool
@@ -389,9 +400,23 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
if executor == nil {
return
}
provider := strings.TrimSpace(executor.Identifier())
if provider == "" {
return
}
var replaced ProviderExecutor
m.mu.Lock()
defer m.mu.Unlock()
m.executors[executor.Identifier()] = executor
replaced = m.executors[provider]
m.executors[provider] = executor
m.mu.Unlock()
if replaced == nil || replaced == executor {
return
}
if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(CloseAllExecutionSessionsID)
}
}
// UnregisterExecutor removes the executor associated with the provider key.
@@ -581,6 +606,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -636,6 +662,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -691,6 +718,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -794,6 +822,38 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
}
}
func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey]
if !ok || raw == nil {
return ""
}
switch val := raw.(type) {
case string:
return strings.TrimSpace(val)
case []byte:
return strings.TrimSpace(string(val))
default:
return ""
}
}
func publishSelectedAuthMetadata(meta map[string]any, authID string) {
if len(meta) == 0 {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID
if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil {
callback(authID)
}
}
func rewriteModelForAuth(model string, auth *Auth) string {
if auth == nil || model == "" {
return model
@@ -1550,7 +1610,56 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
return auth.Clone(), true
}
// Executor returns the registered provider executor for a provider key.
func (m *Manager) Executor(provider string) (ProviderExecutor, bool) {
if m == nil {
return nil, false
}
provider = strings.TrimSpace(provider)
if provider == "" {
return nil, false
}
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
lowerProvider := strings.ToLower(provider)
if lowerProvider != provider {
executor, okExecutor = m.executors[lowerProvider]
}
}
m.mu.RUnlock()
if !okExecutor || executor == nil {
return nil, false
}
return executor, true
}
// CloseExecutionSession asks all registered executors to release the supplied execution session.
func (m *Manager) CloseExecutionSession(sessionID string) {
sessionID = strings.TrimSpace(sessionID)
if m == nil || sessionID == "" {
return
}
m.mu.RLock()
executors := make([]ProviderExecutor, 0, len(m.executors))
for _, exec := range m.executors {
executors = append(executors, exec)
}
m.mu.RUnlock()
for i := range executors {
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(sessionID)
}
}
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
@@ -1571,6 +1680,9 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
if candidate.Provider != provider || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
@@ -1606,6 +1718,8 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
@@ -1633,6 +1747,9 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
if candidate == nil || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
if providerKey == "" {
continue

View File

@@ -0,0 +1,100 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type replaceAwareExecutor struct {
id string
mu sync.Mutex
closedSessionIDs []string
}
func (e *replaceAwareExecutor) Identifier() string {
return e.id
}
func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
ch := make(chan cliproxyexecutor.StreamChunk)
close(ch)
return ch, nil
}
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *replaceAwareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, nil
}
func (e *replaceAwareExecutor) CloseExecutionSession(sessionID string) {
e.mu.Lock()
defer e.mu.Unlock()
e.closedSessionIDs = append(e.closedSessionIDs, sessionID)
}
func (e *replaceAwareExecutor) ClosedSessionIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.closedSessionIDs))
copy(out, e.closedSessionIDs)
return out
}
func TestManagerRegisterExecutorClosesReplacedExecutionSessions(t *testing.T) {
t.Parallel()
manager := NewManager(nil, nil, nil)
replaced := &replaceAwareExecutor{id: "codex"}
current := &replaceAwareExecutor{id: "codex"}
manager.RegisterExecutor(replaced)
manager.RegisterExecutor(current)
closed := replaced.ClosedSessionIDs()
if len(closed) != 1 {
t.Fatalf("expected replaced executor close calls = 1, got %d", len(closed))
}
if closed[0] != CloseAllExecutionSessionsID {
t.Fatalf("expected close marker %q, got %q", CloseAllExecutionSessionsID, closed[0])
}
if len(current.ClosedSessionIDs()) != 0 {
t.Fatalf("expected current executor to stay open")
}
}
func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
t.Parallel()
manager := NewManager(nil, nil, nil)
current := &replaceAwareExecutor{id: "codex"}
manager.RegisterExecutor(current)
resolved, okResolved := manager.Executor("CODEX")
if !okResolved {
t.Fatal("expected registered executor to be found")
}
if resolved != current {
t.Fatal("expected resolved executor to match registered executor")
}
_, okMissing := manager.Executor("unknown")
if okMissing {
t.Fatal("expected unknown provider lookup to fail")
}
}

View File

@@ -134,6 +134,62 @@ func canonicalModelKey(model string) string {
return modelName
}
func authWebsocketsEnabled(auth *Auth) bool {
if auth == nil {
return false
}
if len(auth.Attributes) > 0 {
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(auth.Metadata) == 0 {
return false
}
raw, ok := auth.Metadata["websockets"]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case bool:
return v
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
if errParse == nil {
return parsed
}
default:
}
return false
}
func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth {
if len(available) == 0 {
return available
}
if !cliproxyexecutor.DownstreamWebsocket(ctx) {
return available
}
if !strings.EqualFold(strings.TrimSpace(provider), "codex") {
return available
}
wsEnabled := make([]*Auth, 0, len(available))
for i := 0; i < len(available); i++ {
candidate := available[i]
if authWebsocketsEnabled(candidate) {
wsEnabled = append(wsEnabled, candidate)
}
}
if len(wsEnabled) > 0 {
return wsEnabled
}
return available
}
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
available = make(map[int][]*Auth)
for i := 0; i < len(auths); i++ {
@@ -193,13 +249,13 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
// Pick selects the next available auth for the provider in a round-robin manner.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
available = preferCodexWebsocketAuths(ctx, provider, available)
key := provider + ":" + canonicalModelKey(model)
s.mu.Lock()
if s.cursors == nil {
@@ -226,13 +282,13 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
available = preferCodexWebsocketAuths(ctx, provider, available)
return available[0], nil
}

View File

@@ -213,6 +213,23 @@ func (a *Auth) DisableCoolingOverride() (bool, bool) {
return false, false
}
// ToolPrefixDisabled returns whether the proxy_ tool name prefix should be
// skipped for this auth. When true, tool names are sent to Anthropic unchanged.
// The value is read from metadata key "tool_prefix_disabled" (or "tool-prefix-disabled").
func (a *Auth) ToolPrefixDisabled() bool {
if a == nil || a.Metadata == nil {
return false
}
for _, key := range []string{"tool_prefix_disabled", "tool-prefix-disabled"} {
if val, ok := a.Metadata[key]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed
}
}
}
return false
}
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
// The value is read from metadata key "request_retry" (or legacy "request-retry").
func (a *Auth) RequestRetryOverride() (int, bool) {

View File

@@ -0,0 +1,35 @@
package auth
import "testing"
func TestToolPrefixDisabled(t *testing.T) {
var a *Auth
if a.ToolPrefixDisabled() {
t.Error("nil auth should return false")
}
a = &Auth{}
if a.ToolPrefixDisabled() {
t.Error("empty auth should return false")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": true}}
if !a.ToolPrefixDisabled() {
t.Error("should return true when set to true")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": "true"}}
if !a.ToolPrefixDisabled() {
t.Error("should return true when set to string 'true'")
}
a = &Auth{Metadata: map[string]any{"tool-prefix-disabled": true}}
if !a.ToolPrefixDisabled() {
t.Error("should return true with kebab-case key")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": false}}
if a.ToolPrefixDisabled() {
t.Error("should return false when set to false")
}
}

View File

@@ -0,0 +1,23 @@
package executor
import "context"
type downstreamWebsocketContextKey struct{}
// WithDownstreamWebsocket marks the current request as coming from a downstream websocket connection.
func WithDownstreamWebsocket(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, downstreamWebsocketContextKey{}, true)
}
// DownstreamWebsocket reports whether the current request originates from a downstream websocket connection.
func DownstreamWebsocket(ctx context.Context) bool {
if ctx == nil {
return false
}
raw := ctx.Value(downstreamWebsocketContextKey{})
enabled, ok := raw.(bool)
return ok && enabled
}

View File

@@ -10,6 +10,17 @@ import (
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
const RequestedModelMetadataKey = "requested_model"
const (
// PinnedAuthMetadataKey locks execution to a specific auth ID.
PinnedAuthMetadataKey = "pinned_auth_id"
// SelectedAuthMetadataKey stores the auth ID selected by the scheduler.
SelectedAuthMetadataKey = "selected_auth_id"
// SelectedAuthCallbackMetadataKey carries an optional callback invoked with the selected auth ID.
SelectedAuthCallbackMetadataKey = "selected_auth_callback"
// ExecutionSessionMetadataKey identifies a long-lived downstream execution session.
ExecutionSessionMetadataKey = "execution_session_id"
)
// Request encapsulates the translated payload that will be sent to a provider executor.
type Request struct {
// Model is the upstream model identifier after translation.

View File

@@ -325,6 +325,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
if _, err := s.coreManager.Update(ctx, existing); err != nil {
log.Errorf("failed to disable auth %s: %v", id, err)
}
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
s.ensureExecutorsForAuth(existing)
}
}
}
@@ -357,7 +360,24 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
}
func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
if s == nil || a == nil {
s.ensureExecutorsForAuthWithMode(a, false)
}
func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace bool) {
if s == nil || s.coreManager == nil || a == nil {
return
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "codex") {
if !forceReplace {
existingExecutor, hasExecutor := s.coreManager.Executor("codex")
if hasExecutor {
_, isCodexAutoExecutor := existingExecutor.(*executor.CodexAutoExecutor)
if isCodexAutoExecutor {
return
}
}
}
s.coreManager.RegisterExecutor(executor.NewCodexAutoExecutor(s.cfg))
return
}
// Skip disabled auth entries when (re)binding executors.
@@ -392,8 +412,6 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
case "claude":
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "codex":
s.coreManager.RegisterExecutor(executor.NewCodexExecutor(s.cfg))
case "qwen":
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
case "iflow":
@@ -415,8 +433,15 @@ func (s *Service) rebindExecutors() {
return
}
auths := s.coreManager.List()
reboundCodex := false
for _, auth := range auths {
s.ensureExecutorsForAuth(auth)
if auth != nil && strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
if reboundCodex {
continue
}
reboundCodex = true
}
s.ensureExecutorsForAuthWithMode(auth, true)
}
}

View File

@@ -0,0 +1,64 @@
package cliproxy
import (
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestEnsureExecutorsForAuth_CodexDoesNotReplaceInNormalMode(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "codex-auth-1",
Provider: "codex",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
firstExecutor, okFirst := service.coreManager.Executor("codex")
if !okFirst || firstExecutor == nil {
t.Fatal("expected codex executor after first bind")
}
service.ensureExecutorsForAuth(auth)
secondExecutor, okSecond := service.coreManager.Executor("codex")
if !okSecond || secondExecutor == nil {
t.Fatal("expected codex executor after second bind")
}
if firstExecutor != secondExecutor {
t.Fatal("expected codex executor to stay unchanged in normal mode")
}
}
func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "codex-auth-2",
Provider: "codex",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
firstExecutor, okFirst := service.coreManager.Executor("codex")
if !okFirst || firstExecutor == nil {
t.Fatal("expected codex executor after first bind")
}
service.ensureExecutorsForAuthWithMode(auth, true)
secondExecutor, okSecond := service.coreManager.Executor("codex")
if !okSecond || secondExecutor == nil {
t.Fatal("expected codex executor after forced rebind")
}
if firstExecutor == secondExecutor {
t.Fatal("expected codex executor replacement in force mode")
}
}