mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 08:55:06 +08:00
Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c040445af | ||
|
|
fff866424e | ||
|
|
2d12becfd6 | ||
|
|
252f7e0751 | ||
|
|
b2b17528cb | ||
|
|
55f938164b | ||
|
|
76294f0c59 | ||
|
|
2bcee78c6e | ||
|
|
7fe8246a9f | ||
|
|
93fe58e31e | ||
|
|
e5b5dc870f | ||
|
|
a54877c023 | ||
|
|
bb86a0c0c4 | ||
|
|
5fa23c7f41 | ||
|
|
73dc0b10b8 | ||
|
|
2ea95266e3 | ||
|
|
9261b0c20b | ||
|
|
7cc725496e | ||
|
|
709d999f9f | ||
|
|
24c18614f0 | ||
|
|
603f06a762 | ||
|
|
98f0a3e3bd | ||
|
|
453aaf8774 | ||
|
|
1b1ab1fb9b | ||
|
|
a9d0bb72da | ||
|
|
2c8821891c | ||
|
|
0a2555b0f3 | ||
|
|
020df41efe | ||
|
|
f31f7f701a | ||
|
|
b5fe78eb70 | ||
|
|
d1f667cf8d | ||
|
|
54ad7c1b6b | ||
|
|
55789df275 | ||
|
|
46a6782065 | ||
|
|
c359f61859 | ||
|
|
908c8eab5b | ||
|
|
f5f2c69233 | ||
|
|
63d4de5eea | ||
|
|
ae1e8a5191 | ||
|
|
b3ccc55f09 | ||
|
|
1ce56d7413 | ||
|
|
41a78be3a2 | ||
|
|
1ff5de9a31 | ||
|
|
46a6853046 | ||
|
|
4b2d40bd67 | ||
|
|
575881cb59 | ||
|
|
f361b2716d | ||
|
|
58e09f8e5f | ||
|
|
a146c6c0aa | ||
|
|
4c133d3ea9 | ||
|
|
f3ccd85ba1 | ||
|
|
dc279de443 | ||
|
|
bf1634bda0 | ||
|
|
166d2d24d9 | ||
|
|
4cbcc835d1 | ||
|
|
b93026d83a | ||
|
|
5ed2133ff9 | ||
|
|
bb9fe52f1e | ||
|
|
afe4c1bfb7 | ||
|
|
865af9f19e | ||
|
|
2b97cb98b5 |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '>=1.24.0'
|
||||
go-version: '>=1.26.0'
|
||||
cache: true
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.24-alpine AS builder
|
||||
FROM golang:1.26-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -146,6 +146,10 @@ A Windows tray application implemented using PowerShell scripts, without relying
|
||||
|
||||
霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
|
||||
|
||||
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
|
||||
|
||||
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
@@ -145,6 +145,10 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
|
||||
|
||||
霖君是一款用于管理AI编程助手的跨平台桌面应用,支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具,本地代理实现多账户配额跟踪和一键配置。
|
||||
|
||||
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
|
||||
|
||||
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -68,6 +70,8 @@ func main() {
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
var password string
|
||||
var tuiMode bool
|
||||
var standalone bool
|
||||
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
@@ -84,6 +88,8 @@ func main() {
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
out := flag.CommandLine.Output()
|
||||
@@ -479,8 +485,83 @@ func main() {
|
||||
cmd.WaitForCloudDeploy()
|
||||
return
|
||||
}
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
if tuiMode {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
hook := tui.NewLogHook(2000)
|
||||
hook.SetFormatter(&logging.LogFormatter{})
|
||||
log.AddHook(hook)
|
||||
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
origLogOutput := log.StandardLogger().Out
|
||||
log.SetOutput(io.Discard)
|
||||
|
||||
devNull, errOpenDevNull := os.Open(os.DevNull)
|
||||
if errOpenDevNull == nil {
|
||||
os.Stdout = devNull
|
||||
os.Stderr = devNull
|
||||
}
|
||||
|
||||
restoreIO := func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
log.SetOutput(origLogOutput)
|
||||
if devNull != nil {
|
||||
_ = devNull.Close()
|
||||
}
|
||||
}
|
||||
|
||||
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
|
||||
if password == "" {
|
||||
password = localMgmtPassword
|
||||
}
|
||||
|
||||
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
|
||||
|
||||
client := tui.NewClient(cfg.Port, password)
|
||||
ready := false
|
||||
backoff := 100 * time.Millisecond
|
||||
for i := 0; i < 30; i++ {
|
||||
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
|
||||
ready = true
|
||||
break
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
if backoff < time.Second {
|
||||
backoff = time.Duration(float64(backoff) * 1.5)
|
||||
}
|
||||
}
|
||||
|
||||
if !ready {
|
||||
restoreIO()
|
||||
cancel()
|
||||
<-done
|
||||
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
|
||||
return
|
||||
}
|
||||
|
||||
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
|
||||
restoreIO()
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
} else {
|
||||
restoreIO()
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-done
|
||||
} else {
|
||||
// Default TUI mode: pure management client.
|
||||
// The proxy server must already be running.
|
||||
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,6 +156,14 @@ nonstream-keepalive-interval: 0
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# timeout: "600"
|
||||
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
|
||||
23
go.mod
23
go.mod
@@ -1,9 +1,13 @@
|
||||
module github.com/router-for-me/CLIProxyAPI/v6
|
||||
|
||||
go 1.24.0
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/charmbracelet/bubbles v1.0.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
|
||||
@@ -31,8 +35,16 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
@@ -40,6 +52,7 @@ require (
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||
@@ -56,19 +69,27 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sergi/go-diff v1.4.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
|
||||
45
go.sum
45
go.sum
@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
|
||||
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
@@ -99,8 +125,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||
@@ -112,6 +144,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||
@@ -120,6 +158,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
@@ -159,17 +199,22 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
|
||||
@@ -808,6 +808,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
|
||||
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
ProxyURL *string `json:"proxy_url"`
|
||||
Priority *int `json:"priority"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find auth by name or ID
|
||||
var targetAuth *coreauth.Auth
|
||||
if auth, ok := h.authManager.GetByID(name); ok {
|
||||
targetAuth = auth
|
||||
} else {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name {
|
||||
targetAuth = auth
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if targetAuth == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
if req.Prefix != nil {
|
||||
targetAuth.Prefix = *req.Prefix
|
||||
changed = true
|
||||
}
|
||||
if req.ProxyURL != nil {
|
||||
targetAuth.ProxyURL = *req.ProxyURL
|
||||
changed = true
|
||||
}
|
||||
if req.Priority != nil {
|
||||
if targetAuth.Metadata == nil {
|
||||
targetAuth.Metadata = make(map[string]any)
|
||||
}
|
||||
if *req.Priority == 0 {
|
||||
delete(targetAuth.Metadata, "priority")
|
||||
} else {
|
||||
targetAuth.Metadata["priority"] = *req.Priority
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
|
||||
return
|
||||
}
|
||||
|
||||
targetAuth.UpdatedAt = time.Now()
|
||||
|
||||
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
@@ -1188,6 +1269,30 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
ts.Checked = true
|
||||
} else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") {
|
||||
ts.Auto = false
|
||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
@@ -2036,7 +2141,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
|
||||
@@ -28,8 +28,7 @@ func (h *Handler) GetConfig(c *gin.Context) {
|
||||
c.JSON(200, gin.H{})
|
||||
return
|
||||
}
|
||||
cfgCopy := *h.cfg
|
||||
c.JSON(200, &cfgCopy)
|
||||
c.JSON(200, new(*h.cfg))
|
||||
}
|
||||
|
||||
type releaseInfo struct {
|
||||
|
||||
@@ -15,10 +15,12 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
||||
// logger, it still captures data so that upstream errors can be persisted.
|
||||
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
||||
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Method == http.MethodGet {
|
||||
if shouldSkipMethodForRequestLogging(c.Request) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
loggerEnabled := logger.IsEnabled()
|
||||
|
||||
// Capture request information
|
||||
requestInfo, err := captureRequestInfo(c)
|
||||
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
||||
if err != nil {
|
||||
// Log error but continue processing
|
||||
// In a real implementation, you might want to use a proper logger here
|
||||
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
if !logger.IsEnabled() {
|
||||
if !loggerEnabled {
|
||||
wrapper.logOnErrorOnly = true
|
||||
}
|
||||
c.Writer = wrapper
|
||||
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
||||
if req == nil {
|
||||
return true
|
||||
}
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
return !isResponsesWebsocketUpgrade(req)
|
||||
}
|
||||
|
||||
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
||||
if req == nil || req.URL == nil {
|
||||
return false
|
||||
}
|
||||
if req.URL.Path != "/v1/responses" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
||||
}
|
||||
|
||||
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
||||
if loggerEnabled {
|
||||
return true
|
||||
}
|
||||
if req == nil || req.Body == nil {
|
||||
return false
|
||||
}
|
||||
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return false
|
||||
}
|
||||
if req.ContentLength <= 0 {
|
||||
return false
|
||||
}
|
||||
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
||||
}
|
||||
|
||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||
// It captures the URL, method, headers, and body. The request body is read and then
|
||||
// restored so that it can be processed by subsequent handlers.
|
||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
||||
// Capture URL with sensitive query parameters masked
|
||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
url := c.Request.URL.Path
|
||||
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
|
||||
// Capture request body
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
if captureBody && c.Request.Body != nil {
|
||||
// Read the body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
|
||||
138
internal/api/middleware/request_logging_test.go
Normal file
138
internal/api/middleware/request_logging_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
skip bool
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
req: nil,
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "post request should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "plain get should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/models"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "responses websocket upgrade should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{"Upgrade": []string{"websocket"}},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "responses get without upgrade should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
||||
if got != tests[i].skip {
|
||||
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldCaptureRequestBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
loggerEnabled bool
|
||||
req *http.Request
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "logger enabled always captures",
|
||||
loggerEnabled: true,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nil request",
|
||||
loggerEnabled: false,
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "small known size json in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: 2,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large known size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unknown size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multipart skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: 1,
|
||||
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
||||
if got != tests[i].want {
|
||||
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
|
||||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
if c != nil {
|
||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return w.requestInfo.Body
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if len(w.requestInfo.Body) > 0 {
|
||||
requestBody = w.requestInfo.Body
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
|
||||
43
internal/api/middleware/response_writer_test.go
Normal file
43
internal/api/middleware/response_writer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{
|
||||
requestInfo: &RequestInfo{Body: []byte("original-body")},
|
||||
}
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "original-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "original-body")
|
||||
}
|
||||
|
||||
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
|
||||
body = wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-as-string" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||
}
|
||||
}
|
||||
@@ -127,8 +127,7 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
m.lastConfig = &settingsCopy
|
||||
m.lastConfig = new(settings)
|
||||
|
||||
// Initialize localhost restriction setting (hot-reloadable)
|
||||
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
||||
|
||||
@@ -284,8 +284,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
optionState.routerConfigurator(engine, s.handlers, cfg)
|
||||
}
|
||||
|
||||
// Register management routes when configuration or environment secrets are available.
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
|
||||
// Register management routes when configuration or environment secrets are available,
|
||||
// or when a local management password is provided (e.g. TUI mode).
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
|
||||
s.managementRoutesEnabled.Store(hasManagementSecret)
|
||||
if hasManagementSecret {
|
||||
s.registerManagementRoutes()
|
||||
@@ -323,6 +324,7 @@ func (s *Server) setupRoutes() {
|
||||
v1.POST("/completions", openaiHandlers.Completions)
|
||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
|
||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||
}
|
||||
@@ -616,6 +618,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
|
||||
@@ -40,8 +40,7 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *claude.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*claude.AuthenticationError](err); ok {
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == claude.ErrPortInUse.Type {
|
||||
os.Exit(claude.ErrPortInUse.Code)
|
||||
|
||||
@@ -32,8 +32,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -100,49 +100,74 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
var activatedProjects []string
|
||||
|
||||
useGoogleOne := false
|
||||
if trimmedProjectID == "" && promptFn != nil {
|
||||
fmt.Println("\nSelect login mode:")
|
||||
fmt.Println(" 1. Code Assist (GCP project, manual selection)")
|
||||
fmt.Println(" 2. Google One (personal account, auto-discover project)")
|
||||
choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ")
|
||||
if errPrompt == nil && strings.TrimSpace(choice) == "2" {
|
||||
useGoogleOne = true
|
||||
}
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
activatedProjects := make([]string, 0, len(projectSelections))
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
var projectErr *projectSelectionRequiredError
|
||||
if errors.As(errSetup, &projectErr) {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
if useGoogleOne {
|
||||
log.Info("Google One mode: auto-discovering project...")
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
autoProject := strings.TrimSpace(storage.ProjectID)
|
||||
if autoProject == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
log.Infof("Auto-discovered project: %s", autoProject)
|
||||
activatedProjects = []string{autoProject}
|
||||
} else {
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
}
|
||||
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
|
||||
storage.Auto = false
|
||||
@@ -235,7 +260,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
@@ -617,7 +683,7 @@ func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStor
|
||||
return
|
||||
}
|
||||
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, false)
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true)
|
||||
|
||||
if record.Metadata == nil {
|
||||
record.Metadata = make(map[string]any)
|
||||
|
||||
@@ -54,8 +54,7 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *codex.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
|
||||
@@ -44,8 +44,7 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
|
||||
}
|
||||
}
|
||||
|
||||
// StartServiceBackground starts the proxy service in a background goroutine
|
||||
// and returns a cancel function for shutdown and a done channel.
|
||||
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
|
||||
builder := cliproxy.NewBuilder().
|
||||
WithConfig(cfg).
|
||||
WithConfigPath(configPath).
|
||||
WithLocalManagementPassword(localPassword)
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
service, err := builder.Build()
|
||||
if err != nil {
|
||||
log.Errorf("failed to build proxy service: %v", err)
|
||||
close(doneCh)
|
||||
return cancelFn, doneCh
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("proxy service exited with error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return cancelFn, doneCh
|
||||
}
|
||||
|
||||
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
|
||||
// when no configuration file is available.
|
||||
func WaitForCloudDeploy() {
|
||||
|
||||
@@ -90,6 +90,10 @@ type Config struct {
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values for Claude API requests.
|
||||
// These are used as fallbacks when the client does not send its own headers.
|
||||
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
|
||||
|
||||
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
|
||||
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
|
||||
|
||||
@@ -117,6 +121,15 @@ type Config struct {
|
||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
|
||||
// when the client does not send them. Update these when Claude Code releases a new version.
|
||||
type ClaudeHeaderDefaults struct {
|
||||
UserAgent string `yaml:"user-agent" json:"user-agent"`
|
||||
PackageVersion string `yaml:"package-version" json:"package-version"`
|
||||
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
|
||||
Timeout string `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
@@ -355,6 +368,9 @@ type CodexKey struct {
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
// Websockets enables the Responses API websocket transport for this credential.
|
||||
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
|
||||
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
// - codex
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kimi
|
||||
// - antigravity (returns static overrides only)
|
||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
key := strings.ToLower(strings.TrimSpace(channel))
|
||||
@@ -39,6 +40,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetQwenModels()
|
||||
case "iflow":
|
||||
return GetIFlowModels()
|
||||
case "kimi":
|
||||
return GetKimiModels()
|
||||
case "antigravity":
|
||||
cfg := GetAntigravityModelConfig()
|
||||
if len(cfg) == 0 {
|
||||
@@ -83,6 +86,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
GetKimiModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
|
||||
@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1771372800, // 2026-02-17
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Object: "model",
|
||||
@@ -742,6 +753,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex-spark",
|
||||
Object: "model",
|
||||
Created: 1770912000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex Spark",
|
||||
Description: "Ultra-fast coding model.",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -774,6 +799,19 @@ func GetQwenModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "coder-model",
|
||||
Object: "model",
|
||||
Created: 1771171200,
|
||||
OwnedBy: "qwen",
|
||||
Type: "qwen",
|
||||
Version: "3.5",
|
||||
DisplayName: "Qwen 3.5 Plus",
|
||||
Description: "efficient hybrid model with leading coding performance",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "vision-model",
|
||||
Object: "model",
|
||||
@@ -814,6 +852,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
@@ -828,6 +867,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
@@ -868,6 +908,8 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
|
||||
@@ -596,8 +596,7 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestAntigravityBuildRequest_SanitizesGeminiToolSchema(t *testing.T) {
|
||||
body := buildRequestBodyFromPayload(t, "gemini-2.5-pro")
|
||||
|
||||
decl := extractFirstFunctionDeclaration(t, body)
|
||||
if _, ok := decl["parametersJsonSchema"]; ok {
|
||||
t.Fatalf("parametersJsonSchema should be renamed to parameters")
|
||||
}
|
||||
|
||||
params, ok := decl["parameters"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters missing or invalid type")
|
||||
}
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequest_SanitizesAntigravityToolSchema(t *testing.T) {
|
||||
body := buildRequestBodyFromPayload(t, "claude-opus-4-6")
|
||||
|
||||
decl := extractFirstFunctionDeclaration(t, body)
|
||||
params, ok := decl["parameters"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters missing or invalid type")
|
||||
}
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{}
|
||||
payload := []byte(`{
|
||||
"request": {
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": "tool_1",
|
||||
"parametersJsonSchema": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"$id": "root-schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"$id": {"type": "string"},
|
||||
"arg": {
|
||||
"type": "object",
|
||||
"prefill": "hello",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"],
|
||||
"enumTitles": ["A", "B"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"patternProperties": {
|
||||
"^x-": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRequest error: %v", err)
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body error: %v", err)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Fatalf("unmarshal request body error: %v, body=%s", err, string(raw))
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func extractFirstFunctionDeclaration(t *testing.T, body map[string]any) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
request, ok := body["request"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("request missing or invalid type")
|
||||
}
|
||||
tools, ok := request["tools"].([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
t.Fatalf("tools missing or empty")
|
||||
}
|
||||
tool, ok := tools[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first tool invalid type")
|
||||
}
|
||||
decls, ok := tool["function_declarations"].([]any)
|
||||
if !ok || len(decls) == 0 {
|
||||
t.Fatalf("function_declarations missing or empty")
|
||||
}
|
||||
decl, ok := decls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first function declaration invalid type")
|
||||
}
|
||||
return decl
|
||||
}
|
||||
|
||||
func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := params["$id"]; ok {
|
||||
t.Fatalf("root $id should be removed from schema")
|
||||
}
|
||||
if _, ok := params["patternProperties"]; ok {
|
||||
t.Fatalf("patternProperties should be removed from schema")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("properties missing or invalid type")
|
||||
}
|
||||
if _, ok := props["$id"]; !ok {
|
||||
t.Fatalf("property named $id should be preserved")
|
||||
}
|
||||
|
||||
arg, ok := props["arg"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("arg property missing or invalid type")
|
||||
}
|
||||
if _, ok := arg["prefill"]; ok {
|
||||
t.Fatalf("prefill should be removed from nested schema")
|
||||
}
|
||||
|
||||
argProps, ok := arg["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("arg.properties missing or invalid type")
|
||||
}
|
||||
mode, ok := argProps["mode"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("mode property missing or invalid type")
|
||||
}
|
||||
if _, ok := mode["enumTitles"]; ok {
|
||||
t.Fatalf("enumTitles should be removed from nested schema")
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
} else {
|
||||
reporter.publish(ctx, parseClaudeUsage(data))
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||
}
|
||||
var param any
|
||||
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -348,7 +349,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
// Forward the line as-is to preserve SSE format
|
||||
@@ -375,7 +376,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(
|
||||
@@ -423,7 +424,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
@@ -432,7 +433,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -638,7 +639,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
||||
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
|
||||
func mapStainlessOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "MacOS"
|
||||
case "windows":
|
||||
return "Windows"
|
||||
case "linux":
|
||||
return "Linux"
|
||||
case "freebsd":
|
||||
return "FreeBSD"
|
||||
default:
|
||||
return "Other::" + runtime.GOOS
|
||||
}
|
||||
}
|
||||
|
||||
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
|
||||
func mapStainlessArch() string {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return "x64"
|
||||
case "arm64":
|
||||
return "arm64"
|
||||
case "386":
|
||||
return "x86"
|
||||
default:
|
||||
return "other::" + runtime.GOARCH
|
||||
}
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
|
||||
hdrDefault := func(cfgVal, fallback string) string {
|
||||
if cfgVal != "" {
|
||||
return cfgVal
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hd config.ClaudeHeaderDefaults
|
||||
if cfg != nil {
|
||||
hd = cfg.ClaudeHeaderDefaults
|
||||
}
|
||||
|
||||
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
|
||||
if isAnthropicBase && useAPIKey {
|
||||
@@ -685,16 +728,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
if stream {
|
||||
@@ -702,6 +746,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
} else {
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
@@ -753,11 +799,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// Collect built-in tool names (those with a non-empty "type" field) so we can
|
||||
// skip them consistently in both tools and message history.
|
||||
builtinTools := map[string]bool{}
|
||||
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
|
||||
builtinTools[name] = true
|
||||
}
|
||||
|
||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||
// Skip built-in tools (web_search, code_execution, etc.) which have
|
||||
// a "type" field and require their name to remain unchanged.
|
||||
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
||||
if n := tool.Get("name").String(); n != "" {
|
||||
builtinTools[n] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
name := tool.Get("name").String()
|
||||
@@ -772,7 +828,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
|
||||
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
|
||||
name := gjson.GetBytes(body, "tool_choice.name").String()
|
||||
if name != "" && !strings.HasPrefix(name, prefix) {
|
||||
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
|
||||
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
|
||||
}
|
||||
}
|
||||
@@ -784,15 +840,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
return true
|
||||
}
|
||||
content.ForEach(func(contentIndex, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+toolName)
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
nestedContent := part.Get("content")
|
||||
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||
if nestedPart.Get("type").String() == "tool_reference" {
|
||||
nestedToolName := nestedPart.Get("tool_name").String()
|
||||
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
|
||||
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
@@ -811,15 +890,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
|
||||
return body
|
||||
}
|
||||
content.ForEach(func(index, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if !strings.HasPrefix(toolName, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.tool_name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
nestedContent := part.Get("content")
|
||||
if nestedContent.Exists() && nestedContent.IsArray() {
|
||||
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
|
||||
if nestedPart.Get("type").String() == "tool_reference" {
|
||||
nestedToolName := nestedPart.Get("tool_name").String()
|
||||
if strings.HasPrefix(nestedToolName, prefix) {
|
||||
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
return true
|
||||
})
|
||||
return body
|
||||
@@ -834,15 +936,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
||||
return line
|
||||
}
|
||||
contentBlock := gjson.GetBytes(payload, "content_block")
|
||||
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
|
||||
if !contentBlock.Exists() {
|
||||
return line
|
||||
}
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
|
||||
blockType := contentBlock.Get("type").String()
|
||||
var updated []byte
|
||||
var err error
|
||||
|
||||
switch blockType {
|
||||
case "tool_use":
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
case "tool_reference":
|
||||
toolName := contentBlock.Get("tool_name").String()
|
||||
if !strings.HasPrefix(toolName, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
default:
|
||||
return line
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
|
||||
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
|
||||
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
@@ -37,6 +49,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"name": "Read"}
|
||||
],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [{"name": "Read"}, {"name": "Write"}],
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
|
||||
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"tools": [
|
||||
{"type": "web_search_20250305", "name": "web_search"},
|
||||
{"name": "Read"}
|
||||
],
|
||||
"tool_choice": {"type": "tool", "name": "web_search"}
|
||||
}`)
|
||||
out := applyClaudeToolPrefix(body, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
|
||||
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
@@ -49,6 +152,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
|
||||
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
|
||||
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
@@ -61,3 +176,53 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
|
||||
payload := bytes.TrimSpace(out)
|
||||
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
|
||||
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "proxy_mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
|
||||
if got != "mcp__nia__manage_resource" {
|
||||
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
|
||||
// tool_result.content can be a string - should not be processed
|
||||
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
|
||||
if got != "plain string result" {
|
||||
t.Fatalf("string content should remain unchanged = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
|
||||
if got != "web_search" {
|
||||
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.98.0"
|
||||
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexClientVersion = "0.101.0"
|
||||
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
@@ -643,7 +643,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
|
||||
|
||||
1407
internal/runtime/executor/codex_websockets_executor.go
Normal file
1407
internal/runtime/executor/codex_websockets_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -899,8 +899,7 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
return new(time.Duration(seconds) * time.Second), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,9 +22,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
qwenUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
qwenXGoogAPIClient = "gl-node/22.17.0"
|
||||
qwenClientMetadataValue = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
)
|
||||
|
||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||
@@ -344,8 +342,18 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
r.Header.Set("User-Agent", qwenUserAgent)
|
||||
r.Header.Set("X-Goog-Api-Client", qwenXGoogAPIClient)
|
||||
r.Header.Set("Client-Metadata", qwenClientMetadataValue)
|
||||
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
|
||||
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
||||
r.Header.Set("Sec-Fetch-Mode", "cors")
|
||||
r.Header.Set("X-Stainless-Lang", "js")
|
||||
r.Header.Set("X-Stainless-Arch", "arm64")
|
||||
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
||||
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
||||
r.Header.Set("X-Stainless-Retry-Count", "0")
|
||||
r.Header.Set("X-Stainless-Os", "MacOS")
|
||||
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
||||
r.Header.Set("X-Stainless-Runtime", "node")
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
|
||||
@@ -10,10 +10,53 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// validReasoningEffortLevels contains the standard values accepted by the
|
||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
||||
// auto) are NOT in this set and must be clamped before use.
|
||||
var validReasoningEffortLevels = map[string]struct{}{
|
||||
"none": {},
|
||||
"low": {},
|
||||
"medium": {},
|
||||
"high": {},
|
||||
}
|
||||
|
||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
||||
// mapped to the nearest standard equivalent.
|
||||
//
|
||||
// Mapping rules:
|
||||
// - none / low / medium / high → returned as-is (already valid)
|
||||
// - xhigh → "high" (nearest lower standard level)
|
||||
// - minimal → "low" (nearest higher standard level)
|
||||
// - auto → "medium" (reasonable default)
|
||||
// - anything else → "medium" (safe default)
|
||||
func clampReasoningEffort(level string) string {
|
||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
||||
return level
|
||||
}
|
||||
var clamped string
|
||||
switch level {
|
||||
case string(thinking.LevelXHigh):
|
||||
clamped = string(thinking.LevelHigh)
|
||||
case string(thinking.LevelMinimal):
|
||||
clamped = string(thinking.LevelLow)
|
||||
case string(thinking.LevelAuto):
|
||||
clamped = string(thinking.LevelMedium)
|
||||
default:
|
||||
clamped = string(thinking.LevelMedium)
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"original": level,
|
||||
"clamped": clamped,
|
||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
||||
return clamped
|
||||
}
|
||||
|
||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||
//
|
||||
// OpenAI-specific behavior:
|
||||
@@ -58,7 +101,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
}
|
||||
|
||||
if config.Mode == thinking.ModeLevel {
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -79,7 +122,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -114,7 +157,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -20,10 +20,12 @@ var (
|
||||
|
||||
// ConvertCliToOpenAIParams holds parameters for response conversion.
|
||||
type ConvertCliToOpenAIParams struct {
|
||||
ResponseID string
|
||||
CreatedAt int64
|
||||
Model string
|
||||
FunctionCallIndex int
|
||||
ResponseID string
|
||||
CreatedAt int64
|
||||
Model string
|
||||
FunctionCallIndex int
|
||||
HasReceivedArgumentsDelta bool
|
||||
HasToolCallAnnounced bool
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
|
||||
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
|
||||
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
if *param == nil {
|
||||
*param = &ConvertCliToOpenAIParams{
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
FunctionCallIndex: -1,
|
||||
Model: modelName,
|
||||
CreatedAt: 0,
|
||||
ResponseID: "",
|
||||
FunctionCallIndex: -1,
|
||||
HasReceivedArgumentsDelta: false,
|
||||
HasToolCallAnnounced: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +94,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
@@ -115,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
} else if dataType == "response.output_item.done" {
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
} else if dataType == "response.output_item.added" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if itemResult.Exists() {
|
||||
if itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// set the index
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened
|
||||
name := itemResult.Get("name").String()
|
||||
// Build reverse map on demand from original request tools
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Increment index for this new function call item.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
|
||||
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.delta" {
|
||||
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
|
||||
|
||||
deltaValue := rootResult.Get("delta").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.function_call_arguments.done" {
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
|
||||
// Arguments were already streamed via delta events; nothing to emit.
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Fallback: no delta events were received, emit the full arguments as a single chunk.
|
||||
fullArgs := rootResult.Get("arguments").String()
|
||||
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else if dataType == "response.output_item.done" {
|
||||
itemResult := rootResult.Get("item")
|
||||
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
|
||||
// Tool call was already announced via output_item.added; skip emission.
|
||||
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Fallback path: model skipped output_item.added, so emit complete tool call now.
|
||||
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
|
||||
|
||||
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
|
||||
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||
|
||||
// Restore original tool name if it was shortened.
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
|
||||
if orig, ok := rev[name]; ok {
|
||||
name = orig
|
||||
}
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
|
||||
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||
|
||||
} else {
|
||||
return []string{}
|
||||
}
|
||||
@@ -205,6 +270,9 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
|
||||
@@ -27,6 +27,9 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
|
||||
// Delete the user field as it is not supported by the Codex upstream.
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||
|
||||
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
|
||||
rawJSON = convertSystemRoleToDeveloper(rawJSON)
|
||||
|
||||
|
||||
@@ -263,3 +263,20 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
||||
t.Errorf("Expected third role 'assistant', got '%s'", thirdRole.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"user": "test-user",
|
||||
"input": [{"role": "user", "content": "Hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify user field is deleted
|
||||
userField := gjson.Get(outputStr, "user")
|
||||
if userField.Exists() {
|
||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
542
internal/tui/app.go
Normal file
542
internal/tui/app.go
Normal file
@@ -0,0 +1,542 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Tab identifiers
|
||||
const (
|
||||
tabDashboard = iota
|
||||
tabConfig
|
||||
tabAuthFiles
|
||||
tabAPIKeys
|
||||
tabOAuth
|
||||
tabUsage
|
||||
tabLogs
|
||||
)
|
||||
|
||||
// App is the root bubbletea model that contains all tab sub-models.
|
||||
type App struct {
|
||||
activeTab int
|
||||
tabs []string
|
||||
|
||||
standalone bool
|
||||
logsEnabled bool
|
||||
|
||||
authenticated bool
|
||||
authInput textinput.Model
|
||||
authError string
|
||||
authConnecting bool
|
||||
|
||||
dashboard dashboardModel
|
||||
config configTabModel
|
||||
auth authTabModel
|
||||
keys keysTabModel
|
||||
oauth oauthTabModel
|
||||
usage usageTabModel
|
||||
logs logsTabModel
|
||||
|
||||
client *Client
|
||||
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
|
||||
// Track which tabs have been initialized (fetched data)
|
||||
initialized [7]bool
|
||||
}
|
||||
|
||||
type authConnectMsg struct {
|
||||
cfg map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
// NewApp creates the root TUI application model.
|
||||
func NewApp(port int, secretKey string, hook *LogHook) App {
|
||||
standalone := hook != nil
|
||||
authRequired := !standalone
|
||||
ti := textinput.New()
|
||||
ti.CharLimit = 512
|
||||
ti.EchoMode = textinput.EchoPassword
|
||||
ti.EchoCharacter = '*'
|
||||
ti.SetValue(strings.TrimSpace(secretKey))
|
||||
ti.Focus()
|
||||
|
||||
client := NewClient(port, secretKey)
|
||||
app := App{
|
||||
activeTab: tabDashboard,
|
||||
standalone: standalone,
|
||||
logsEnabled: true,
|
||||
authenticated: !authRequired,
|
||||
authInput: ti,
|
||||
dashboard: newDashboardModel(client),
|
||||
config: newConfigTabModel(client),
|
||||
auth: newAuthTabModel(client),
|
||||
keys: newKeysTabModel(client),
|
||||
oauth: newOAuthTabModel(client),
|
||||
usage: newUsageTabModel(client),
|
||||
logs: newLogsTabModel(client, hook),
|
||||
client: client,
|
||||
initialized: [7]bool{
|
||||
tabDashboard: true,
|
||||
tabLogs: true,
|
||||
},
|
||||
}
|
||||
|
||||
app.refreshTabs()
|
||||
if authRequired {
|
||||
app.initialized = [7]bool{}
|
||||
}
|
||||
app.setAuthInputPrompt()
|
||||
return app
|
||||
}
|
||||
|
||||
func (a App) Init() tea.Cmd {
|
||||
if !a.authenticated {
|
||||
return textinput.Blink
|
||||
}
|
||||
cmds := []tea.Cmd{a.dashboard.Init()}
|
||||
if a.logsEnabled {
|
||||
cmds = append(cmds, a.logs.Init())
|
||||
}
|
||||
return tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
a.width = msg.Width
|
||||
a.height = msg.Height
|
||||
a.ready = true
|
||||
if a.width > 0 {
|
||||
a.authInput.Width = a.width - 6
|
||||
}
|
||||
contentH := a.height - 4 // tab bar + status bar
|
||||
if contentH < 1 {
|
||||
contentH = 1
|
||||
}
|
||||
contentW := a.width
|
||||
a.dashboard.SetSize(contentW, contentH)
|
||||
a.config.SetSize(contentW, contentH)
|
||||
a.auth.SetSize(contentW, contentH)
|
||||
a.keys.SetSize(contentW, contentH)
|
||||
a.oauth.SetSize(contentW, contentH)
|
||||
a.usage.SetSize(contentW, contentH)
|
||||
a.logs.SetSize(contentW, contentH)
|
||||
return a, nil
|
||||
|
||||
case authConnectMsg:
|
||||
a.authConnecting = false
|
||||
if msg.err != nil {
|
||||
a.authError = fmt.Sprintf(T("auth_gate_connect_fail"), msg.err.Error())
|
||||
return a, nil
|
||||
}
|
||||
a.authError = ""
|
||||
a.authenticated = true
|
||||
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
|
||||
a.refreshTabs()
|
||||
a.initialized = [7]bool{}
|
||||
a.initialized[tabDashboard] = true
|
||||
cmds := []tea.Cmd{a.dashboard.Init()}
|
||||
if a.logsEnabled {
|
||||
a.initialized[tabLogs] = true
|
||||
cmds = append(cmds, a.logs.Init())
|
||||
}
|
||||
return a, tea.Batch(cmds...)
|
||||
|
||||
case configUpdateMsg:
|
||||
var cmdLogs tea.Cmd
|
||||
if !a.standalone && msg.err == nil && msg.path == "logging-to-file" {
|
||||
logsEnabledConfig, okConfig := msg.value.(bool)
|
||||
if okConfig {
|
||||
logsEnabledBefore := a.logsEnabled
|
||||
a.logsEnabled = logsEnabledConfig
|
||||
if logsEnabledBefore != a.logsEnabled {
|
||||
a.refreshTabs()
|
||||
}
|
||||
if !a.logsEnabled {
|
||||
a.initialized[tabLogs] = false
|
||||
}
|
||||
if !logsEnabledBefore && a.logsEnabled {
|
||||
a.initialized[tabLogs] = true
|
||||
cmdLogs = a.logs.Init()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cmdConfig tea.Cmd
|
||||
a.config, cmdConfig = a.config.Update(msg)
|
||||
if cmdConfig != nil && cmdLogs != nil {
|
||||
return a, tea.Batch(cmdConfig, cmdLogs)
|
||||
}
|
||||
if cmdConfig != nil {
|
||||
return a, cmdConfig
|
||||
}
|
||||
return a, cmdLogs
|
||||
|
||||
case tea.KeyMsg:
|
||||
if !a.authenticated {
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q":
|
||||
return a, tea.Quit
|
||||
case "L":
|
||||
ToggleLocale()
|
||||
a.refreshTabs()
|
||||
a.setAuthInputPrompt()
|
||||
return a, nil
|
||||
case "enter":
|
||||
if a.authConnecting {
|
||||
return a, nil
|
||||
}
|
||||
password := strings.TrimSpace(a.authInput.Value())
|
||||
if password == "" {
|
||||
a.authError = T("auth_gate_password_required")
|
||||
return a, nil
|
||||
}
|
||||
a.authError = ""
|
||||
a.authConnecting = true
|
||||
return a, a.connectWithPassword(password)
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
a.authInput, cmd = a.authInput.Update(msg)
|
||||
return a, cmd
|
||||
}
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "ctrl+c":
|
||||
return a, tea.Quit
|
||||
case "q":
|
||||
// Only quit if not in logs tab (where 'q' might be useful)
|
||||
if !a.logsEnabled || a.activeTab != tabLogs {
|
||||
return a, tea.Quit
|
||||
}
|
||||
case "L":
|
||||
ToggleLocale()
|
||||
a.refreshTabs()
|
||||
return a.broadcastToAllTabs(localeChangedMsg{})
|
||||
case "tab":
|
||||
if len(a.tabs) == 0 {
|
||||
return a, nil
|
||||
}
|
||||
prevTab := a.activeTab
|
||||
a.activeTab = (a.activeTab + 1) % len(a.tabs)
|
||||
return a, a.initTabIfNeeded(prevTab)
|
||||
case "shift+tab":
|
||||
if len(a.tabs) == 0 {
|
||||
return a, nil
|
||||
}
|
||||
prevTab := a.activeTab
|
||||
a.activeTab = (a.activeTab - 1 + len(a.tabs)) % len(a.tabs)
|
||||
return a, a.initTabIfNeeded(prevTab)
|
||||
}
|
||||
}
|
||||
|
||||
if !a.authenticated {
|
||||
var cmd tea.Cmd
|
||||
a.authInput, cmd = a.authInput.Update(msg)
|
||||
return a, cmd
|
||||
}
|
||||
|
||||
// Route msg to active tab
|
||||
var cmd tea.Cmd
|
||||
switch a.activeTab {
|
||||
case tabDashboard:
|
||||
a.dashboard, cmd = a.dashboard.Update(msg)
|
||||
case tabConfig:
|
||||
a.config, cmd = a.config.Update(msg)
|
||||
case tabAuthFiles:
|
||||
a.auth, cmd = a.auth.Update(msg)
|
||||
case tabAPIKeys:
|
||||
a.keys, cmd = a.keys.Update(msg)
|
||||
case tabOAuth:
|
||||
a.oauth, cmd = a.oauth.Update(msg)
|
||||
case tabUsage:
|
||||
a.usage, cmd = a.usage.Update(msg)
|
||||
case tabLogs:
|
||||
a.logs, cmd = a.logs.Update(msg)
|
||||
}
|
||||
|
||||
// Keep logs polling alive even when logs tab is not active.
|
||||
if a.logsEnabled && a.activeTab != tabLogs {
|
||||
switch msg.(type) {
|
||||
case logsPollMsg, logsTickMsg, logLineMsg:
|
||||
var logCmd tea.Cmd
|
||||
a.logs, logCmd = a.logs.Update(msg)
|
||||
if logCmd != nil {
|
||||
cmd = logCmd
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return a, cmd
|
||||
}
|
||||
|
||||
// localeChangedMsg is broadcast to all tabs when the user toggles locale.
|
||||
type localeChangedMsg struct{}
|
||||
|
||||
func (a *App) refreshTabs() {
|
||||
names := TabNames()
|
||||
if a.logsEnabled {
|
||||
a.tabs = names
|
||||
} else {
|
||||
filtered := make([]string, 0, len(names)-1)
|
||||
for idx, name := range names {
|
||||
if idx == tabLogs {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
a.tabs = filtered
|
||||
}
|
||||
|
||||
if len(a.tabs) == 0 {
|
||||
a.activeTab = tabDashboard
|
||||
return
|
||||
}
|
||||
if a.activeTab >= len(a.tabs) {
|
||||
a.activeTab = len(a.tabs) - 1
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) initTabIfNeeded(_ int) tea.Cmd {
|
||||
if a.initialized[a.activeTab] {
|
||||
return nil
|
||||
}
|
||||
a.initialized[a.activeTab] = true
|
||||
switch a.activeTab {
|
||||
case tabDashboard:
|
||||
return a.dashboard.Init()
|
||||
case tabConfig:
|
||||
return a.config.Init()
|
||||
case tabAuthFiles:
|
||||
return a.auth.Init()
|
||||
case tabAPIKeys:
|
||||
return a.keys.Init()
|
||||
case tabOAuth:
|
||||
return a.oauth.Init()
|
||||
case tabUsage:
|
||||
return a.usage.Init()
|
||||
case tabLogs:
|
||||
if !a.logsEnabled {
|
||||
return nil
|
||||
}
|
||||
return a.logs.Init()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a App) View() string {
|
||||
if !a.authenticated {
|
||||
return a.renderAuthView()
|
||||
}
|
||||
|
||||
if !a.ready {
|
||||
return T("initializing_tui")
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Tab bar
|
||||
sb.WriteString(a.renderTabBar())
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Content
|
||||
switch a.activeTab {
|
||||
case tabDashboard:
|
||||
sb.WriteString(a.dashboard.View())
|
||||
case tabConfig:
|
||||
sb.WriteString(a.config.View())
|
||||
case tabAuthFiles:
|
||||
sb.WriteString(a.auth.View())
|
||||
case tabAPIKeys:
|
||||
sb.WriteString(a.keys.View())
|
||||
case tabOAuth:
|
||||
sb.WriteString(a.oauth.View())
|
||||
case tabUsage:
|
||||
sb.WriteString(a.usage.View())
|
||||
case tabLogs:
|
||||
if a.logsEnabled {
|
||||
sb.WriteString(a.logs.View())
|
||||
}
|
||||
}
|
||||
|
||||
// Status bar
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(a.renderStatusBar())
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (a App) renderAuthView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("auth_gate_title")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("auth_gate_help")))
|
||||
sb.WriteString("\n\n")
|
||||
if a.authConnecting {
|
||||
sb.WriteString(warningStyle.Render(T("auth_gate_connecting")))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
if strings.TrimSpace(a.authError) != "" {
|
||||
sb.WriteString(errorStyle.Render(a.authError))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
sb.WriteString(a.authInput.View())
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("auth_gate_enter")))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (a App) renderTabBar() string {
|
||||
var tabs []string
|
||||
for i, name := range a.tabs {
|
||||
if i == a.activeTab {
|
||||
tabs = append(tabs, tabActiveStyle.Render(name))
|
||||
} else {
|
||||
tabs = append(tabs, tabInactiveStyle.Render(name))
|
||||
}
|
||||
}
|
||||
tabBar := lipgloss.JoinHorizontal(lipgloss.Top, tabs...)
|
||||
return tabBarStyle.Width(a.width).Render(tabBar)
|
||||
}
|
||||
|
||||
func (a App) renderStatusBar() string {
|
||||
left := strings.TrimRight(T("status_left"), " ")
|
||||
right := strings.TrimRight(T("status_right"), " ")
|
||||
|
||||
width := a.width
|
||||
if width < 1 {
|
||||
width = 1
|
||||
}
|
||||
|
||||
// statusBarStyle has left/right padding(1), so content area is width-2.
|
||||
contentWidth := width - 2
|
||||
if contentWidth < 0 {
|
||||
contentWidth = 0
|
||||
}
|
||||
|
||||
if lipgloss.Width(left) > contentWidth {
|
||||
left = fitStringWidth(left, contentWidth)
|
||||
right = ""
|
||||
}
|
||||
|
||||
remaining := contentWidth - lipgloss.Width(left)
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
if lipgloss.Width(right) > remaining {
|
||||
right = fitStringWidth(right, remaining)
|
||||
}
|
||||
|
||||
gap := contentWidth - lipgloss.Width(left) - lipgloss.Width(right)
|
||||
if gap < 0 {
|
||||
gap = 0
|
||||
}
|
||||
return statusBarStyle.Width(width).Render(left + strings.Repeat(" ", gap) + right)
|
||||
}
|
||||
|
||||
func fitStringWidth(text string, maxWidth int) string {
|
||||
if maxWidth <= 0 {
|
||||
return ""
|
||||
}
|
||||
if lipgloss.Width(text) <= maxWidth {
|
||||
return text
|
||||
}
|
||||
|
||||
out := ""
|
||||
for _, r := range text {
|
||||
next := out + string(r)
|
||||
if lipgloss.Width(next) > maxWidth {
|
||||
break
|
||||
}
|
||||
out = next
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isLogsEnabledFromConfig(cfg map[string]any) bool {
|
||||
if cfg == nil {
|
||||
return true
|
||||
}
|
||||
value, ok := cfg["logging-to-file"]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
enabled, ok := value.(bool)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (a *App) setAuthInputPrompt() {
|
||||
if a == nil {
|
||||
return
|
||||
}
|
||||
a.authInput.Prompt = fmt.Sprintf(" %s: ", T("auth_gate_password"))
|
||||
}
|
||||
|
||||
func (a App) connectWithPassword(password string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
a.client.SetSecretKey(password)
|
||||
cfg, errGetConfig := a.client.GetConfig()
|
||||
return authConnectMsg{cfg: cfg, err: errGetConfig}
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the TUI application.
|
||||
// output specifies where bubbletea renders. If nil, defaults to os.Stdout.
|
||||
func Run(port int, secretKey string, hook *LogHook, output io.Writer) error {
|
||||
if output == nil {
|
||||
output = os.Stdout
|
||||
}
|
||||
app := NewApp(port, secretKey, hook)
|
||||
p := tea.NewProgram(app, tea.WithAltScreen(), tea.WithOutput(output))
|
||||
_, err := p.Run()
|
||||
return err
|
||||
}
|
||||
|
||||
func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
var cmd tea.Cmd
|
||||
|
||||
a.dashboard, cmd = a.dashboard.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
a.config, cmd = a.config.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
a.auth, cmd = a.auth.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
a.keys, cmd = a.keys.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
a.oauth, cmd = a.oauth.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
a.usage, cmd = a.usage.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
a.logs, cmd = a.logs.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
return a, tea.Batch(cmds...)
|
||||
}
|
||||
456
internal/tui/auth_tab.go
Normal file
456
internal/tui/auth_tab.go
Normal file
@@ -0,0 +1,456 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// editableField represents an editable field on an auth file.
|
||||
type editableField struct {
|
||||
label string
|
||||
key string // API field key: "prefix", "proxy_url", "priority"
|
||||
}
|
||||
|
||||
var authEditableFields = []editableField{
|
||||
{label: "Prefix", key: "prefix"},
|
||||
{label: "Proxy URL", key: "proxy_url"},
|
||||
{label: "Priority", key: "priority"},
|
||||
}
|
||||
|
||||
// authTabModel displays auth credential files with interactive management.
|
||||
type authTabModel struct {
|
||||
client *Client
|
||||
viewport viewport.Model
|
||||
files []map[string]any
|
||||
err error
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
cursor int
|
||||
expanded int // -1 = none expanded, >=0 = expanded index
|
||||
confirm int // -1 = no confirmation, >=0 = confirm delete for index
|
||||
status string
|
||||
|
||||
// Editing state
|
||||
editing bool // true when editing a field
|
||||
editField int // index into authEditableFields
|
||||
editInput textinput.Model // text input for editing
|
||||
editFileName string // name of file being edited
|
||||
}
|
||||
|
||||
type authFilesMsg struct {
|
||||
files []map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
type authActionMsg struct {
|
||||
action string // "deleted", "toggled", "updated"
|
||||
err error
|
||||
}
|
||||
|
||||
func newAuthTabModel(client *Client) authTabModel {
|
||||
ti := textinput.New()
|
||||
ti.CharLimit = 256
|
||||
return authTabModel{
|
||||
client: client,
|
||||
expanded: -1,
|
||||
confirm: -1,
|
||||
editInput: ti,
|
||||
}
|
||||
}
|
||||
|
||||
func (m authTabModel) Init() tea.Cmd {
|
||||
return m.fetchFiles
|
||||
}
|
||||
|
||||
func (m authTabModel) fetchFiles() tea.Msg {
|
||||
files, err := m.client.GetAuthFiles()
|
||||
return authFilesMsg{files: files, err: err}
|
||||
}
|
||||
|
||||
func (m authTabModel) Update(msg tea.Msg) (authTabModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
case authFilesMsg:
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
} else {
|
||||
m.err = nil
|
||||
m.files = msg.files
|
||||
if m.cursor >= len(m.files) {
|
||||
m.cursor = max(0, len(m.files)-1)
|
||||
}
|
||||
m.status = ""
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
|
||||
case authActionMsg:
|
||||
if msg.err != nil {
|
||||
m.status = errorStyle.Render("✗ " + msg.err.Error())
|
||||
} else {
|
||||
m.status = successStyle.Render("✓ " + msg.action)
|
||||
}
|
||||
m.confirm = -1
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, m.fetchFiles
|
||||
|
||||
case tea.KeyMsg:
|
||||
// ---- Editing mode ----
|
||||
if m.editing {
|
||||
return m.handleEditInput(msg)
|
||||
}
|
||||
|
||||
// ---- Delete confirmation mode ----
|
||||
if m.confirm >= 0 {
|
||||
return m.handleConfirmInput(msg)
|
||||
}
|
||||
|
||||
// ---- Normal mode ----
|
||||
return m.handleNormalInput(msg)
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
// startEdit activates inline editing for a field on the currently selected auth file.
|
||||
func (m *authTabModel) startEdit(fieldIdx int) tea.Cmd {
|
||||
if m.cursor >= len(m.files) {
|
||||
return nil
|
||||
}
|
||||
f := m.files[m.cursor]
|
||||
m.editFileName = getString(f, "name")
|
||||
m.editField = fieldIdx
|
||||
m.editing = true
|
||||
|
||||
// Pre-populate with current value
|
||||
key := authEditableFields[fieldIdx].key
|
||||
currentVal := getAnyString(f, key)
|
||||
m.editInput.SetValue(currentVal)
|
||||
m.editInput.Focus()
|
||||
m.editInput.Prompt = fmt.Sprintf(" %s: ", authEditableFields[fieldIdx].label)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return textinput.Blink
|
||||
}
|
||||
|
||||
func (m *authTabModel) SetSize(w, h int) {
|
||||
m.width = w
|
||||
m.height = h
|
||||
m.editInput.Width = w - 20
|
||||
if !m.ready {
|
||||
m.viewport = viewport.New(w, h)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
m.ready = true
|
||||
} else {
|
||||
m.viewport.Width = w
|
||||
m.viewport.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
func (m authTabModel) View() string {
|
||||
if !m.ready {
|
||||
return T("loading")
|
||||
}
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m authTabModel) renderContent() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("auth_title")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("auth_help1")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("auth_help2")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", m.width))
|
||||
sb.WriteString("\n")
|
||||
|
||||
if m.err != nil {
|
||||
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
if len(m.files) == 0 {
|
||||
sb.WriteString(subtitleStyle.Render(T("no_auth_files")))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
for i, f := range m.files {
|
||||
name := getString(f, "name")
|
||||
channel := getString(f, "channel")
|
||||
email := getString(f, "email")
|
||||
disabled := getBool(f, "disabled")
|
||||
|
||||
statusIcon := successStyle.Render("●")
|
||||
statusText := T("status_active")
|
||||
if disabled {
|
||||
statusIcon = lipgloss.NewStyle().Foreground(colorMuted).Render("○")
|
||||
statusText = T("status_disabled")
|
||||
}
|
||||
|
||||
cursor := " "
|
||||
rowStyle := lipgloss.NewStyle()
|
||||
if i == m.cursor {
|
||||
cursor = "▸ "
|
||||
rowStyle = lipgloss.NewStyle().Bold(true)
|
||||
}
|
||||
|
||||
displayName := name
|
||||
if len(displayName) > 24 {
|
||||
displayName = displayName[:21] + "..."
|
||||
}
|
||||
displayEmail := email
|
||||
if len(displayEmail) > 28 {
|
||||
displayEmail = displayEmail[:25] + "..."
|
||||
}
|
||||
|
||||
row := fmt.Sprintf("%s%s %-24s %-12s %-28s %s",
|
||||
cursor, statusIcon, displayName, channel, displayEmail, statusText)
|
||||
sb.WriteString(rowStyle.Render(row))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Delete confirmation
|
||||
if m.confirm == i {
|
||||
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete"), name)))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Inline edit input
|
||||
if m.editing && i == m.cursor {
|
||||
sb.WriteString(m.editInput.View())
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(" " + T("enter_save") + " • " + T("esc_cancel")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Expanded detail view
|
||||
if m.expanded == i {
|
||||
sb.WriteString(m.renderDetail(f))
|
||||
}
|
||||
}
|
||||
|
||||
if m.status != "" {
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(m.status)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m authTabModel) renderDetail(f map[string]any) string {
|
||||
var sb strings.Builder
|
||||
|
||||
labelStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("111")).
|
||||
Bold(true)
|
||||
valueStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("252"))
|
||||
editableMarker := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("214")).
|
||||
Render(" ✎")
|
||||
|
||||
sb.WriteString(" ┌─────────────────────────────────────────────\n")
|
||||
|
||||
fields := []struct {
|
||||
label string
|
||||
key string
|
||||
editable bool
|
||||
}{
|
||||
{"Name", "name", false},
|
||||
{"Channel", "channel", false},
|
||||
{"Email", "email", false},
|
||||
{"Status", "status", false},
|
||||
{"Status Msg", "status_message", false},
|
||||
{"File Name", "file_name", false},
|
||||
{"Auth Type", "auth_type", false},
|
||||
{"Prefix", "prefix", true},
|
||||
{"Proxy URL", "proxy_url", true},
|
||||
{"Priority", "priority", true},
|
||||
{"Project ID", "project_id", false},
|
||||
{"Disabled", "disabled", false},
|
||||
{"Created", "created_at", false},
|
||||
{"Updated", "updated_at", false},
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
val := getAnyString(f, field.key)
|
||||
if val == "" || val == "<nil>" {
|
||||
if field.editable {
|
||||
val = T("not_set")
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
editMark := ""
|
||||
if field.editable {
|
||||
editMark = editableMarker
|
||||
}
|
||||
line := fmt.Sprintf(" │ %s %s%s",
|
||||
labelStyle.Render(fmt.Sprintf("%-12s:", field.label)),
|
||||
valueStyle.Render(val),
|
||||
editMark)
|
||||
sb.WriteString(line)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString(" └─────────────────────────────────────────────\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// getAnyString converts any value to its string representation.
|
||||
func getAnyString(m map[string]any, key string) string {
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (m authTabModel) handleEditInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
value := m.editInput.Value()
|
||||
fieldKey := authEditableFields[m.editField].key
|
||||
fileName := m.editFileName
|
||||
m.editing = false
|
||||
m.editInput.Blur()
|
||||
fields := map[string]any{}
|
||||
if fieldKey == "priority" {
|
||||
p, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return m, func() tea.Msg {
|
||||
return authActionMsg{err: fmt.Errorf("%s: %s", T("invalid_int"), value)}
|
||||
}
|
||||
}
|
||||
fields[fieldKey] = p
|
||||
} else {
|
||||
fields[fieldKey] = value
|
||||
}
|
||||
return m, func() tea.Msg {
|
||||
err := m.client.PatchAuthFileFields(fileName, fields)
|
||||
if err != nil {
|
||||
return authActionMsg{err: err}
|
||||
}
|
||||
return authActionMsg{action: fmt.Sprintf(T("updated_field"), fieldKey, fileName)}
|
||||
}
|
||||
case "esc":
|
||||
m.editing = false
|
||||
m.editInput.Blur()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.editInput, cmd = m.editInput.Update(msg)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
func (m authTabModel) handleConfirmInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
idx := m.confirm
|
||||
m.confirm = -1
|
||||
if idx < len(m.files) {
|
||||
name := getString(m.files[idx], "name")
|
||||
return m, func() tea.Msg {
|
||||
err := m.client.DeleteAuthFile(name)
|
||||
if err != nil {
|
||||
return authActionMsg{err: err}
|
||||
}
|
||||
return authActionMsg{action: fmt.Sprintf(T("deleted"), name)}
|
||||
}
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
case "n", "N", "esc":
|
||||
m.confirm = -1
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m authTabModel) handleNormalInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "j", "down":
|
||||
if len(m.files) > 0 {
|
||||
m.cursor = (m.cursor + 1) % len(m.files)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "k", "up":
|
||||
if len(m.files) > 0 {
|
||||
m.cursor = (m.cursor - 1 + len(m.files)) % len(m.files)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "enter", " ":
|
||||
if m.expanded == m.cursor {
|
||||
m.expanded = -1
|
||||
} else {
|
||||
m.expanded = m.cursor
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
case "d", "D":
|
||||
if m.cursor < len(m.files) {
|
||||
m.confirm = m.cursor
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "e", "E":
|
||||
if m.cursor < len(m.files) {
|
||||
f := m.files[m.cursor]
|
||||
name := getString(f, "name")
|
||||
disabled := getBool(f, "disabled")
|
||||
newDisabled := !disabled
|
||||
return m, func() tea.Msg {
|
||||
err := m.client.ToggleAuthFile(name, newDisabled)
|
||||
if err != nil {
|
||||
return authActionMsg{err: err}
|
||||
}
|
||||
action := T("enabled")
|
||||
if newDisabled {
|
||||
action = T("disabled")
|
||||
}
|
||||
return authActionMsg{action: fmt.Sprintf("%s %s", action, name)}
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
case "1":
|
||||
return m, m.startEdit(0) // prefix
|
||||
case "2":
|
||||
return m, m.startEdit(1) // proxy_url
|
||||
case "3":
|
||||
return m, m.startEdit(2) // priority
|
||||
case "r":
|
||||
m.status = ""
|
||||
return m, m.fetchFiles
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
20
internal/tui/browser.go
Normal file
20
internal/tui/browser.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// openBrowser opens the specified URL in the user's default browser.
|
||||
func openBrowser(url string) error {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return exec.Command("open", url).Start()
|
||||
case "linux":
|
||||
return exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
default:
|
||||
return exec.Command("xdg-open", url).Start()
|
||||
}
|
||||
}
|
||||
400
internal/tui/client.go
Normal file
400
internal/tui/client.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client wraps HTTP calls to the management API.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
secretKey string
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new management API client.
|
||||
func NewClient(port int, secretKey string) *Client {
|
||||
return &Client{
|
||||
baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||
secretKey: strings.TrimSpace(secretKey),
|
||||
http: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetSecretKey updates management API bearer token used by this client.
|
||||
func (c *Client) SetSecretKey(secretKey string) {
|
||||
c.secretKey = strings.TrimSpace(secretKey)
|
||||
}
|
||||
|
||||
func (c *Client) doRequest(method, path string, body io.Reader) ([]byte, int, error) {
|
||||
url := c.baseURL + path
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if c.secretKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.secretKey)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
return data, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (c *Client) get(path string) ([]byte, error) {
|
||||
data, code, err := c.doRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if code >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *Client) put(path string, body io.Reader) ([]byte, error) {
|
||||
data, code, err := c.doRequest("PUT", path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if code >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *Client) patch(path string, body io.Reader) ([]byte, error) {
|
||||
data, code, err := c.doRequest("PATCH", path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if code >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// getJSON fetches a path and unmarshals JSON into a generic map.
|
||||
func (c *Client) getJSON(path string) (map[string]any, error) {
|
||||
data, err := c.get(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// postJSON sends a JSON body via POST and checks for errors.
|
||||
func (c *Client) postJSON(path string, body any) error {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, code, err := c.doRequest("POST", path, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if code >= 400 {
|
||||
return fmt.Errorf("HTTP %d", code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfig fetches the parsed config.
|
||||
func (c *Client) GetConfig() (map[string]any, error) {
|
||||
return c.getJSON("/v0/management/config")
|
||||
}
|
||||
|
||||
// GetConfigYAML fetches the raw config.yaml content.
|
||||
func (c *Client) GetConfigYAML() (string, error) {
|
||||
data, err := c.get("/v0/management/config.yaml")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// PutConfigYAML uploads new config.yaml content.
|
||||
func (c *Client) PutConfigYAML(yamlContent string) error {
|
||||
_, err := c.put("/v0/management/config.yaml", strings.NewReader(yamlContent))
|
||||
return err
|
||||
}
|
||||
|
||||
// GetUsage fetches usage statistics.
|
||||
func (c *Client) GetUsage() (map[string]any, error) {
|
||||
return c.getJSON("/v0/management/usage")
|
||||
}
|
||||
|
||||
// GetAuthFiles lists auth credential files.
|
||||
// API returns {"files": [...]}.
|
||||
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
|
||||
wrapper, err := c.getJSON("/v0/management/auth-files")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return extractList(wrapper, "files")
|
||||
}
|
||||
|
||||
// DeleteAuthFile deletes a single auth file by name.
|
||||
func (c *Client) DeleteAuthFile(name string) error {
|
||||
query := url.Values{}
|
||||
query.Set("name", name)
|
||||
path := "/v0/management/auth-files?" + query.Encode()
|
||||
_, code, err := c.doRequest("DELETE", path, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if code >= 400 {
|
||||
return fmt.Errorf("delete failed (HTTP %d)", code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToggleAuthFile enables or disables an auth file.
|
||||
func (c *Client) ToggleAuthFile(name string, disabled bool) error {
|
||||
body, _ := json.Marshal(map[string]any{"name": name, "disabled": disabled})
|
||||
_, err := c.patch("/v0/management/auth-files/status", strings.NewReader(string(body)))
|
||||
return err
|
||||
}
|
||||
|
||||
// PatchAuthFileFields updates editable fields on an auth file.
|
||||
func (c *Client) PatchAuthFileFields(name string, fields map[string]any) error {
|
||||
fields["name"] = name
|
||||
body, _ := json.Marshal(fields)
|
||||
_, err := c.patch("/v0/management/auth-files/fields", strings.NewReader(string(body)))
|
||||
return err
|
||||
}
|
||||
|
||||
// GetLogs fetches log lines from the server.
|
||||
func (c *Client) GetLogs(after int64, limit int) ([]string, int64, error) {
|
||||
query := url.Values{}
|
||||
if limit > 0 {
|
||||
query.Set("limit", strconv.Itoa(limit))
|
||||
}
|
||||
if after > 0 {
|
||||
query.Set("after", strconv.FormatInt(after, 10))
|
||||
}
|
||||
|
||||
path := "/v0/management/logs"
|
||||
encodedQuery := query.Encode()
|
||||
if encodedQuery != "" {
|
||||
path += "?" + encodedQuery
|
||||
}
|
||||
|
||||
wrapper, err := c.getJSON(path)
|
||||
if err != nil {
|
||||
return nil, after, err
|
||||
}
|
||||
|
||||
lines := []string{}
|
||||
if rawLines, ok := wrapper["lines"]; ok && rawLines != nil {
|
||||
rawJSON, errMarshal := json.Marshal(rawLines)
|
||||
if errMarshal != nil {
|
||||
return nil, after, errMarshal
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(rawJSON, &lines); errUnmarshal != nil {
|
||||
return nil, after, errUnmarshal
|
||||
}
|
||||
}
|
||||
|
||||
latest := after
|
||||
if rawLatest, ok := wrapper["latest-timestamp"]; ok {
|
||||
switch value := rawLatest.(type) {
|
||||
case float64:
|
||||
latest = int64(value)
|
||||
case json.Number:
|
||||
if parsed, errParse := value.Int64(); errParse == nil {
|
||||
latest = parsed
|
||||
}
|
||||
case int64:
|
||||
latest = value
|
||||
case int:
|
||||
latest = int64(value)
|
||||
}
|
||||
}
|
||||
if latest < after {
|
||||
latest = after
|
||||
}
|
||||
|
||||
return lines, latest, nil
|
||||
}
|
||||
|
||||
// GetAPIKeys fetches the list of API keys.
|
||||
// API returns {"api-keys": [...]}.
|
||||
func (c *Client) GetAPIKeys() ([]string, error) {
|
||||
wrapper, err := c.getJSON("/v0/management/api-keys")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr, ok := wrapper["api-keys"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
raw, err := json.Marshal(arr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []string
|
||||
if err := json.Unmarshal(raw, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AddAPIKey adds a new API key by sending old=nil, new=key which appends.
|
||||
func (c *Client) AddAPIKey(key string) error {
|
||||
body := map[string]any{"old": nil, "new": key}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
|
||||
return err
|
||||
}
|
||||
|
||||
// EditAPIKey replaces an API key at the given index.
|
||||
func (c *Client) EditAPIKey(index int, newValue string) error {
|
||||
body := map[string]any{"index": index, "value": newValue}
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteAPIKey deletes an API key by index.
|
||||
func (c *Client) DeleteAPIKey(index int) error {
|
||||
_, code, err := c.doRequest("DELETE", fmt.Sprintf("/v0/management/api-keys?index=%d", index), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if code >= 400 {
|
||||
return fmt.Errorf("delete failed (HTTP %d)", code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGeminiKeys fetches Gemini API keys.
|
||||
// API returns {"gemini-api-key": [...]}.
|
||||
func (c *Client) GetGeminiKeys() ([]map[string]any, error) {
|
||||
return c.getWrappedKeyList("/v0/management/gemini-api-key", "gemini-api-key")
|
||||
}
|
||||
|
||||
// GetClaudeKeys fetches Claude API keys.
|
||||
func (c *Client) GetClaudeKeys() ([]map[string]any, error) {
|
||||
return c.getWrappedKeyList("/v0/management/claude-api-key", "claude-api-key")
|
||||
}
|
||||
|
||||
// GetCodexKeys fetches Codex API keys.
|
||||
func (c *Client) GetCodexKeys() ([]map[string]any, error) {
|
||||
return c.getWrappedKeyList("/v0/management/codex-api-key", "codex-api-key")
|
||||
}
|
||||
|
||||
// GetVertexKeys fetches Vertex API keys.
|
||||
func (c *Client) GetVertexKeys() ([]map[string]any, error) {
|
||||
return c.getWrappedKeyList("/v0/management/vertex-api-key", "vertex-api-key")
|
||||
}
|
||||
|
||||
// GetOpenAICompat fetches OpenAI compatibility entries.
|
||||
func (c *Client) GetOpenAICompat() ([]map[string]any, error) {
|
||||
return c.getWrappedKeyList("/v0/management/openai-compatibility", "openai-compatibility")
|
||||
}
|
||||
|
||||
// getWrappedKeyList fetches a wrapped list from the API.
|
||||
func (c *Client) getWrappedKeyList(path, key string) ([]map[string]any, error) {
|
||||
wrapper, err := c.getJSON(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return extractList(wrapper, key)
|
||||
}
|
||||
|
||||
// extractList pulls an array of maps from a wrapper object by key.
|
||||
func extractList(wrapper map[string]any, key string) ([]map[string]any, error) {
|
||||
arr, ok := wrapper[key]
|
||||
if !ok || arr == nil {
|
||||
return nil, nil
|
||||
}
|
||||
raw, err := json.Marshal(arr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []map[string]any
|
||||
if err := json.Unmarshal(raw, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetDebug fetches the current debug setting.
|
||||
func (c *Client) GetDebug() (bool, error) {
|
||||
wrapper, err := c.getJSON("/v0/management/debug")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if v, ok := wrapper["debug"]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetAuthStatus polls the OAuth session status.
|
||||
// Returns status ("wait", "ok", "error") and optional error message.
|
||||
func (c *Client) GetAuthStatus(state string) (string, string, error) {
|
||||
query := url.Values{}
|
||||
query.Set("state", state)
|
||||
path := "/v0/management/get-auth-status?" + query.Encode()
|
||||
wrapper, err := c.getJSON(path)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
status := getString(wrapper, "status")
|
||||
errMsg := getString(wrapper, "error")
|
||||
return status, errMsg, nil
|
||||
}
|
||||
|
||||
// ----- Config field update methods -----
|
||||
|
||||
// PutBoolField updates a boolean config field.
|
||||
func (c *Client) PutBoolField(path string, value bool) error {
|
||||
body, _ := json.Marshal(map[string]any{"value": value})
|
||||
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
|
||||
return err
|
||||
}
|
||||
|
||||
// PutIntField updates an integer config field.
|
||||
func (c *Client) PutIntField(path string, value int) error {
|
||||
body, _ := json.Marshal(map[string]any{"value": value})
|
||||
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
|
||||
return err
|
||||
}
|
||||
|
||||
// PutStringField updates a string config field.
|
||||
func (c *Client) PutStringField(path string, value string) error {
|
||||
body, _ := json.Marshal(map[string]any{"value": value})
|
||||
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteField sends a DELETE request for a config field.
|
||||
func (c *Client) DeleteField(path string) error {
|
||||
_, _, err := c.doRequest("DELETE", "/v0/management/"+path, nil)
|
||||
return err
|
||||
}
|
||||
413
internal/tui/config_tab.go
Normal file
413
internal/tui/config_tab.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// configField represents a single editable config field.
|
||||
type configField struct {
|
||||
label string
|
||||
apiPath string // management API path (e.g. "debug", "proxy-url")
|
||||
kind string // "bool", "int", "string", "readonly"
|
||||
value string // current display value
|
||||
rawValue any // raw value from API
|
||||
}
|
||||
|
||||
// configTabModel displays parsed config with interactive editing.
|
||||
type configTabModel struct {
|
||||
client *Client
|
||||
viewport viewport.Model
|
||||
fields []configField
|
||||
cursor int
|
||||
editing bool
|
||||
textInput textinput.Model
|
||||
err error
|
||||
message string // status message (success/error)
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
}
|
||||
|
||||
type configDataMsg struct {
|
||||
config map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
type configUpdateMsg struct {
|
||||
path string
|
||||
value any
|
||||
err error
|
||||
}
|
||||
|
||||
func newConfigTabModel(client *Client) configTabModel {
|
||||
ti := textinput.New()
|
||||
ti.CharLimit = 256
|
||||
return configTabModel{
|
||||
client: client,
|
||||
textInput: ti,
|
||||
}
|
||||
}
|
||||
|
||||
func (m configTabModel) Init() tea.Cmd {
|
||||
return m.fetchConfig
|
||||
}
|
||||
|
||||
func (m configTabModel) fetchConfig() tea.Msg {
|
||||
cfg, err := m.client.GetConfig()
|
||||
return configDataMsg{config: cfg, err: err}
|
||||
}
|
||||
|
||||
func (m configTabModel) Update(msg tea.Msg) (configTabModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
case configDataMsg:
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
m.fields = nil
|
||||
} else {
|
||||
m.err = nil
|
||||
m.fields = m.parseConfig(msg.config)
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
|
||||
case configUpdateMsg:
|
||||
if msg.err != nil {
|
||||
m.message = errorStyle.Render("✗ " + msg.err.Error())
|
||||
} else {
|
||||
m.message = successStyle.Render(T("updated_ok"))
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
// Refresh config from server
|
||||
return m, m.fetchConfig
|
||||
|
||||
case tea.KeyMsg:
|
||||
if m.editing {
|
||||
return m.handleEditingKey(msg)
|
||||
}
|
||||
return m.handleNormalKey(msg)
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m configTabModel) handleNormalKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "r":
|
||||
m.message = ""
|
||||
return m, m.fetchConfig
|
||||
case "up", "k":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
// Ensure cursor is visible
|
||||
m.ensureCursorVisible()
|
||||
}
|
||||
return m, nil
|
||||
case "down", "j":
|
||||
if m.cursor < len(m.fields)-1 {
|
||||
m.cursor++
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
m.ensureCursorVisible()
|
||||
}
|
||||
return m, nil
|
||||
case "enter", " ":
|
||||
if m.cursor >= 0 && m.cursor < len(m.fields) {
|
||||
f := m.fields[m.cursor]
|
||||
if f.kind == "readonly" {
|
||||
return m, nil
|
||||
}
|
||||
if f.kind == "bool" {
|
||||
// Toggle directly
|
||||
return m, m.toggleBool(m.cursor)
|
||||
}
|
||||
// Start editing for int/string
|
||||
m.editing = true
|
||||
m.textInput.SetValue(configFieldEditValue(f))
|
||||
m.textInput.Focus()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, textinput.Blink
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m configTabModel) handleEditingKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
m.editing = false
|
||||
m.textInput.Blur()
|
||||
return m, m.submitEdit(m.cursor, m.textInput.Value())
|
||||
case "esc":
|
||||
m.editing = false
|
||||
m.textInput.Blur()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.textInput, cmd = m.textInput.Update(msg)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
func (m configTabModel) toggleBool(idx int) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
f := m.fields[idx]
|
||||
current := f.value == "true"
|
||||
newValue := !current
|
||||
errPutBool := m.client.PutBoolField(f.apiPath, newValue)
|
||||
return configUpdateMsg{
|
||||
path: f.apiPath,
|
||||
value: newValue,
|
||||
err: errPutBool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m configTabModel) submitEdit(idx int, newValue string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
f := m.fields[idx]
|
||||
var err error
|
||||
var value any
|
||||
switch f.kind {
|
||||
case "int":
|
||||
valueInt, errAtoi := strconv.Atoi(newValue)
|
||||
if errAtoi != nil {
|
||||
return configUpdateMsg{
|
||||
path: f.apiPath,
|
||||
err: fmt.Errorf("%s: %s", T("invalid_int"), newValue),
|
||||
}
|
||||
}
|
||||
value = valueInt
|
||||
err = m.client.PutIntField(f.apiPath, valueInt)
|
||||
case "string":
|
||||
value = newValue
|
||||
err = m.client.PutStringField(f.apiPath, newValue)
|
||||
}
|
||||
return configUpdateMsg{
|
||||
path: f.apiPath,
|
||||
value: value,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func configFieldEditValue(f configField) string {
|
||||
if rawString, ok := f.rawValue.(string); ok {
|
||||
return rawString
|
||||
}
|
||||
return f.value
|
||||
}
|
||||
|
||||
func (m *configTabModel) SetSize(w, h int) {
|
||||
m.width = w
|
||||
m.height = h
|
||||
if !m.ready {
|
||||
m.viewport = viewport.New(w, h)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
m.ready = true
|
||||
} else {
|
||||
m.viewport.Width = w
|
||||
m.viewport.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
func (m *configTabModel) ensureCursorVisible() {
|
||||
// Each field takes ~1 line, header takes ~4 lines
|
||||
targetLine := m.cursor + 5
|
||||
if targetLine < m.viewport.YOffset {
|
||||
m.viewport.SetYOffset(targetLine)
|
||||
}
|
||||
if targetLine >= m.viewport.YOffset+m.viewport.Height {
|
||||
m.viewport.SetYOffset(targetLine - m.viewport.Height + 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (m configTabModel) View() string {
|
||||
if !m.ready {
|
||||
return T("loading")
|
||||
}
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m configTabModel) renderContent() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("config_title")))
|
||||
sb.WriteString("\n")
|
||||
|
||||
if m.message != "" {
|
||||
sb.WriteString(" " + m.message)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString(helpStyle.Render(T("config_help1")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("config_help2")))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if m.err != nil {
|
||||
sb.WriteString(errorStyle.Render(" ⚠ Error: " + m.err.Error()))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
if len(m.fields) == 0 {
|
||||
sb.WriteString(subtitleStyle.Render(T("no_config")))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
currentSection := ""
|
||||
for i, f := range m.fields {
|
||||
// Section headers
|
||||
section := fieldSection(f.apiPath)
|
||||
if section != currentSection {
|
||||
currentSection = section
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(" ── " + section + " "))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
isSelected := i == m.cursor
|
||||
prefix := " "
|
||||
if isSelected {
|
||||
prefix = "▸ "
|
||||
}
|
||||
|
||||
labelStr := lipgloss.NewStyle().
|
||||
Foreground(colorInfo).
|
||||
Bold(isSelected).
|
||||
Width(32).
|
||||
Render(f.label)
|
||||
|
||||
var valueStr string
|
||||
if m.editing && isSelected {
|
||||
valueStr = m.textInput.View()
|
||||
} else {
|
||||
switch f.kind {
|
||||
case "bool":
|
||||
if f.value == "true" {
|
||||
valueStr = successStyle.Render("● ON")
|
||||
} else {
|
||||
valueStr = lipgloss.NewStyle().Foreground(colorMuted).Render("○ OFF")
|
||||
}
|
||||
case "readonly":
|
||||
valueStr = lipgloss.NewStyle().Foreground(colorSubtext).Render(f.value)
|
||||
default:
|
||||
valueStr = valueStyle.Render(f.value)
|
||||
}
|
||||
}
|
||||
|
||||
line := prefix + labelStr + " " + valueStr
|
||||
if isSelected && !m.editing {
|
||||
line = lipgloss.NewStyle().Background(colorSurface).Render(line)
|
||||
}
|
||||
sb.WriteString(line + "\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m configTabModel) parseConfig(cfg map[string]any) []configField {
|
||||
var fields []configField
|
||||
|
||||
// Server settings
|
||||
fields = append(fields, configField{"Port", "port", "readonly", fmt.Sprintf("%.0f", getFloat(cfg, "port")), nil})
|
||||
fields = append(fields, configField{"Host", "host", "readonly", getString(cfg, "host"), nil})
|
||||
fields = append(fields, configField{"Debug", "debug", "bool", fmt.Sprintf("%v", getBool(cfg, "debug")), nil})
|
||||
fields = append(fields, configField{"Proxy URL", "proxy-url", "string", getString(cfg, "proxy-url"), nil})
|
||||
fields = append(fields, configField{"Request Retry", "request-retry", "int", fmt.Sprintf("%.0f", getFloat(cfg, "request-retry")), nil})
|
||||
fields = append(fields, configField{"Max Retry Interval (s)", "max-retry-interval", "int", fmt.Sprintf("%.0f", getFloat(cfg, "max-retry-interval")), nil})
|
||||
fields = append(fields, configField{"Force Model Prefix", "force-model-prefix", "string", getString(cfg, "force-model-prefix"), nil})
|
||||
|
||||
// Logging
|
||||
fields = append(fields, configField{"Logging to File", "logging-to-file", "bool", fmt.Sprintf("%v", getBool(cfg, "logging-to-file")), nil})
|
||||
fields = append(fields, configField{"Logs Max Total Size (MB)", "logs-max-total-size-mb", "int", fmt.Sprintf("%.0f", getFloat(cfg, "logs-max-total-size-mb")), nil})
|
||||
fields = append(fields, configField{"Error Logs Max Files", "error-logs-max-files", "int", fmt.Sprintf("%.0f", getFloat(cfg, "error-logs-max-files")), nil})
|
||||
fields = append(fields, configField{"Usage Stats Enabled", "usage-statistics-enabled", "bool", fmt.Sprintf("%v", getBool(cfg, "usage-statistics-enabled")), nil})
|
||||
fields = append(fields, configField{"Request Log", "request-log", "bool", fmt.Sprintf("%v", getBool(cfg, "request-log")), nil})
|
||||
|
||||
// Quota exceeded
|
||||
fields = append(fields, configField{"Switch Project on Quota", "quota-exceeded/switch-project", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-project")), nil})
|
||||
fields = append(fields, configField{"Switch Preview Model", "quota-exceeded/switch-preview-model", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-preview-model")), nil})
|
||||
|
||||
// Routing
|
||||
if routing, ok := cfg["routing"].(map[string]any); ok {
|
||||
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", getString(routing, "strategy"), nil})
|
||||
} else {
|
||||
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", "", nil})
|
||||
}
|
||||
|
||||
// WebSocket auth
|
||||
fields = append(fields, configField{"WebSocket Auth", "ws-auth", "bool", fmt.Sprintf("%v", getBool(cfg, "ws-auth")), nil})
|
||||
|
||||
// AMP settings
|
||||
if amp, ok := cfg["ampcode"].(map[string]any); ok {
|
||||
upstreamURL := getString(amp, "upstream-url")
|
||||
upstreamAPIKey := getString(amp, "upstream-api-key")
|
||||
fields = append(fields, configField{"AMP Upstream URL", "ampcode/upstream-url", "string", upstreamURL, upstreamURL})
|
||||
fields = append(fields, configField{"AMP Upstream API Key", "ampcode/upstream-api-key", "string", maskIfNotEmpty(upstreamAPIKey), upstreamAPIKey})
|
||||
fields = append(fields, configField{"AMP Restrict Mgmt Localhost", "ampcode/restrict-management-to-localhost", "bool", fmt.Sprintf("%v", getBool(amp, "restrict-management-to-localhost")), nil})
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func fieldSection(apiPath string) string {
|
||||
if strings.HasPrefix(apiPath, "ampcode/") {
|
||||
return T("section_ampcode")
|
||||
}
|
||||
if strings.HasPrefix(apiPath, "quota-exceeded/") {
|
||||
return T("section_quota")
|
||||
}
|
||||
if strings.HasPrefix(apiPath, "routing/") {
|
||||
return T("section_routing")
|
||||
}
|
||||
switch apiPath {
|
||||
case "port", "host", "debug", "proxy-url", "request-retry", "max-retry-interval", "force-model-prefix":
|
||||
return T("section_server")
|
||||
case "logging-to-file", "logs-max-total-size-mb", "error-logs-max-files", "usage-statistics-enabled", "request-log":
|
||||
return T("section_logging")
|
||||
case "ws-auth":
|
||||
return T("section_websocket")
|
||||
default:
|
||||
return T("section_other")
|
||||
}
|
||||
}
|
||||
|
||||
func getBoolNested(m map[string]any, keys ...string) bool {
|
||||
current := m
|
||||
for i, key := range keys {
|
||||
if i == len(keys)-1 {
|
||||
return getBool(current, key)
|
||||
}
|
||||
if nested, ok := current[key].(map[string]any); ok {
|
||||
current = nested
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func maskIfNotEmpty(s string) string {
|
||||
if s == "" {
|
||||
return T("not_set")
|
||||
}
|
||||
return maskKey(s)
|
||||
}
|
||||
360
internal/tui/dashboard.go
Normal file
360
internal/tui/dashboard.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// dashboardModel displays server info, stats cards, and config overview.
|
||||
type dashboardModel struct {
|
||||
client *Client
|
||||
viewport viewport.Model
|
||||
content string
|
||||
err error
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
|
||||
// Cached data for re-rendering on locale change
|
||||
lastConfig map[string]any
|
||||
lastUsage map[string]any
|
||||
lastAuthFiles []map[string]any
|
||||
lastAPIKeys []string
|
||||
}
|
||||
|
||||
type dashboardDataMsg struct {
|
||||
config map[string]any
|
||||
usage map[string]any
|
||||
authFiles []map[string]any
|
||||
apiKeys []string
|
||||
err error
|
||||
}
|
||||
|
||||
func newDashboardModel(client *Client) dashboardModel {
|
||||
return dashboardModel{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (m dashboardModel) Init() tea.Cmd {
|
||||
return m.fetchData
|
||||
}
|
||||
|
||||
func (m dashboardModel) fetchData() tea.Msg {
|
||||
cfg, cfgErr := m.client.GetConfig()
|
||||
usage, usageErr := m.client.GetUsage()
|
||||
authFiles, authErr := m.client.GetAuthFiles()
|
||||
apiKeys, keysErr := m.client.GetAPIKeys()
|
||||
|
||||
var err error
|
||||
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
}
|
||||
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
|
||||
}
|
||||
|
||||
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
// Re-render immediately with cached data using new locale
|
||||
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
|
||||
m.viewport.SetContent(m.content)
|
||||
// Also fetch fresh data in background
|
||||
return m, m.fetchData
|
||||
|
||||
case dashboardDataMsg:
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
m.content = errorStyle.Render("⚠ Error: " + msg.err.Error())
|
||||
} else {
|
||||
m.err = nil
|
||||
// Cache data for locale switching
|
||||
m.lastConfig = msg.config
|
||||
m.lastUsage = msg.usage
|
||||
m.lastAuthFiles = msg.authFiles
|
||||
m.lastAPIKeys = msg.apiKeys
|
||||
|
||||
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
|
||||
}
|
||||
m.viewport.SetContent(m.content)
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
if msg.String() == "r" {
|
||||
return m, m.fetchData
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *dashboardModel) SetSize(w, h int) {
|
||||
m.width = w
|
||||
m.height = h
|
||||
if !m.ready {
|
||||
m.viewport = viewport.New(w, h)
|
||||
m.viewport.SetContent(m.content)
|
||||
m.ready = true
|
||||
} else {
|
||||
m.viewport.Width = w
|
||||
m.viewport.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
func (m dashboardModel) View() string {
|
||||
if !m.ready {
|
||||
return T("loading")
|
||||
}
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("dashboard_title")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("dashboard_help")))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// ━━━ Connection Status ━━━
|
||||
connStyle := lipgloss.NewStyle().Bold(true).Foreground(colorSuccess)
|
||||
sb.WriteString(connStyle.Render(T("connected")))
|
||||
sb.WriteString(fmt.Sprintf(" %s", m.client.baseURL))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// ━━━ Stats Cards ━━━
|
||||
cardWidth := 25
|
||||
if m.width > 0 {
|
||||
cardWidth = (m.width - 6) / 4
|
||||
if cardWidth < 18 {
|
||||
cardWidth = 18
|
||||
}
|
||||
}
|
||||
|
||||
cardStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("240")).
|
||||
Padding(0, 1).
|
||||
Width(cardWidth).
|
||||
Height(2)
|
||||
|
||||
// Card 1: API Keys
|
||||
keyCount := len(apiKeys)
|
||||
card1 := cardStyle.Render(fmt.Sprintf(
|
||||
"%s\n%s",
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("🔑 %d", keyCount)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("mgmt_keys")),
|
||||
))
|
||||
|
||||
// Card 2: Auth Files
|
||||
authCount := len(authFiles)
|
||||
activeAuth := 0
|
||||
for _, f := range authFiles {
|
||||
if !getBool(f, "disabled") {
|
||||
activeAuth++
|
||||
}
|
||||
}
|
||||
card2 := cardStyle.Render(fmt.Sprintf(
|
||||
"%s\n%s",
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("📄 %d", authCount)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
|
||||
))
|
||||
|
||||
// Card 3: Total Requests
|
||||
totalReqs := int64(0)
|
||||
successReqs := int64(0)
|
||||
failedReqs := int64(0)
|
||||
totalTokens := int64(0)
|
||||
if usage != nil {
|
||||
if usageMap, ok := usage["usage"].(map[string]any); ok {
|
||||
totalReqs = int64(getFloat(usageMap, "total_requests"))
|
||||
successReqs = int64(getFloat(usageMap, "success_count"))
|
||||
failedReqs = int64(getFloat(usageMap, "failure_count"))
|
||||
totalTokens = int64(getFloat(usageMap, "total_tokens"))
|
||||
}
|
||||
}
|
||||
card3 := cardStyle.Render(fmt.Sprintf(
|
||||
"%s\n%s",
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
|
||||
))
|
||||
|
||||
// Card 4: Total Tokens
|
||||
tokenStr := formatLargeNumber(totalTokens)
|
||||
card4 := cardStyle.Render(fmt.Sprintf(
|
||||
"%s\n%s",
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
|
||||
))
|
||||
|
||||
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// ━━━ Current Config ━━━
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("current_config")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
|
||||
if cfg != nil {
|
||||
debug := getBool(cfg, "debug")
|
||||
retry := getFloat(cfg, "request-retry")
|
||||
proxyURL := getString(cfg, "proxy-url")
|
||||
loggingToFile := getBool(cfg, "logging-to-file")
|
||||
usageEnabled := true
|
||||
if v, ok := cfg["usage-statistics-enabled"]; ok {
|
||||
if b, ok2 := v.(bool); ok2 {
|
||||
usageEnabled = b
|
||||
}
|
||||
}
|
||||
|
||||
configItems := []struct {
|
||||
label string
|
||||
value string
|
||||
}{
|
||||
{T("debug_mode"), boolEmoji(debug)},
|
||||
{T("usage_stats"), boolEmoji(usageEnabled)},
|
||||
{T("log_to_file"), boolEmoji(loggingToFile)},
|
||||
{T("retry_count"), fmt.Sprintf("%.0f", retry)},
|
||||
}
|
||||
if proxyURL != "" {
|
||||
configItems = append(configItems, struct {
|
||||
label string
|
||||
value string
|
||||
}{T("proxy_url"), proxyURL})
|
||||
}
|
||||
|
||||
// Render config items as a compact row
|
||||
for _, item := range configItems {
|
||||
sb.WriteString(fmt.Sprintf(" %s %s\n",
|
||||
labelStyle.Render(item.label+":"),
|
||||
valueStyle.Render(item.value)))
|
||||
}
|
||||
|
||||
// Routing strategy
|
||||
strategy := "round-robin"
|
||||
if routing, ok := cfg["routing"].(map[string]any); ok {
|
||||
if s := getString(routing, "strategy"); s != "" {
|
||||
strategy = s
|
||||
}
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %s %s\n",
|
||||
labelStyle.Render(T("routing_strategy")+":"),
|
||||
valueStyle.Render(strategy)))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// ━━━ Per-Model Usage ━━━
|
||||
if usage != nil {
|
||||
if usageMap, ok := usage["usage"].(map[string]any); ok {
|
||||
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
|
||||
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
|
||||
sb.WriteString(tableHeaderStyle.Render(header))
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, apiSnap := range apis {
|
||||
if apiMap, ok := apiSnap.(map[string]any); ok {
|
||||
if models, ok := apiMap["models"].(map[string]any); ok {
|
||||
for model, v := range models {
|
||||
if stats, ok := v.(map[string]any); ok {
|
||||
reqs := int64(getFloat(stats, "total_requests"))
|
||||
toks := int64(getFloat(stats, "total_tokens"))
|
||||
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
|
||||
sb.WriteString(tableCellStyle.Render(row))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func formatKV(key, value string) string {
|
||||
return fmt.Sprintf(" %s %s\n", labelStyle.Render(key+":"), valueStyle.Render(value))
|
||||
}
|
||||
|
||||
func getString(m map[string]any, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getFloat(m map[string]any, key string) float64 {
|
||||
if v, ok := m[key]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return n
|
||||
case json.Number:
|
||||
f, _ := n.Float64()
|
||||
return f
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getBool(m map[string]any, key string) bool {
|
||||
if v, ok := m[key]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func boolEmoji(b bool) string {
|
||||
if b {
|
||||
return T("bool_yes")
|
||||
}
|
||||
return T("bool_no")
|
||||
}
|
||||
|
||||
func formatLargeNumber(n int64) string {
|
||||
if n >= 1_000_000 {
|
||||
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
|
||||
}
|
||||
if n >= 1_000 {
|
||||
return fmt.Sprintf("%.1fK", float64(n)/1_000)
|
||||
}
|
||||
return fmt.Sprintf("%d", n)
|
||||
}
|
||||
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) > maxLen {
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
364
internal/tui/i18n.go
Normal file
364
internal/tui/i18n.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package tui
|
||||
|
||||
// i18n provides a simple internationalization system for the TUI.
|
||||
// Supported locales: "zh" (Chinese, default), "en" (English).
|
||||
|
||||
var currentLocale = "en"
|
||||
|
||||
// SetLocale changes the active locale.
|
||||
func SetLocale(locale string) {
|
||||
if _, ok := locales[locale]; ok {
|
||||
currentLocale = locale
|
||||
}
|
||||
}
|
||||
|
||||
// CurrentLocale returns the active locale code.
|
||||
func CurrentLocale() string {
|
||||
return currentLocale
|
||||
}
|
||||
|
||||
// ToggleLocale switches between zh and en.
|
||||
func ToggleLocale() {
|
||||
if currentLocale == "zh" {
|
||||
currentLocale = "en"
|
||||
} else {
|
||||
currentLocale = "zh"
|
||||
}
|
||||
}
|
||||
|
||||
// T returns the translated string for the given key.
|
||||
func T(key string) string {
|
||||
if m, ok := locales[currentLocale]; ok {
|
||||
if v, ok := m[key]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
// Fallback to English
|
||||
if m, ok := locales["en"]; ok {
|
||||
if v, ok := m[key]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
var locales = map[string]map[string]string{
|
||||
"zh": zhStrings,
|
||||
"en": enStrings,
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────
|
||||
// Tab names
|
||||
// ──────────────────────────────────────────
|
||||
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
|
||||
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
|
||||
|
||||
// TabNames returns tab names in the current locale.
|
||||
func TabNames() []string {
|
||||
if currentLocale == "zh" {
|
||||
return zhTabNames
|
||||
}
|
||||
return enTabNames
|
||||
}
|
||||
|
||||
var zhStrings = map[string]string{
|
||||
// ── Common ──
|
||||
"loading": "加载中...",
|
||||
"refresh": "刷新",
|
||||
"save": "保存",
|
||||
"cancel": "取消",
|
||||
"confirm": "确认",
|
||||
"yes": "是",
|
||||
"no": "否",
|
||||
"error": "错误",
|
||||
"success": "成功",
|
||||
"navigate": "导航",
|
||||
"scroll": "滚动",
|
||||
"enter_save": "Enter: 保存",
|
||||
"esc_cancel": "Esc: 取消",
|
||||
"enter_submit": "Enter: 提交",
|
||||
"press_r": "[r] 刷新",
|
||||
"press_scroll": "[↑↓] 滚动",
|
||||
"not_set": "(未设置)",
|
||||
"error_prefix": "⚠ 错误: ",
|
||||
|
||||
// ── Status bar ──
|
||||
"status_left": " CLIProxyAPI 管理终端",
|
||||
"status_right": "Tab/Shift+Tab: 切换 • L: 语言 • q/Ctrl+C: 退出 ",
|
||||
"initializing_tui": "正在初始化...",
|
||||
"auth_gate_title": "🔐 连接管理 API",
|
||||
"auth_gate_help": " 请输入管理密码并按 Enter 连接",
|
||||
"auth_gate_password": "密码",
|
||||
"auth_gate_enter": " Enter: 连接 • q/Ctrl+C: 退出 • L: 语言",
|
||||
"auth_gate_connecting": "正在连接...",
|
||||
"auth_gate_connect_fail": "连接失败:%s",
|
||||
"auth_gate_password_required": "请输入密码",
|
||||
|
||||
// ── Dashboard ──
|
||||
"dashboard_title": "📊 仪表盘",
|
||||
"dashboard_help": " [r] 刷新 • [↑↓] 滚动",
|
||||
"connected": "● 已连接",
|
||||
"mgmt_keys": "管理密钥",
|
||||
"auth_files_label": "认证文件",
|
||||
"active_suffix": "活跃",
|
||||
"total_requests": "请求",
|
||||
"success_label": "成功",
|
||||
"failure_label": "失败",
|
||||
"total_tokens": "总 Tokens",
|
||||
"current_config": "当前配置",
|
||||
"debug_mode": "启用调试模式",
|
||||
"usage_stats": "启用使用统计",
|
||||
"log_to_file": "启用日志记录到文件",
|
||||
"retry_count": "重试次数",
|
||||
"proxy_url": "代理 URL",
|
||||
"routing_strategy": "路由策略",
|
||||
"model_stats": "模型统计",
|
||||
"model": "模型",
|
||||
"requests": "请求数",
|
||||
"tokens": "Tokens",
|
||||
"bool_yes": "是 ✓",
|
||||
"bool_no": "否",
|
||||
|
||||
// ── Config ──
|
||||
"config_title": "⚙ 配置",
|
||||
"config_help1": " [↑↓/jk] 导航 • [Enter/Space] 编辑 • [r] 刷新",
|
||||
"config_help2": " 布尔: Enter 切换 • 文本/数字: Enter 输入, Enter 确认, Esc 取消",
|
||||
"updated_ok": "✓ 更新成功",
|
||||
"no_config": " 未加载配置",
|
||||
"invalid_int": "无效整数",
|
||||
"section_server": "服务器",
|
||||
"section_logging": "日志与统计",
|
||||
"section_quota": "配额超限处理",
|
||||
"section_routing": "路由",
|
||||
"section_websocket": "WebSocket",
|
||||
"section_ampcode": "AMP Code",
|
||||
"section_other": "其他",
|
||||
|
||||
// ── Auth Files ──
|
||||
"auth_title": "🔑 认证文件",
|
||||
"auth_help1": " [↑↓/jk] 导航 • [Enter] 展开 • [e] 启用/停用 • [d] 删除 • [r] 刷新",
|
||||
"auth_help2": " [1] 编辑 prefix • [2] 编辑 proxy_url • [3] 编辑 priority",
|
||||
"no_auth_files": " 无认证文件",
|
||||
"confirm_delete": "⚠ 删除 %s? [y/n]",
|
||||
"deleted": "已删除 %s",
|
||||
"enabled": "已启用",
|
||||
"disabled": "已停用",
|
||||
"updated_field": "已更新 %s 的 %s",
|
||||
"status_active": "活跃",
|
||||
"status_disabled": "已停用",
|
||||
|
||||
// ── API Keys ──
|
||||
"keys_title": "🔐 API 密钥",
|
||||
"keys_help": " [↑↓/jk] 导航 • [a] 添加 • [e] 编辑 • [d] 删除 • [c] 复制 • [r] 刷新",
|
||||
"no_keys": " 无 API Key,按 [a] 添加",
|
||||
"access_keys": "Access API Keys",
|
||||
"confirm_delete_key": "⚠ 确认删除 %s? [y/n]",
|
||||
"key_added": "已添加 API Key",
|
||||
"key_updated": "已更新 API Key",
|
||||
"key_deleted": "已删除 API Key",
|
||||
"copied": "✓ 已复制到剪贴板",
|
||||
"copy_failed": "✗ 复制失败",
|
||||
"new_key_prompt": " New Key: ",
|
||||
"edit_key_prompt": " Edit Key: ",
|
||||
"enter_add": " Enter: 添加 • Esc: 取消",
|
||||
"enter_save_esc": " Enter: 保存 • Esc: 取消",
|
||||
|
||||
// ── OAuth ──
|
||||
"oauth_title": "🔐 OAuth 登录",
|
||||
"oauth_select": " 选择提供商并按 [Enter] 开始 OAuth 登录:",
|
||||
"oauth_help": " [↑↓/jk] 导航 • [Enter] 登录 • [Esc] 清除状态",
|
||||
"oauth_initiating": "⏳ 正在初始化 %s 登录...",
|
||||
"oauth_success": "认证成功! 请刷新 Auth Files 标签查看新凭证。",
|
||||
"oauth_completed": "认证流程已完成。",
|
||||
"oauth_failed": "认证失败",
|
||||
"oauth_timeout": "OAuth 流程超时 (5 分钟)",
|
||||
"oauth_press_esc": " 按 [Esc] 取消",
|
||||
"oauth_auth_url": " 授权链接:",
|
||||
"oauth_remote_hint": " 远程浏览器模式:在浏览器中打开上述链接完成授权后,将回调 URL 粘贴到下方。",
|
||||
"oauth_callback_url": " 回调 URL:",
|
||||
"oauth_press_c": " 按 [c] 输入回调 URL • [Esc] 返回",
|
||||
"oauth_submitting": "⏳ 提交回调中...",
|
||||
"oauth_submit_ok": "✓ 回调已提交,等待处理...",
|
||||
"oauth_submit_fail": "✗ 提交回调失败",
|
||||
"oauth_waiting": " 等待认证中...",
|
||||
|
||||
// ── Usage ──
|
||||
"usage_title": "📈 使用统计",
|
||||
"usage_help": " [r] 刷新 • [↑↓] 滚动",
|
||||
"usage_no_data": " 使用数据不可用",
|
||||
"usage_total_reqs": "总请求数",
|
||||
"usage_total_tokens": "总 Token 数",
|
||||
"usage_success": "成功",
|
||||
"usage_failure": "失败",
|
||||
"usage_total_token_l": "总Token",
|
||||
"usage_rpm": "RPM",
|
||||
"usage_tpm": "TPM",
|
||||
"usage_req_by_hour": "请求趋势 (按小时)",
|
||||
"usage_tok_by_hour": "Token 使用趋势 (按小时)",
|
||||
"usage_req_by_day": "请求趋势 (按天)",
|
||||
"usage_api_detail": "API 详细统计",
|
||||
"usage_input": "输入",
|
||||
"usage_output": "输出",
|
||||
"usage_cached": "缓存",
|
||||
"usage_reasoning": "思考",
|
||||
|
||||
// ── Logs ──
|
||||
"logs_title": "📋 日志",
|
||||
"logs_auto_scroll": "● 自动滚动",
|
||||
"logs_paused": "○ 已暂停",
|
||||
"logs_filter": "过滤",
|
||||
"logs_lines": "行数",
|
||||
"logs_help": " [a] 自动滚动 • [c] 清除 • [1] 全部 [2] info+ [3] warn+ [4] error • [↑↓] 滚动",
|
||||
"logs_waiting": " 等待日志输出...",
|
||||
}
|
||||
|
||||
var enStrings = map[string]string{
|
||||
// ── Common ──
|
||||
"loading": "Loading...",
|
||||
"refresh": "Refresh",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel",
|
||||
"confirm": "Confirm",
|
||||
"yes": "Yes",
|
||||
"no": "No",
|
||||
"error": "Error",
|
||||
"success": "Success",
|
||||
"navigate": "Navigate",
|
||||
"scroll": "Scroll",
|
||||
"enter_save": "Enter: Save",
|
||||
"esc_cancel": "Esc: Cancel",
|
||||
"enter_submit": "Enter: Submit",
|
||||
"press_r": "[r] Refresh",
|
||||
"press_scroll": "[↑↓] Scroll",
|
||||
"not_set": "(not set)",
|
||||
"error_prefix": "⚠ Error: ",
|
||||
|
||||
// ── Status bar ──
|
||||
"status_left": " CLIProxyAPI Management TUI",
|
||||
"status_right": "Tab/Shift+Tab: switch • L: lang • q/Ctrl+C: quit ",
|
||||
"initializing_tui": "Initializing...",
|
||||
"auth_gate_title": "🔐 Connect Management API",
|
||||
"auth_gate_help": " Enter management password and press Enter to connect",
|
||||
"auth_gate_password": "Password",
|
||||
"auth_gate_enter": " Enter: connect • q/Ctrl+C: quit • L: lang",
|
||||
"auth_gate_connecting": "Connecting...",
|
||||
"auth_gate_connect_fail": "Connection failed: %s",
|
||||
"auth_gate_password_required": "password is required",
|
||||
|
||||
// ── Dashboard ──
|
||||
"dashboard_title": "📊 Dashboard",
|
||||
"dashboard_help": " [r] Refresh • [↑↓] Scroll",
|
||||
"connected": "● Connected",
|
||||
"mgmt_keys": "Mgmt Keys",
|
||||
"auth_files_label": "Auth Files",
|
||||
"active_suffix": "active",
|
||||
"total_requests": "Requests",
|
||||
"success_label": "Success",
|
||||
"failure_label": "Failed",
|
||||
"total_tokens": "Total Tokens",
|
||||
"current_config": "Current Config",
|
||||
"debug_mode": "Debug Mode",
|
||||
"usage_stats": "Usage Statistics",
|
||||
"log_to_file": "Log to File",
|
||||
"retry_count": "Retry Count",
|
||||
"proxy_url": "Proxy URL",
|
||||
"routing_strategy": "Routing Strategy",
|
||||
"model_stats": "Model Stats",
|
||||
"model": "Model",
|
||||
"requests": "Requests",
|
||||
"tokens": "Tokens",
|
||||
"bool_yes": "Yes ✓",
|
||||
"bool_no": "No",
|
||||
|
||||
// ── Config ──
|
||||
"config_title": "⚙ Configuration",
|
||||
"config_help1": " [↑↓/jk] Navigate • [Enter/Space] Edit • [r] Refresh",
|
||||
"config_help2": " Bool: Enter to toggle • String/Int: Enter to type, Enter to confirm, Esc to cancel",
|
||||
"updated_ok": "✓ Updated successfully",
|
||||
"no_config": " No configuration loaded",
|
||||
"invalid_int": "invalid integer",
|
||||
"section_server": "Server",
|
||||
"section_logging": "Logging & Stats",
|
||||
"section_quota": "Quota Exceeded Handling",
|
||||
"section_routing": "Routing",
|
||||
"section_websocket": "WebSocket",
|
||||
"section_ampcode": "AMP Code",
|
||||
"section_other": "Other",
|
||||
|
||||
// ── Auth Files ──
|
||||
"auth_title": "🔑 Auth Files",
|
||||
"auth_help1": " [↑↓/jk] Navigate • [Enter] Expand • [e] Enable/Disable • [d] Delete • [r] Refresh",
|
||||
"auth_help2": " [1] Edit prefix • [2] Edit proxy_url • [3] Edit priority",
|
||||
"no_auth_files": " No auth files found",
|
||||
"confirm_delete": "⚠ Delete %s? [y/n]",
|
||||
"deleted": "Deleted %s",
|
||||
"enabled": "Enabled",
|
||||
"disabled": "Disabled",
|
||||
"updated_field": "Updated %s on %s",
|
||||
"status_active": "active",
|
||||
"status_disabled": "disabled",
|
||||
|
||||
// ── API Keys ──
|
||||
"keys_title": "🔐 API Keys",
|
||||
"keys_help": " [↑↓/jk] Navigate • [a] Add • [e] Edit • [d] Delete • [c] Copy • [r] Refresh",
|
||||
"no_keys": " No API Keys. Press [a] to add",
|
||||
"access_keys": "Access API Keys",
|
||||
"confirm_delete_key": "⚠ Delete %s? [y/n]",
|
||||
"key_added": "API Key added",
|
||||
"key_updated": "API Key updated",
|
||||
"key_deleted": "API Key deleted",
|
||||
"copied": "✓ Copied to clipboard",
|
||||
"copy_failed": "✗ Copy failed",
|
||||
"new_key_prompt": " New Key: ",
|
||||
"edit_key_prompt": " Edit Key: ",
|
||||
"enter_add": " Enter: Add • Esc: Cancel",
|
||||
"enter_save_esc": " Enter: Save • Esc: Cancel",
|
||||
|
||||
// ── OAuth ──
|
||||
"oauth_title": "🔐 OAuth Login",
|
||||
"oauth_select": " Select a provider and press [Enter] to start OAuth login:",
|
||||
"oauth_help": " [↑↓/jk] Navigate • [Enter] Login • [Esc] Clear status",
|
||||
"oauth_initiating": "⏳ Initiating %s login...",
|
||||
"oauth_success": "Authentication successful! Refresh Auth Files tab to see the new credential.",
|
||||
"oauth_completed": "Authentication flow completed.",
|
||||
"oauth_failed": "Authentication failed",
|
||||
"oauth_timeout": "OAuth flow timed out (5 minutes)",
|
||||
"oauth_press_esc": " Press [Esc] to cancel",
|
||||
"oauth_auth_url": " Authorization URL:",
|
||||
"oauth_remote_hint": " Remote browser mode: Open the URL above in browser, paste the callback URL below after authorization.",
|
||||
"oauth_callback_url": " Callback URL:",
|
||||
"oauth_press_c": " Press [c] to enter callback URL • [Esc] to go back",
|
||||
"oauth_submitting": "⏳ Submitting callback...",
|
||||
"oauth_submit_ok": "✓ Callback submitted, waiting...",
|
||||
"oauth_submit_fail": "✗ Callback submission failed",
|
||||
"oauth_waiting": " Waiting for authentication...",
|
||||
|
||||
// ── Usage ──
|
||||
"usage_title": "📈 Usage Statistics",
|
||||
"usage_help": " [r] Refresh • [↑↓] Scroll",
|
||||
"usage_no_data": " Usage data not available",
|
||||
"usage_total_reqs": "Total Requests",
|
||||
"usage_total_tokens": "Total Tokens",
|
||||
"usage_success": "Success",
|
||||
"usage_failure": "Failed",
|
||||
"usage_total_token_l": "Total Tokens",
|
||||
"usage_rpm": "RPM",
|
||||
"usage_tpm": "TPM",
|
||||
"usage_req_by_hour": "Requests by Hour",
|
||||
"usage_tok_by_hour": "Token Usage by Hour",
|
||||
"usage_req_by_day": "Requests by Day",
|
||||
"usage_api_detail": "API Detail Statistics",
|
||||
"usage_input": "Input",
|
||||
"usage_output": "Output",
|
||||
"usage_cached": "Cached",
|
||||
"usage_reasoning": "Reasoning",
|
||||
|
||||
// ── Logs ──
|
||||
"logs_title": "📋 Logs",
|
||||
"logs_auto_scroll": "● AUTO-SCROLL",
|
||||
"logs_paused": "○ PAUSED",
|
||||
"logs_filter": "Filter",
|
||||
"logs_lines": "Lines",
|
||||
"logs_help": " [a] Auto-scroll • [c] Clear • [1] All [2] info+ [3] warn+ [4] error • [↑↓] Scroll",
|
||||
"logs_waiting": " Waiting for log output...",
|
||||
}
|
||||
405
internal/tui/keys_tab.go
Normal file
405
internal/tui/keys_tab.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/atotto/clipboard"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// keysTabModel displays and manages API keys.
|
||||
type keysTabModel struct {
|
||||
client *Client
|
||||
viewport viewport.Model
|
||||
keys []string
|
||||
gemini []map[string]any
|
||||
claude []map[string]any
|
||||
codex []map[string]any
|
||||
vertex []map[string]any
|
||||
openai []map[string]any
|
||||
err error
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
cursor int
|
||||
confirm int // -1 = no deletion pending
|
||||
status string
|
||||
|
||||
// Editing / Adding
|
||||
editing bool
|
||||
adding bool
|
||||
editIdx int
|
||||
editInput textinput.Model
|
||||
}
|
||||
|
||||
type keysDataMsg struct {
|
||||
apiKeys []string
|
||||
gemini []map[string]any
|
||||
claude []map[string]any
|
||||
codex []map[string]any
|
||||
vertex []map[string]any
|
||||
openai []map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
type keyActionMsg struct {
|
||||
action string
|
||||
err error
|
||||
}
|
||||
|
||||
func newKeysTabModel(client *Client) keysTabModel {
|
||||
ti := textinput.New()
|
||||
ti.CharLimit = 512
|
||||
ti.Prompt = " Key: "
|
||||
return keysTabModel{
|
||||
client: client,
|
||||
confirm: -1,
|
||||
editInput: ti,
|
||||
}
|
||||
}
|
||||
|
||||
func (m keysTabModel) Init() tea.Cmd {
|
||||
return m.fetchKeys
|
||||
}
|
||||
|
||||
func (m keysTabModel) fetchKeys() tea.Msg {
|
||||
result := keysDataMsg{}
|
||||
apiKeys, err := m.client.GetAPIKeys()
|
||||
if err != nil {
|
||||
result.err = err
|
||||
return result
|
||||
}
|
||||
result.apiKeys = apiKeys
|
||||
result.gemini, _ = m.client.GetGeminiKeys()
|
||||
result.claude, _ = m.client.GetClaudeKeys()
|
||||
result.codex, _ = m.client.GetCodexKeys()
|
||||
result.vertex, _ = m.client.GetVertexKeys()
|
||||
result.openai, _ = m.client.GetOpenAICompat()
|
||||
return result
|
||||
}
|
||||
|
||||
func (m keysTabModel) Update(msg tea.Msg) (keysTabModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
case keysDataMsg:
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
} else {
|
||||
m.err = nil
|
||||
m.keys = msg.apiKeys
|
||||
m.gemini = msg.gemini
|
||||
m.claude = msg.claude
|
||||
m.codex = msg.codex
|
||||
m.vertex = msg.vertex
|
||||
m.openai = msg.openai
|
||||
if m.cursor >= len(m.keys) {
|
||||
m.cursor = max(0, len(m.keys)-1)
|
||||
}
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
|
||||
case keyActionMsg:
|
||||
if msg.err != nil {
|
||||
m.status = errorStyle.Render("✗ " + msg.err.Error())
|
||||
} else {
|
||||
m.status = successStyle.Render("✓ " + msg.action)
|
||||
}
|
||||
m.confirm = -1
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, m.fetchKeys
|
||||
|
||||
case tea.KeyMsg:
|
||||
// ---- Editing / Adding mode ----
|
||||
if m.editing || m.adding {
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
value := strings.TrimSpace(m.editInput.Value())
|
||||
if value == "" {
|
||||
m.editing = false
|
||||
m.adding = false
|
||||
m.editInput.Blur()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
}
|
||||
isAdding := m.adding
|
||||
editIdx := m.editIdx
|
||||
m.editing = false
|
||||
m.adding = false
|
||||
m.editInput.Blur()
|
||||
if isAdding {
|
||||
return m, func() tea.Msg {
|
||||
err := m.client.AddAPIKey(value)
|
||||
if err != nil {
|
||||
return keyActionMsg{err: err}
|
||||
}
|
||||
return keyActionMsg{action: T("key_added")}
|
||||
}
|
||||
}
|
||||
return m, func() tea.Msg {
|
||||
err := m.client.EditAPIKey(editIdx, value)
|
||||
if err != nil {
|
||||
return keyActionMsg{err: err}
|
||||
}
|
||||
return keyActionMsg{action: T("key_updated")}
|
||||
}
|
||||
case "esc":
|
||||
m.editing = false
|
||||
m.adding = false
|
||||
m.editInput.Blur()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.editInput, cmd = m.editInput.Update(msg)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Delete confirmation ----
|
||||
if m.confirm >= 0 {
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
idx := m.confirm
|
||||
m.confirm = -1
|
||||
return m, func() tea.Msg {
|
||||
err := m.client.DeleteAPIKey(idx)
|
||||
if err != nil {
|
||||
return keyActionMsg{err: err}
|
||||
}
|
||||
return keyActionMsg{action: T("key_deleted")}
|
||||
}
|
||||
case "n", "N", "esc":
|
||||
m.confirm = -1
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ---- Normal mode ----
|
||||
switch msg.String() {
|
||||
case "j", "down":
|
||||
if len(m.keys) > 0 {
|
||||
m.cursor = (m.cursor + 1) % len(m.keys)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "k", "up":
|
||||
if len(m.keys) > 0 {
|
||||
m.cursor = (m.cursor - 1 + len(m.keys)) % len(m.keys)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "a":
|
||||
// Add new key
|
||||
m.adding = true
|
||||
m.editing = false
|
||||
m.editInput.SetValue("")
|
||||
m.editInput.Prompt = T("new_key_prompt")
|
||||
m.editInput.Focus()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, textinput.Blink
|
||||
case "e":
|
||||
// Edit selected key
|
||||
if m.cursor < len(m.keys) {
|
||||
m.editing = true
|
||||
m.adding = false
|
||||
m.editIdx = m.cursor
|
||||
m.editInput.SetValue(m.keys[m.cursor])
|
||||
m.editInput.Prompt = T("edit_key_prompt")
|
||||
m.editInput.Focus()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, textinput.Blink
|
||||
}
|
||||
return m, nil
|
||||
case "d":
|
||||
// Delete selected key
|
||||
if m.cursor < len(m.keys) {
|
||||
m.confirm = m.cursor
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "c":
|
||||
// Copy selected key to clipboard
|
||||
if m.cursor < len(m.keys) {
|
||||
key := m.keys[m.cursor]
|
||||
if err := clipboard.WriteAll(key); err != nil {
|
||||
m.status = errorStyle.Render(T("copy_failed") + ": " + err.Error())
|
||||
} else {
|
||||
m.status = successStyle.Render(T("copied"))
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "r":
|
||||
m.status = ""
|
||||
return m, m.fetchKeys
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *keysTabModel) SetSize(w, h int) {
|
||||
m.width = w
|
||||
m.height = h
|
||||
m.editInput.Width = w - 16
|
||||
if !m.ready {
|
||||
m.viewport = viewport.New(w, h)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
m.ready = true
|
||||
} else {
|
||||
m.viewport.Width = w
|
||||
m.viewport.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
func (m keysTabModel) View() string {
|
||||
if !m.ready {
|
||||
return T("loading")
|
||||
}
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m keysTabModel) renderContent() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("keys_title")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("keys_help")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", m.width))
|
||||
sb.WriteString("\n")
|
||||
|
||||
if m.err != nil {
|
||||
sb.WriteString(errorStyle.Render(T("error_prefix") + m.err.Error()))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ━━━ Access API Keys (interactive) ━━━
|
||||
sb.WriteString(tableHeaderStyle.Render(fmt.Sprintf(" %s (%d)", T("access_keys"), len(m.keys))))
|
||||
sb.WriteString("\n")
|
||||
|
||||
if len(m.keys) == 0 {
|
||||
sb.WriteString(subtitleStyle.Render(T("no_keys")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
for i, key := range m.keys {
|
||||
cursor := " "
|
||||
rowStyle := lipgloss.NewStyle()
|
||||
if i == m.cursor {
|
||||
cursor = "▸ "
|
||||
rowStyle = lipgloss.NewStyle().Bold(true)
|
||||
}
|
||||
|
||||
row := fmt.Sprintf("%s%d. %s", cursor, i+1, maskKey(key))
|
||||
sb.WriteString(rowStyle.Render(row))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Delete confirmation
|
||||
if m.confirm == i {
|
||||
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete_key"), maskKey(key))))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Edit input
|
||||
if m.editing && m.editIdx == i {
|
||||
sb.WriteString(m.editInput.View())
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("enter_save_esc")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Add input
|
||||
if m.adding {
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(m.editInput.View())
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("enter_add")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// ━━━ Provider Keys (read-only display) ━━━
|
||||
renderProviderKeys(&sb, "Gemini API Keys", m.gemini)
|
||||
renderProviderKeys(&sb, "Claude API Keys", m.claude)
|
||||
renderProviderKeys(&sb, "Codex API Keys", m.codex)
|
||||
renderProviderKeys(&sb, "Vertex API Keys", m.vertex)
|
||||
|
||||
if len(m.openai) > 0 {
|
||||
renderSection(&sb, "OpenAI Compatibility", len(m.openai))
|
||||
for i, entry := range m.openai {
|
||||
name := getString(entry, "name")
|
||||
baseURL := getString(entry, "base-url")
|
||||
prefix := getString(entry, "prefix")
|
||||
info := name
|
||||
if prefix != "" {
|
||||
info += " (prefix: " + prefix + ")"
|
||||
}
|
||||
if baseURL != "" {
|
||||
info += " → " + baseURL
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
if m.status != "" {
|
||||
sb.WriteString(m.status)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func renderSection(sb *strings.Builder, title string, count int) {
|
||||
header := fmt.Sprintf("%s (%d)", title, count)
|
||||
sb.WriteString(tableHeaderStyle.Render(" " + header))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
func renderProviderKeys(sb *strings.Builder, title string, keys []map[string]any) {
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
renderSection(sb, title, len(keys))
|
||||
for i, key := range keys {
|
||||
apiKey := getString(key, "api-key")
|
||||
prefix := getString(key, "prefix")
|
||||
baseURL := getString(key, "base-url")
|
||||
info := maskKey(apiKey)
|
||||
if prefix != "" {
|
||||
info += " (prefix: " + prefix + ")"
|
||||
}
|
||||
if baseURL != "" {
|
||||
info += " → " + baseURL
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
func maskKey(key string) string {
|
||||
if len(key) <= 8 {
|
||||
return strings.Repeat("*", len(key))
|
||||
}
|
||||
return key[:4] + strings.Repeat("*", len(key)-8) + key[len(key)-4:]
|
||||
}
|
||||
78
internal/tui/loghook.go
Normal file
78
internal/tui/loghook.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// LogHook is a logrus hook that captures log entries and sends them to a channel.
|
||||
type LogHook struct {
|
||||
ch chan string
|
||||
formatter log.Formatter
|
||||
mu sync.Mutex
|
||||
levels []log.Level
|
||||
}
|
||||
|
||||
// NewLogHook creates a new LogHook with a buffered channel of the given size.
|
||||
func NewLogHook(bufSize int) *LogHook {
|
||||
return &LogHook{
|
||||
ch: make(chan string, bufSize),
|
||||
formatter: &log.TextFormatter{DisableColors: true, FullTimestamp: true},
|
||||
levels: log.AllLevels,
|
||||
}
|
||||
}
|
||||
|
||||
// SetFormatter sets a custom formatter for the hook.
|
||||
func (h *LogHook) SetFormatter(f log.Formatter) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.formatter = f
|
||||
}
|
||||
|
||||
// Levels returns the log levels this hook should fire on.
|
||||
func (h *LogHook) Levels() []log.Level {
|
||||
return h.levels
|
||||
}
|
||||
|
||||
// Fire is called by logrus when a log entry is fired.
|
||||
func (h *LogHook) Fire(entry *log.Entry) error {
|
||||
h.mu.Lock()
|
||||
f := h.formatter
|
||||
h.mu.Unlock()
|
||||
|
||||
var line string
|
||||
if f != nil {
|
||||
b, err := f.Format(entry)
|
||||
if err == nil {
|
||||
line = strings.TrimRight(string(b), "\n\r")
|
||||
} else {
|
||||
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
|
||||
}
|
||||
} else {
|
||||
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
|
||||
}
|
||||
|
||||
// Non-blocking send
|
||||
select {
|
||||
case h.ch <- line:
|
||||
default:
|
||||
// Drop oldest if full
|
||||
select {
|
||||
case <-h.ch:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case h.ch <- line:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chan returns the channel to read log lines from.
|
||||
func (h *LogHook) Chan() <-chan string {
|
||||
return h.ch
|
||||
}
|
||||
261
internal/tui/logs_tab.go
Normal file
261
internal/tui/logs_tab.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// logsTabModel displays real-time log lines from hook/API source.
|
||||
type logsTabModel struct {
|
||||
client *Client
|
||||
hook *LogHook
|
||||
viewport viewport.Model
|
||||
lines []string
|
||||
maxLines int
|
||||
autoScroll bool
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
filter string // "", "debug", "info", "warn", "error"
|
||||
after int64
|
||||
lastErr error
|
||||
}
|
||||
|
||||
type logsPollMsg struct {
|
||||
lines []string
|
||||
latest int64
|
||||
err error
|
||||
}
|
||||
|
||||
type logsTickMsg struct{}
|
||||
type logLineMsg string
|
||||
|
||||
func newLogsTabModel(client *Client, hook *LogHook) logsTabModel {
|
||||
return logsTabModel{
|
||||
client: client,
|
||||
hook: hook,
|
||||
maxLines: 5000,
|
||||
autoScroll: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (m logsTabModel) Init() tea.Cmd {
|
||||
if m.hook != nil {
|
||||
return m.waitForLog
|
||||
}
|
||||
return m.fetchLogs
|
||||
}
|
||||
|
||||
func (m logsTabModel) fetchLogs() tea.Msg {
|
||||
lines, latest, err := m.client.GetLogs(m.after, 200)
|
||||
return logsPollMsg{
|
||||
lines: lines,
|
||||
latest: latest,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (m logsTabModel) waitForNextPoll() tea.Cmd {
|
||||
return tea.Tick(2*time.Second, func(_ time.Time) tea.Msg {
|
||||
return logsTickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func (m logsTabModel) waitForLog() tea.Msg {
|
||||
if m.hook == nil {
|
||||
return nil
|
||||
}
|
||||
line, ok := <-m.hook.Chan()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return logLineMsg(line)
|
||||
}
|
||||
|
||||
func (m logsTabModel) Update(msg tea.Msg) (logsTabModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
return m, nil
|
||||
case logsTickMsg:
|
||||
if m.hook != nil {
|
||||
return m, nil
|
||||
}
|
||||
return m, m.fetchLogs
|
||||
case logsPollMsg:
|
||||
if m.hook != nil {
|
||||
return m, nil
|
||||
}
|
||||
if msg.err != nil {
|
||||
m.lastErr = msg.err
|
||||
} else {
|
||||
m.lastErr = nil
|
||||
m.after = msg.latest
|
||||
if len(msg.lines) > 0 {
|
||||
m.lines = append(m.lines, msg.lines...)
|
||||
if len(m.lines) > m.maxLines {
|
||||
m.lines = m.lines[len(m.lines)-m.maxLines:]
|
||||
}
|
||||
}
|
||||
}
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
if m.autoScroll {
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
return m, m.waitForNextPoll()
|
||||
case logLineMsg:
|
||||
m.lines = append(m.lines, string(msg))
|
||||
if len(m.lines) > m.maxLines {
|
||||
m.lines = m.lines[len(m.lines)-m.maxLines:]
|
||||
}
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
if m.autoScroll {
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
return m, m.waitForLog
|
||||
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "a":
|
||||
m.autoScroll = !m.autoScroll
|
||||
if m.autoScroll {
|
||||
m.viewport.GotoBottom()
|
||||
}
|
||||
return m, nil
|
||||
case "c":
|
||||
m.lines = nil
|
||||
m.lastErr = nil
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
return m, nil
|
||||
case "1":
|
||||
m.filter = ""
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
return m, nil
|
||||
case "2":
|
||||
m.filter = "info"
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
return m, nil
|
||||
case "3":
|
||||
m.filter = "warn"
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
return m, nil
|
||||
case "4":
|
||||
m.filter = "error"
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
return m, nil
|
||||
default:
|
||||
wasAtBottom := m.viewport.AtBottom()
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
// If user scrolls up, disable auto-scroll
|
||||
if !m.viewport.AtBottom() && wasAtBottom {
|
||||
m.autoScroll = false
|
||||
}
|
||||
// If user scrolls to bottom, re-enable auto-scroll
|
||||
if m.viewport.AtBottom() {
|
||||
m.autoScroll = true
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *logsTabModel) SetSize(w, h int) {
|
||||
m.width = w
|
||||
m.height = h
|
||||
if !m.ready {
|
||||
m.viewport = viewport.New(w, h)
|
||||
m.viewport.SetContent(m.renderLogs())
|
||||
m.ready = true
|
||||
} else {
|
||||
m.viewport.Width = w
|
||||
m.viewport.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
func (m logsTabModel) View() string {
|
||||
if !m.ready {
|
||||
return T("loading")
|
||||
}
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m logsTabModel) renderLogs() string {
|
||||
var sb strings.Builder
|
||||
|
||||
scrollStatus := successStyle.Render(T("logs_auto_scroll"))
|
||||
if !m.autoScroll {
|
||||
scrollStatus = warningStyle.Render(T("logs_paused"))
|
||||
}
|
||||
filterLabel := "ALL"
|
||||
if m.filter != "" {
|
||||
filterLabel = strings.ToUpper(m.filter) + "+"
|
||||
}
|
||||
|
||||
header := fmt.Sprintf(" %s %s %s: %s %s: %d",
|
||||
T("logs_title"), scrollStatus, T("logs_filter"), filterLabel, T("logs_lines"), len(m.lines))
|
||||
sb.WriteString(titleStyle.Render(header))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("logs_help")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", m.width))
|
||||
sb.WriteString("\n")
|
||||
|
||||
if m.lastErr != nil {
|
||||
sb.WriteString(errorStyle.Render("⚠ Error: " + m.lastErr.Error()))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
if len(m.lines) == 0 {
|
||||
sb.WriteString(subtitleStyle.Render(T("logs_waiting")))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
for _, line := range m.lines {
|
||||
if m.filter != "" && !m.matchLevel(line) {
|
||||
continue
|
||||
}
|
||||
styled := m.styleLine(line)
|
||||
sb.WriteString(styled)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m logsTabModel) matchLevel(line string) bool {
|
||||
switch m.filter {
|
||||
case "error":
|
||||
return strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") || strings.Contains(line, "[panic]")
|
||||
case "warn":
|
||||
return strings.Contains(line, "[warn") || strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]")
|
||||
case "info":
|
||||
return !strings.Contains(line, "[debug]")
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (m logsTabModel) styleLine(line string) string {
|
||||
if strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") {
|
||||
return logErrorStyle.Render(line)
|
||||
}
|
||||
if strings.Contains(line, "[warn") {
|
||||
return logWarnStyle.Render(line)
|
||||
}
|
||||
if strings.Contains(line, "[info") {
|
||||
return logInfoStyle.Render(line)
|
||||
}
|
||||
if strings.Contains(line, "[debug]") {
|
||||
return logDebugStyle.Render(line)
|
||||
}
|
||||
return line
|
||||
}
|
||||
473
internal/tui/oauth_tab.go
Normal file
473
internal/tui/oauth_tab.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// oauthProvider represents an OAuth provider option.
|
||||
type oauthProvider struct {
|
||||
name string
|
||||
apiPath string // management API path
|
||||
emoji string
|
||||
}
|
||||
|
||||
var oauthProviders = []oauthProvider{
|
||||
{"Gemini CLI", "gemini-cli-auth-url", "🟦"},
|
||||
{"Claude (Anthropic)", "anthropic-auth-url", "🟧"},
|
||||
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
|
||||
{"Antigravity", "antigravity-auth-url", "🟪"},
|
||||
{"Qwen", "qwen-auth-url", "🟨"},
|
||||
{"Kimi", "kimi-auth-url", "🟫"},
|
||||
{"IFlow", "iflow-auth-url", "⬜"},
|
||||
}
|
||||
|
||||
// oauthTabModel handles OAuth login flows.
|
||||
type oauthTabModel struct {
|
||||
client *Client
|
||||
viewport viewport.Model
|
||||
cursor int
|
||||
state oauthState
|
||||
message string
|
||||
err error
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
|
||||
// Remote browser mode
|
||||
authURL string // auth URL to display
|
||||
authState string // OAuth state parameter
|
||||
providerName string // current provider name
|
||||
callbackInput textinput.Model
|
||||
inputActive bool // true when user is typing callback URL
|
||||
}
|
||||
|
||||
type oauthState int
|
||||
|
||||
const (
|
||||
oauthIdle oauthState = iota
|
||||
oauthPending
|
||||
oauthRemote // remote browser mode: waiting for manual callback
|
||||
oauthSuccess
|
||||
oauthError
|
||||
)
|
||||
|
||||
// Messages
|
||||
type oauthStartMsg struct {
|
||||
url string
|
||||
state string
|
||||
providerName string
|
||||
err error
|
||||
}
|
||||
|
||||
type oauthPollMsg struct {
|
||||
done bool
|
||||
message string
|
||||
err error
|
||||
}
|
||||
|
||||
type oauthCallbackSubmitMsg struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func newOAuthTabModel(client *Client) oauthTabModel {
|
||||
ti := textinput.New()
|
||||
ti.Placeholder = "http://localhost:.../auth/callback?code=...&state=..."
|
||||
ti.CharLimit = 2048
|
||||
ti.Prompt = " 回调 URL: "
|
||||
return oauthTabModel{
|
||||
client: client,
|
||||
callbackInput: ti,
|
||||
}
|
||||
}
|
||||
|
||||
func (m oauthTabModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m oauthTabModel) Update(msg tea.Msg) (oauthTabModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
case oauthStartMsg:
|
||||
if msg.err != nil {
|
||||
m.state = oauthError
|
||||
m.err = msg.err
|
||||
m.message = errorStyle.Render("✗ " + msg.err.Error())
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
}
|
||||
m.authURL = msg.url
|
||||
m.authState = msg.state
|
||||
m.providerName = msg.providerName
|
||||
m.state = oauthRemote
|
||||
m.callbackInput.SetValue("")
|
||||
m.callbackInput.Focus()
|
||||
m.inputActive = true
|
||||
m.message = ""
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
// Also start polling in the background
|
||||
return m, tea.Batch(textinput.Blink, m.pollOAuthStatus(msg.state))
|
||||
|
||||
case oauthPollMsg:
|
||||
if msg.err != nil {
|
||||
m.state = oauthError
|
||||
m.err = msg.err
|
||||
m.message = errorStyle.Render("✗ " + msg.err.Error())
|
||||
m.inputActive = false
|
||||
m.callbackInput.Blur()
|
||||
} else if msg.done {
|
||||
m.state = oauthSuccess
|
||||
m.message = successStyle.Render("✓ " + msg.message)
|
||||
m.inputActive = false
|
||||
m.callbackInput.Blur()
|
||||
} else {
|
||||
m.message = warningStyle.Render("⏳ " + msg.message)
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
|
||||
case oauthCallbackSubmitMsg:
|
||||
if msg.err != nil {
|
||||
m.message = errorStyle.Render(T("oauth_submit_fail") + ": " + msg.err.Error())
|
||||
} else {
|
||||
m.message = successStyle.Render(T("oauth_submit_ok"))
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
// ---- Input active: typing callback URL ----
|
||||
if m.inputActive {
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
callbackURL := m.callbackInput.Value()
|
||||
if callbackURL == "" {
|
||||
return m, nil
|
||||
}
|
||||
m.inputActive = false
|
||||
m.callbackInput.Blur()
|
||||
m.message = warningStyle.Render(T("oauth_submitting"))
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, m.submitCallback(callbackURL)
|
||||
case "esc":
|
||||
m.inputActive = false
|
||||
m.callbackInput.Blur()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
m.callbackInput, cmd = m.callbackInput.Update(msg)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Remote mode but not typing ----
|
||||
if m.state == oauthRemote {
|
||||
switch msg.String() {
|
||||
case "c", "C":
|
||||
// Re-activate input
|
||||
m.inputActive = true
|
||||
m.callbackInput.Focus()
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, textinput.Blink
|
||||
case "esc":
|
||||
m.state = oauthIdle
|
||||
m.message = ""
|
||||
m.authURL = ""
|
||||
m.authState = ""
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
// ---- Pending (auto polling) ----
|
||||
if m.state == oauthPending {
|
||||
if msg.String() == "esc" {
|
||||
m.state = oauthIdle
|
||||
m.message = ""
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ---- Idle ----
|
||||
switch msg.String() {
|
||||
case "up", "k":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "down", "j":
|
||||
if m.cursor < len(oauthProviders)-1 {
|
||||
m.cursor++
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
}
|
||||
return m, nil
|
||||
case "enter":
|
||||
if m.cursor >= 0 && m.cursor < len(oauthProviders) {
|
||||
provider := oauthProviders[m.cursor]
|
||||
m.state = oauthPending
|
||||
m.message = warningStyle.Render(fmt.Sprintf(T("oauth_initiating"), provider.name))
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, m.startOAuth(provider)
|
||||
}
|
||||
return m, nil
|
||||
case "esc":
|
||||
m.state = oauthIdle
|
||||
m.message = ""
|
||||
m.err = nil
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m oauthTabModel) startOAuth(provider oauthProvider) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Call the auth URL endpoint with is_webui=true
|
||||
data, err := m.client.getJSON("/v0/management/" + provider.apiPath + "?is_webui=true")
|
||||
if err != nil {
|
||||
return oauthStartMsg{err: fmt.Errorf("failed to start %s login: %w", provider.name, err)}
|
||||
}
|
||||
|
||||
authURL := getString(data, "url")
|
||||
state := getString(data, "state")
|
||||
if authURL == "" {
|
||||
return oauthStartMsg{err: fmt.Errorf("no auth URL returned for %s", provider.name)}
|
||||
}
|
||||
|
||||
// Try to open browser (best effort)
|
||||
_ = openBrowser(authURL)
|
||||
|
||||
return oauthStartMsg{url: authURL, state: state, providerName: provider.name}
|
||||
}
|
||||
}
|
||||
|
||||
func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Determine provider from current context
|
||||
providerKey := ""
|
||||
for _, p := range oauthProviders {
|
||||
if p.name == m.providerName {
|
||||
// Map provider name to the canonical key the API expects
|
||||
switch p.apiPath {
|
||||
case "gemini-cli-auth-url":
|
||||
providerKey = "gemini"
|
||||
case "anthropic-auth-url":
|
||||
providerKey = "anthropic"
|
||||
case "codex-auth-url":
|
||||
providerKey = "codex"
|
||||
case "antigravity-auth-url":
|
||||
providerKey = "antigravity"
|
||||
case "qwen-auth-url":
|
||||
providerKey = "qwen"
|
||||
case "kimi-auth-url":
|
||||
providerKey = "kimi"
|
||||
case "iflow-auth-url":
|
||||
providerKey = "iflow"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
body := map[string]string{
|
||||
"provider": providerKey,
|
||||
"redirect_url": callbackURL,
|
||||
"state": m.authState,
|
||||
}
|
||||
err := m.client.postJSON("/v0/management/oauth-callback", body)
|
||||
if err != nil {
|
||||
return oauthCallbackSubmitMsg{err: err}
|
||||
}
|
||||
return oauthCallbackSubmitMsg{}
|
||||
}
|
||||
}
|
||||
|
||||
func (m oauthTabModel) pollOAuthStatus(state string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Poll session status for up to 5 minutes
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
return oauthPollMsg{done: false, err: fmt.Errorf("%s", T("oauth_timeout"))}
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
status, errMsg, err := m.client.GetAuthStatus(state)
|
||||
if err != nil {
|
||||
continue // Ignore transient errors
|
||||
}
|
||||
|
||||
switch status {
|
||||
case "ok":
|
||||
return oauthPollMsg{
|
||||
done: true,
|
||||
message: T("oauth_success"),
|
||||
}
|
||||
case "error":
|
||||
return oauthPollMsg{
|
||||
done: false,
|
||||
err: fmt.Errorf("%s: %s", T("oauth_failed"), errMsg),
|
||||
}
|
||||
case "wait":
|
||||
continue
|
||||
default:
|
||||
return oauthPollMsg{
|
||||
done: true,
|
||||
message: T("oauth_completed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *oauthTabModel) SetSize(w, h int) {
|
||||
m.width = w
|
||||
m.height = h
|
||||
m.callbackInput.Width = w - 16
|
||||
if !m.ready {
|
||||
m.viewport = viewport.New(w, h)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
m.ready = true
|
||||
} else {
|
||||
m.viewport.Width = w
|
||||
m.viewport.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
func (m oauthTabModel) View() string {
|
||||
if !m.ready {
|
||||
return T("loading")
|
||||
}
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m oauthTabModel) renderContent() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("oauth_title")))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if m.message != "" {
|
||||
sb.WriteString(" " + m.message)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
// ---- Remote browser mode ----
|
||||
if m.state == oauthRemote {
|
||||
sb.WriteString(m.renderRemoteMode())
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
if m.state == oauthPending {
|
||||
sb.WriteString(helpStyle.Render(T("oauth_press_esc")))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
sb.WriteString(helpStyle.Render(T("oauth_select")))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
for i, p := range oauthProviders {
|
||||
isSelected := i == m.cursor
|
||||
prefix := " "
|
||||
if isSelected {
|
||||
prefix = "▸ "
|
||||
}
|
||||
|
||||
label := fmt.Sprintf("%s %s", p.emoji, p.name)
|
||||
if isSelected {
|
||||
label = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#FFFFFF")).Background(colorPrimary).Padding(0, 1).Render(label)
|
||||
} else {
|
||||
label = lipgloss.NewStyle().Foreground(colorText).Padding(0, 1).Render(label)
|
||||
}
|
||||
|
||||
sb.WriteString(prefix + label + "\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("oauth_help")))
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m oauthTabModel) renderRemoteMode() string {
|
||||
var sb strings.Builder
|
||||
|
||||
providerStyle := lipgloss.NewStyle().Bold(true).Foreground(colorHighlight)
|
||||
sb.WriteString(providerStyle.Render(fmt.Sprintf(" ✦ %s OAuth", m.providerName)))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// Auth URL section
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_auth_url")))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Wrap URL to fit terminal width
|
||||
urlStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
|
||||
maxURLWidth := m.width - 6
|
||||
if maxURLWidth < 40 {
|
||||
maxURLWidth = 40
|
||||
}
|
||||
wrappedURL := wrapText(m.authURL, maxURLWidth)
|
||||
for _, line := range wrappedURL {
|
||||
sb.WriteString(" " + urlStyle.Render(line) + "\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
sb.WriteString(helpStyle.Render(T("oauth_remote_hint")))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// Callback URL input
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_callback_url")))
|
||||
sb.WriteString("\n")
|
||||
|
||||
if m.inputActive {
|
||||
sb.WriteString(m.callbackInput.View())
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(" " + T("enter_submit") + " • " + T("esc_cancel")))
|
||||
} else {
|
||||
sb.WriteString(helpStyle.Render(T("oauth_press_c")))
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(warningStyle.Render(T("oauth_waiting")))
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// wrapText splits a long string into lines of at most maxWidth characters.
|
||||
func wrapText(s string, maxWidth int) []string {
|
||||
if maxWidth <= 0 {
|
||||
return []string{s}
|
||||
}
|
||||
var lines []string
|
||||
for len(s) > maxWidth {
|
||||
lines = append(lines, s[:maxWidth])
|
||||
s = s[maxWidth:]
|
||||
}
|
||||
if len(s) > 0 {
|
||||
lines = append(lines, s)
|
||||
}
|
||||
return lines
|
||||
}
|
||||
126
internal/tui/styles.go
Normal file
126
internal/tui/styles.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Package tui provides a terminal-based management interface for CLIProxyAPI.
|
||||
package tui
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
// Color palette
|
||||
var (
|
||||
colorPrimary = lipgloss.Color("#7C3AED") // violet
|
||||
colorSecondary = lipgloss.Color("#6366F1") // indigo
|
||||
colorSuccess = lipgloss.Color("#22C55E") // green
|
||||
colorWarning = lipgloss.Color("#EAB308") // yellow
|
||||
colorError = lipgloss.Color("#EF4444") // red
|
||||
colorInfo = lipgloss.Color("#3B82F6") // blue
|
||||
colorMuted = lipgloss.Color("#6B7280") // gray
|
||||
colorBg = lipgloss.Color("#1E1E2E") // dark bg
|
||||
colorSurface = lipgloss.Color("#313244") // slightly lighter
|
||||
colorText = lipgloss.Color("#CDD6F4") // light text
|
||||
colorSubtext = lipgloss.Color("#A6ADC8") // dimmer text
|
||||
colorBorder = lipgloss.Color("#45475A") // border
|
||||
colorHighlight = lipgloss.Color("#F5C2E7") // pink highlight
|
||||
)
|
||||
|
||||
// Tab bar styles
|
||||
var (
|
||||
tabActiveStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#FFFFFF")).
|
||||
Background(colorPrimary).
|
||||
Padding(0, 2)
|
||||
|
||||
tabInactiveStyle = lipgloss.NewStyle().
|
||||
Foreground(colorSubtext).
|
||||
Background(colorSurface).
|
||||
Padding(0, 2)
|
||||
|
||||
tabBarStyle = lipgloss.NewStyle().
|
||||
Background(colorSurface).
|
||||
PaddingLeft(1).
|
||||
PaddingBottom(0)
|
||||
)
|
||||
|
||||
// Content styles
|
||||
var (
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(colorHighlight).
|
||||
MarginBottom(1)
|
||||
|
||||
subtitleStyle = lipgloss.NewStyle().
|
||||
Foreground(colorSubtext).
|
||||
Italic(true)
|
||||
|
||||
labelStyle = lipgloss.NewStyle().
|
||||
Foreground(colorInfo).
|
||||
Bold(true).
|
||||
Width(24)
|
||||
|
||||
valueStyle = lipgloss.NewStyle().
|
||||
Foreground(colorText)
|
||||
|
||||
sectionStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(colorBorder).
|
||||
Padding(1, 2)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(colorError).
|
||||
Bold(true)
|
||||
|
||||
successStyle = lipgloss.NewStyle().
|
||||
Foreground(colorSuccess)
|
||||
|
||||
warningStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWarning)
|
||||
|
||||
statusBarStyle = lipgloss.NewStyle().
|
||||
Foreground(colorSubtext).
|
||||
Background(colorSurface).
|
||||
PaddingLeft(1).
|
||||
PaddingRight(1)
|
||||
|
||||
helpStyle = lipgloss.NewStyle().
|
||||
Foreground(colorMuted)
|
||||
)
|
||||
|
||||
// Log level styles
|
||||
var (
|
||||
logDebugStyle = lipgloss.NewStyle().Foreground(colorMuted)
|
||||
logInfoStyle = lipgloss.NewStyle().Foreground(colorInfo)
|
||||
logWarnStyle = lipgloss.NewStyle().Foreground(colorWarning)
|
||||
logErrorStyle = lipgloss.NewStyle().Foreground(colorError)
|
||||
)
|
||||
|
||||
// Table styles
|
||||
var (
|
||||
tableHeaderStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(colorHighlight).
|
||||
BorderBottom(true).
|
||||
BorderStyle(lipgloss.NormalBorder()).
|
||||
BorderForeground(colorBorder)
|
||||
|
||||
tableCellStyle = lipgloss.NewStyle().
|
||||
Foreground(colorText).
|
||||
PaddingRight(2)
|
||||
|
||||
tableSelectedStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FFFFFF")).
|
||||
Background(colorPrimary).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
func logLevelStyle(level string) lipgloss.Style {
|
||||
switch level {
|
||||
case "debug":
|
||||
return logDebugStyle
|
||||
case "info":
|
||||
return logInfoStyle
|
||||
case "warn", "warning":
|
||||
return logWarnStyle
|
||||
case "error", "fatal", "panic":
|
||||
return logErrorStyle
|
||||
default:
|
||||
return logInfoStyle
|
||||
}
|
||||
}
|
||||
364
internal/tui/usage_tab.go
Normal file
364
internal/tui/usage_tab.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// usageTabModel displays usage statistics with charts and breakdowns.
|
||||
type usageTabModel struct {
|
||||
client *Client
|
||||
viewport viewport.Model
|
||||
usage map[string]any
|
||||
err error
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
}
|
||||
|
||||
type usageDataMsg struct {
|
||||
usage map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
func newUsageTabModel(client *Client) usageTabModel {
|
||||
return usageTabModel{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (m usageTabModel) Init() tea.Cmd {
|
||||
return m.fetchData
|
||||
}
|
||||
|
||||
func (m usageTabModel) fetchData() tea.Msg {
|
||||
usage, err := m.client.GetUsage()
|
||||
return usageDataMsg{usage: usage, err: err}
|
||||
}
|
||||
|
||||
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
case usageDataMsg:
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
} else {
|
||||
m.err = nil
|
||||
m.usage = msg.usage
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
if msg.String() == "r" {
|
||||
return m, m.fetchData
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *usageTabModel) SetSize(w, h int) {
|
||||
m.width = w
|
||||
m.height = h
|
||||
if !m.ready {
|
||||
m.viewport = viewport.New(w, h)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
m.ready = true
|
||||
} else {
|
||||
m.viewport.Width = w
|
||||
m.viewport.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
func (m usageTabModel) View() string {
|
||||
if !m.ready {
|
||||
return T("loading")
|
||||
}
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m usageTabModel) renderContent() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("usage_title")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("usage_help")))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if m.err != nil {
|
||||
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
if m.usage == nil {
|
||||
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
usageMap, _ := m.usage["usage"].(map[string]any)
|
||||
if usageMap == nil {
|
||||
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
totalReqs := int64(getFloat(usageMap, "total_requests"))
|
||||
successCnt := int64(getFloat(usageMap, "success_count"))
|
||||
failureCnt := int64(getFloat(usageMap, "failure_count"))
|
||||
totalTokens := int64(getFloat(usageMap, "total_tokens"))
|
||||
|
||||
// ━━━ Overview Cards ━━━
|
||||
cardWidth := 20
|
||||
if m.width > 0 {
|
||||
cardWidth = (m.width - 6) / 4
|
||||
if cardWidth < 16 {
|
||||
cardWidth = 16
|
||||
}
|
||||
}
|
||||
cardStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("240")).
|
||||
Padding(0, 1).
|
||||
Width(cardWidth).
|
||||
Height(3)
|
||||
|
||||
// Total Requests
|
||||
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
|
||||
"%s\n%s\n%s",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
|
||||
))
|
||||
|
||||
// Total Tokens
|
||||
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
|
||||
"%s\n%s\n%s",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
|
||||
))
|
||||
|
||||
// RPM
|
||||
rpm := float64(0)
|
||||
if totalReqs > 0 {
|
||||
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
|
||||
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
|
||||
}
|
||||
}
|
||||
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
|
||||
"%s\n%s\n%s",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
|
||||
))
|
||||
|
||||
// TPM
|
||||
tpm := float64(0)
|
||||
if totalTokens > 0 {
|
||||
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
|
||||
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
|
||||
}
|
||||
}
|
||||
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
|
||||
"%s\n%s\n%s",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
|
||||
))
|
||||
|
||||
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// ━━━ Requests by Hour (ASCII bar chart) ━━━
|
||||
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// ━━━ Tokens by Hour ━━━
|
||||
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// ━━━ Requests by Day ━━━
|
||||
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// ━━━ API Detail Stats ━━━
|
||||
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
|
||||
sb.WriteString("\n")
|
||||
|
||||
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
|
||||
sb.WriteString(tableHeaderStyle.Render(header))
|
||||
sb.WriteString("\n")
|
||||
|
||||
for apiName, apiSnap := range apis {
|
||||
if apiMap, ok := apiSnap.(map[string]any); ok {
|
||||
apiReqs := int64(getFloat(apiMap, "total_requests"))
|
||||
apiToks := int64(getFloat(apiMap, "total_tokens"))
|
||||
|
||||
row := fmt.Sprintf(" %-30s %10d %12s",
|
||||
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Per-model breakdown
|
||||
if models, ok := apiMap["models"].(map[string]any); ok {
|
||||
for model, v := range models {
|
||||
if stats, ok := v.(map[string]any); ok {
|
||||
mReqs := int64(getFloat(stats, "total_requests"))
|
||||
mToks := int64(getFloat(stats, "total_tokens"))
|
||||
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
|
||||
truncate(model, 28), mReqs, formatLargeNumber(mToks))
|
||||
sb.WriteString(tableCellStyle.Render(mRow))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Token type breakdown from details
|
||||
sb.WriteString(m.renderTokenBreakdown(stats))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
|
||||
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
|
||||
details, ok := modelStats["details"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
detailList, ok := details.([]any)
|
||||
if !ok || len(detailList) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
|
||||
for _, d := range detailList {
|
||||
dm, ok := d.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
tokens, ok := dm["tokens"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inputTotal += int64(getFloat(tokens, "input_tokens"))
|
||||
outputTotal += int64(getFloat(tokens, "output_tokens"))
|
||||
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
|
||||
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
|
||||
}
|
||||
|
||||
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := []string{}
|
||||
if inputTotal > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
|
||||
}
|
||||
if outputTotal > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
|
||||
}
|
||||
if cachedTotal > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
|
||||
}
|
||||
if reasoningTotal > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" │ %s\n",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
|
||||
}
|
||||
|
||||
// renderBarChart renders a simple ASCII horizontal bar chart.
|
||||
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
|
||||
if maxBarWidth < 10 {
|
||||
maxBarWidth = 10
|
||||
}
|
||||
|
||||
// Sort keys
|
||||
keys := make([]string, 0, len(data))
|
||||
for k := range data {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
// Find max value
|
||||
maxVal := float64(0)
|
||||
for _, k := range keys {
|
||||
v := getFloat(data, k)
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
}
|
||||
if maxVal == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
barStyle := lipgloss.NewStyle().Foreground(barColor)
|
||||
var sb strings.Builder
|
||||
|
||||
labelWidth := 12
|
||||
barAvail := maxBarWidth - labelWidth - 12
|
||||
if barAvail < 5 {
|
||||
barAvail = 5
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
v := getFloat(data, k)
|
||||
barLen := int(v / maxVal * float64(barAvail))
|
||||
if barLen < 1 && v > 0 {
|
||||
barLen = 1
|
||||
}
|
||||
bar := strings.Repeat("█", barLen)
|
||||
label := k
|
||||
if len(label) > labelWidth {
|
||||
label = label[:labelWidth]
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
|
||||
labelWidth, label,
|
||||
barStyle.Render(bar),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
|
||||
))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -428,8 +428,9 @@ func flattenTypeArrays(jsonStr string) string {
|
||||
|
||||
func removeUnsupportedKeywords(jsonStr string) string {
|
||||
keywords := append(unsupportedConstraints,
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||
"propertyNames", // Gemini doesn't support property name validation
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "$id", "additionalProperties",
|
||||
"propertyNames", "patternProperties", // Gemini doesn't support these schema keywords
|
||||
"enumTitles", "prefill", // Claude/OpenCode schema metadata fields unsupported by Gemini
|
||||
)
|
||||
|
||||
deletePaths := make([]string, 0)
|
||||
|
||||
@@ -870,6 +870,57 @@ func TestCleanJSONSchemaForAntigravity_BooleanEnumToString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RemovesGeminiUnsupportedMetadataFields(t *testing.T) {
|
||||
input := `{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"$id": "root-schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payload": {
|
||||
"type": "object",
|
||||
"prefill": "hello",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"],
|
||||
"enumTitles": ["A", "B"]
|
||||
}
|
||||
},
|
||||
"patternProperties": {
|
||||
"^x-": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"$id": {
|
||||
"type": "string",
|
||||
"description": "property name should not be removed"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"],
|
||||
"description": "Allowed: a, b"
|
||||
}
|
||||
}
|
||||
},
|
||||
"$id": {
|
||||
"type": "string",
|
||||
"description": "property name should not be removed"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestRemoveExtensionFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -184,6 +184,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if o.Websockets != n.Websockets {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||
}
|
||||
|
||||
@@ -160,6 +160,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
}
|
||||
if ck.Websockets {
|
||||
attrs["websockets"] = "true"
|
||||
}
|
||||
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
|
||||
@@ -231,10 +231,11 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
||||
Config: &config.Config{
|
||||
CodexKey: []config.CodexKey{
|
||||
{
|
||||
APIKey: "codex-key-123",
|
||||
Prefix: "dev",
|
||||
BaseURL: "https://api.openai.com",
|
||||
ProxyURL: "http://proxy.local",
|
||||
APIKey: "codex-key-123",
|
||||
Prefix: "dev",
|
||||
BaseURL: "https://api.openai.com",
|
||||
ProxyURL: "http://proxy.local",
|
||||
Websockets: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -259,6 +260,9 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
if auths[0].Attributes["websockets"] != "true" {
|
||||
t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -92,6 +93,9 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
||||
status = coreauth.StatusDisabled
|
||||
}
|
||||
|
||||
// Read per-account excluded models from the OAuth JSON file
|
||||
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
@@ -108,11 +112,23 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||
// Read priority from auth file
|
||||
if rawPriority, ok := metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
a.Attributes["priority"] = strconv.Itoa(int(v))
|
||||
case string:
|
||||
priority := strings.TrimSpace(v)
|
||||
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
|
||||
a.Attributes["priority"] = priority
|
||||
}
|
||||
}
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, nil, "oauth")
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
@@ -167,6 +183,10 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
||||
if authPath != "" {
|
||||
attrs["path"] = authPath
|
||||
}
|
||||
// Propagate priority from primary auth to virtual auths
|
||||
if priorityVal, hasPriority := primary.Attributes["priority"]; hasPriority && priorityVal != "" {
|
||||
attrs["priority"] = priorityVal
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
@@ -239,3 +259,40 @@ func buildGeminiVirtualID(baseID, projectID string) string {
|
||||
replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_")
|
||||
return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project))
|
||||
}
|
||||
|
||||
// extractExcludedModelsFromMetadata reads per-account excluded models from the OAuth JSON metadata.
|
||||
// Supports both "excluded_models" and "excluded-models" keys, and accepts both []string and []interface{}.
|
||||
func extractExcludedModelsFromMetadata(metadata map[string]any) []string {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
// Try both key formats
|
||||
raw, ok := metadata["excluded_models"]
|
||||
if !ok {
|
||||
raw, ok = metadata["excluded-models"]
|
||||
}
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
var stringSlice []string
|
||||
switch v := raw.(type) {
|
||||
case []string:
|
||||
stringSlice = v
|
||||
case []interface{}:
|
||||
stringSlice = make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
stringSlice = append(stringSlice, s)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(stringSlice))
|
||||
for _, s := range stringSlice {
|
||||
if trimmed := strings.TrimSpace(s); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -297,6 +297,117 @@ func TestFileSynthesizer_Synthesize_PrefixValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_PriorityParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
priority any
|
||||
want string
|
||||
hasValue bool
|
||||
}{
|
||||
{
|
||||
name: "string with spaces",
|
||||
priority: " 10 ",
|
||||
want: "10",
|
||||
hasValue: true,
|
||||
},
|
||||
{
|
||||
name: "number",
|
||||
priority: 8,
|
||||
want: "8",
|
||||
hasValue: true,
|
||||
},
|
||||
{
|
||||
name: "invalid string",
|
||||
priority: "1x",
|
||||
hasValue: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"priority": tt.priority,
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
if errWriteFile != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, errSynthesize := synth.Synthesize(ctx)
|
||||
if errSynthesize != nil {
|
||||
t.Fatalf("unexpected error: %v", errSynthesize)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
value, ok := auths[0].Attributes["priority"]
|
||||
if tt.hasValue {
|
||||
if !ok {
|
||||
t.Fatal("expected priority attribute to be set")
|
||||
}
|
||||
if value != tt.want {
|
||||
t.Fatalf("expected priority %q, got %q", tt.want, value)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("expected priority attribute to be absent, got %q", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_OAuthExcludedModelsMerged(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"excluded_models": []string{"custom-model", "MODEL-B"},
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
if errWriteFile != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"claude": {"shared", "model-b"},
|
||||
},
|
||||
},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, errSynthesize := synth.Synthesize(ctx)
|
||||
if errSynthesize != nil {
|
||||
t.Fatalf("unexpected error: %v", errSynthesize)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
got := auths[0].Attributes["excluded_models"]
|
||||
want := "custom-model,model-b,shared"
|
||||
if got != want {
|
||||
t.Fatalf("expected excluded_models %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NilInputs(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
@@ -533,6 +644,7 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
"type": "gemini",
|
||||
"email": "multi@example.com",
|
||||
"project_id": "project-a, project-b, project-c",
|
||||
"priority": " 10 ",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
|
||||
@@ -565,6 +677,9 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
if primary.Status != coreauth.StatusDisabled {
|
||||
t.Errorf("expected primary status disabled, got %s", primary.Status)
|
||||
}
|
||||
if gotPriority := primary.Attributes["priority"]; gotPriority != "10" {
|
||||
t.Errorf("expected primary priority 10, got %q", gotPriority)
|
||||
}
|
||||
|
||||
// Remaining auths should be virtuals
|
||||
for i := 1; i < 4; i++ {
|
||||
@@ -575,6 +690,9 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
if v.Attributes["gemini_virtual_parent"] != primary.ID {
|
||||
t.Errorf("expected virtual %d parent to be %s, got %s", i, primary.ID, v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
if gotPriority := v.Attributes["priority"]; gotPriority != "10" {
|
||||
t.Errorf("expected virtual %d priority 10, got %q", i, gotPriority)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string)
|
||||
|
||||
// ApplyAuthExcludedModelsMeta applies excluded models metadata to an auth entry.
|
||||
// It computes a hash of excluded models and sets the auth_kind attribute.
|
||||
// For OAuth entries, perKey (from the JSON file's excluded-models field) is merged
|
||||
// with the global oauth-excluded-models config for the provider.
|
||||
func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||
if auth == nil || cfg == nil {
|
||||
return
|
||||
@@ -72,9 +74,13 @@ func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey
|
||||
}
|
||||
if authKindKey == "apikey" {
|
||||
add(perKey)
|
||||
} else if cfg.OAuthExcludedModels != nil {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
add(cfg.OAuthExcludedModels[providerKey])
|
||||
} else {
|
||||
// For OAuth: merge per-account excluded models with global provider-level exclusions
|
||||
add(perKey)
|
||||
if cfg.OAuthExcludedModels != nil {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
add(cfg.OAuthExcludedModels[providerKey])
|
||||
}
|
||||
}
|
||||
combined := make([]string, 0, len(seen))
|
||||
for k := range seen {
|
||||
@@ -88,6 +94,10 @@ func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey
|
||||
if hash != "" {
|
||||
auth.Attributes["excluded_models_hash"] = hash
|
||||
}
|
||||
// Store the combined excluded models list so that routing can read it at runtime
|
||||
if len(combined) > 0 {
|
||||
auth.Attributes["excluded_models"] = strings.Join(combined, ",")
|
||||
}
|
||||
if authKind != "" {
|
||||
auth.Attributes["auth_kind"] = authKind
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
@@ -200,6 +201,30 @@ func TestApplyAuthExcludedModelsMeta(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta_OAuthMergeWritesCombinedModels(t *testing.T) {
|
||||
auth := &coreauth.Auth{
|
||||
Provider: "claude",
|
||||
Attributes: make(map[string]string),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"claude": {"global-a", "shared"},
|
||||
},
|
||||
}
|
||||
|
||||
ApplyAuthExcludedModelsMeta(auth, cfg, []string{"per", "SHARED"}, "oauth")
|
||||
|
||||
const wantCombined = "global-a,per,shared"
|
||||
if gotCombined := auth.Attributes["excluded_models"]; gotCombined != wantCombined {
|
||||
t.Fatalf("expected excluded_models=%q, got %q", wantCombined, gotCombined)
|
||||
}
|
||||
|
||||
expectedHash := diff.ComputeExcludedModelsHash([]string{"global-a", "per", "shared"})
|
||||
if gotHash := auth.Attributes["excluded_models_hash"]; gotHash != expectedHash {
|
||||
t.Fatalf("expected excluded_models_hash=%q, got %q", expectedHash, gotHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConfigHeadersToAttrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -185,8 +185,7 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
||||
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
var keepAliveInterval *time.Duration
|
||||
if alt != "" {
|
||||
disabled := time.Duration(0)
|
||||
keepAliveInterval = &disabled
|
||||
keepAliveInterval = new(time.Duration(0))
|
||||
}
|
||||
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
|
||||
@@ -300,8 +300,7 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
||||
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
var keepAliveInterval *time.Duration
|
||||
if alt != "" {
|
||||
disabled := time.Duration(0)
|
||||
keepAliveInterval = &disabled
|
||||
keepAliveInterval = new(time.Duration(0))
|
||||
}
|
||||
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
|
||||
@@ -52,6 +52,45 @@ const (
|
||||
defaultStreamingBootstrapRetries = 0
|
||||
)
|
||||
|
||||
type pinnedAuthContextKey struct{}
|
||||
type selectedAuthCallbackContextKey struct{}
|
||||
type executionSessionContextKey struct{}
|
||||
|
||||
// WithPinnedAuthID returns a child context that requests execution on a specific auth ID.
|
||||
func WithPinnedAuthID(ctx context.Context, authID string) context.Context {
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, pinnedAuthContextKey{}, authID)
|
||||
}
|
||||
|
||||
// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID.
|
||||
func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context {
|
||||
if callback == nil {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback)
|
||||
}
|
||||
|
||||
// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID.
|
||||
func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context {
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if sessionID == "" {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, executionSessionContextKey{}, sessionID)
|
||||
}
|
||||
|
||||
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
||||
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
||||
func BuildErrorResponseBody(status int, errText string) []byte {
|
||||
@@ -152,7 +191,59 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
if key == "" {
|
||||
key = uuid.NewString()
|
||||
}
|
||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||
|
||||
meta := map[string]any{idempotencyKeyMetadataKey: key}
|
||||
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
|
||||
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
|
||||
}
|
||||
if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil {
|
||||
meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback
|
||||
}
|
||||
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
|
||||
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func pinnedAuthIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
raw := ctx.Value(pinnedAuthContextKey{})
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
raw := ctx.Value(selectedAuthCallbackContextKey{})
|
||||
if callback, ok := raw.(func(string)); ok && callback != nil {
|
||||
return callback
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func executionSessionIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
raw := ctx.Value(executionSessionContextKey{})
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// BaseAPIHandler contains the handlers for API endpoints.
|
||||
|
||||
@@ -122,6 +122,82 @@ func (e *payloadThenErrorStreamExecutor) Calls() int {
|
||||
return e.calls
|
||||
}
|
||||
|
||||
type authAwareStreamExecutor struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
_ = opts
|
||||
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||
|
||||
authID := ""
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
e.authIDs = append(e.authIDs, authID)
|
||||
e.mu.Unlock()
|
||||
|
||||
if authID == "auth1" {
|
||||
ch <- coreexecutor.StreamChunk{
|
||||
Err: &coreauth.Error{
|
||||
Code: "unauthorized",
|
||||
Message: "unauthorized",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Calls() int {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.calls
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) AuthIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.authIDs))
|
||||
copy(out, e.authIDs)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
executor := &failOnceStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
@@ -252,3 +328,128 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
|
||||
executor := &authAwareStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
ctx := WithPinnedAuthID(context.Background(), "auth1")
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
|
||||
var gotErr error
|
||||
for msg := range errChan {
|
||||
if msg != nil && msg.Error != nil {
|
||||
gotErr = msg.Error
|
||||
}
|
||||
}
|
||||
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty payload, got %q", string(got))
|
||||
}
|
||||
if gotErr == nil {
|
||||
t.Fatalf("expected terminal error, got nil")
|
||||
}
|
||||
authIDs := executor.AuthIDs()
|
||||
if len(authIDs) == 0 {
|
||||
t.Fatalf("expected at least one upstream attempt")
|
||||
}
|
||||
for _, authID := range authIDs {
|
||||
if authID != "auth1" {
|
||||
t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) {
|
||||
executor := &authAwareStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 0,
|
||||
},
|
||||
}, manager)
|
||||
|
||||
selectedAuthID := ""
|
||||
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
|
||||
selectedAuthID = authID
|
||||
})
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
t.Fatalf("unexpected error: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got) != "ok" {
|
||||
t.Fatalf("expected payload ok, got %q", string(got))
|
||||
}
|
||||
if selectedAuthID != "auth2" {
|
||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,6 +332,7 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
|
||||
// Check if this chunk has any meaningful content
|
||||
hasContent := false
|
||||
hasUsage := root.Get("usage").Exists()
|
||||
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
||||
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
||||
// Check if delta has content or finish_reason
|
||||
@@ -350,8 +351,8 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
})
|
||||
}
|
||||
|
||||
// If no meaningful content, return nil to indicate this chunk should be skipped
|
||||
if !hasContent {
|
||||
// If no meaningful content and no usage, return nil to indicate this chunk should be skipped
|
||||
if !hasContent && !hasUsage {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -410,6 +411,11 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
||||
}
|
||||
|
||||
// Copy usage if present
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
|
||||
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
@@ -0,0 +1,662 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
wsRequestTypeCreate = "response.create"
|
||||
wsRequestTypeAppend = "response.append"
|
||||
wsEventTypeError = "error"
|
||||
wsEventTypeCompleted = "response.completed"
|
||||
wsEventTypeDone = "response.done"
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
wsPayloadLogMaxSize = 2048
|
||||
)
|
||||
|
||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
// ResponsesWebsocket handles websocket requests for /v1/responses.
|
||||
// It accepts `response.create` and `response.append` requests and streams
|
||||
// response events back as JSON websocket text messages.
|
||||
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
passthroughSessionID := uuid.NewString()
|
||||
clientRemoteAddr := ""
|
||||
if c != nil && c.Request != nil {
|
||||
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||
}
|
||||
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
|
||||
var wsTerminateErr error
|
||||
var wsBodyLog strings.Builder
|
||||
defer func() {
|
||||
if wsTerminateErr != nil {
|
||||
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
||||
} else {
|
||||
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
|
||||
}
|
||||
if h != nil && h.AuthManager != nil {
|
||||
h.AuthManager.CloseExecutionSession(passthroughSessionID)
|
||||
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
|
||||
}
|
||||
setWebsocketRequestBody(c, wsBodyLog.String())
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Warnf("responses websocket: close connection error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
var lastRequest []byte
|
||||
lastResponseOutput := []byte("[]")
|
||||
pinnedAuthID := ""
|
||||
|
||||
for {
|
||||
msgType, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
wsTerminateErr = errReadMessage
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
|
||||
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||
} else {
|
||||
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||
continue
|
||||
}
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
|
||||
// passthroughSessionID,
|
||||
// msgType,
|
||||
// websocketPayloadEventType(payload),
|
||||
// websocketPayloadPreview(payload),
|
||||
// )
|
||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||
|
||||
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
var requestJSON []byte
|
||||
var updatedLastRequest []byte
|
||||
var errMsg *interfaces.ErrorMessage
|
||||
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
|
||||
payload,
|
||||
lastRequest,
|
||||
lastResponseOutput,
|
||||
allowIncrementalInputWithPreviousResponseID,
|
||||
)
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
passthroughSessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
passthroughSessionID,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
errWrite,
|
||||
)
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
|
||||
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
|
||||
if pinnedAuthID != "" {
|
||||
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
||||
} else {
|
||||
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
||||
pinnedAuthID = strings.TrimSpace(authID)
|
||||
})
|
||||
}
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||
if errForward != nil {
|
||||
wsTerminateErr = errForward
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
|
||||
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
|
||||
return
|
||||
}
|
||||
lastResponseOutput = completedOutput
|
||||
}
|
||||
}
|
||||
|
||||
func websocketUpgradeHeaders(req *http.Request) http.Header {
|
||||
headers := http.Header{}
|
||||
if req == nil {
|
||||
return headers
|
||||
}
|
||||
|
||||
// Keep the same sticky turn-state across reconnects when provided by the client.
|
||||
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
|
||||
if turnState != "" {
|
||||
headers.Set(wsTurnStateHeader, turnState)
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
|
||||
}
|
||||
|
||||
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||
switch requestType {
|
||||
case wsRequestTypeCreate:
|
||||
// log.Infof("responses websocket: response.create request")
|
||||
if len(lastRequest) == 0 {
|
||||
return normalizeResponseCreateRequest(rawJSON)
|
||||
}
|
||||
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||
case wsRequestTypeAppend:
|
||||
// log.Infof("responses websocket: response.append request")
|
||||
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||
default:
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
if !gjson.GetBytes(normalized, "input").Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
|
||||
if modelName == "" {
|
||||
return nil, nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("missing model in response.create request"),
|
||||
}
|
||||
}
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
|
||||
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
if len(lastRequest) == 0 {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("websocket request received before response.create"),
|
||||
}
|
||||
}
|
||||
|
||||
nextInput := gjson.GetBytes(rawJSON, "input")
|
||||
if !nextInput.Exists() || !nextInput.IsArray() {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("websocket request requires array field: input"),
|
||||
}
|
||||
}
|
||||
|
||||
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
||||
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
||||
if allowIncrementalInputWithPreviousResponseID {
|
||||
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
if modelName != "" {
|
||||
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||
if instructions.Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||
}
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
}
|
||||
|
||||
existingInput := gjson.GetBytes(lastRequest, "input")
|
||||
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
|
||||
if errMerge != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
|
||||
}
|
||||
}
|
||||
|
||||
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
|
||||
if errMerge != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
||||
}
|
||||
}
|
||||
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
|
||||
var errSet error
|
||||
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
|
||||
if errSet != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
if modelName != "" {
|
||||
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||
if instructions.Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||
}
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
|
||||
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||
if len(attributes) > 0 {
|
||||
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||
parsed, errParse := strconv.ParseBool(raw)
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(metadata) == 0 {
|
||||
return false
|
||||
}
|
||||
raw, ok := metadata["websockets"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case bool:
|
||||
return value
|
||||
case string:
|
||||
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
default:
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||
existingRaw = strings.TrimSpace(existingRaw)
|
||||
appendRaw = strings.TrimSpace(appendRaw)
|
||||
if existingRaw == "" {
|
||||
existingRaw = "[]"
|
||||
}
|
||||
if appendRaw == "" {
|
||||
appendRaw = "[]"
|
||||
}
|
||||
|
||||
var existing []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var appendItems []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
merged := append(existing, appendItems...)
|
||||
out, err := json.Marshal(merged)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func normalizeJSONArrayRaw(raw []byte) string {
|
||||
trimmed := strings.TrimSpace(string(raw))
|
||||
if trimmed == "" {
|
||||
return "[]"
|
||||
}
|
||||
result := gjson.Parse(trimmed)
|
||||
if result.Type == gjson.JSON && result.IsArray() {
|
||||
return trimmed
|
||||
}
|
||||
return "[]"
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
c *gin.Context,
|
||||
conn *websocket.Conn,
|
||||
cancel handlers.APIHandlerCancelFunc,
|
||||
data <-chan []byte,
|
||||
errs <-chan *interfaces.ErrorMessage,
|
||||
wsBodyLog *strings.Builder,
|
||||
sessionID string,
|
||||
) ([]byte, error) {
|
||||
completed := false
|
||||
completedOutput := []byte("[]")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return completedOutput, c.Request.Context().Err()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
errs = nil
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
sessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
// log.Warnf(
|
||||
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
// sessionID,
|
||||
// websocketPayloadEventType(errorPayload),
|
||||
// errWrite,
|
||||
// )
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
}
|
||||
if errMsg != nil {
|
||||
cancel(errMsg.Error)
|
||||
} else {
|
||||
cancel(nil)
|
||||
}
|
||||
return completedOutput, nil
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
if !completed {
|
||||
errMsg := &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusRequestTimeout,
|
||||
Error: fmt.Errorf("stream closed before response.completed"),
|
||||
}
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
sessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
errWrite,
|
||||
)
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, nil
|
||||
}
|
||||
cancel(nil)
|
||||
return completedOutput, nil
|
||||
}
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
for i := range payloads {
|
||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||
if eventType == wsEventTypeCompleted {
|
||||
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
||||
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
||||
|
||||
completed = true
|
||||
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||
}
|
||||
markAPIResponseTimestamp(c)
|
||||
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
// sessionID,
|
||||
// websocket.TextMessage,
|
||||
// websocketPayloadEventType(payloads[i]),
|
||||
// websocketPayloadPreview(payloads[i]),
|
||||
// )
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(payloads[i]),
|
||||
errWrite,
|
||||
)
|
||||
cancel(errWrite)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func responseCompletedOutputFromPayload(payload []byte) []byte {
|
||||
output := gjson.GetBytes(payload, "response.output")
|
||||
if output.Exists() && output.IsArray() {
|
||||
return bytes.Clone([]byte(output.Raw))
|
||||
}
|
||||
return []byte("[]")
|
||||
}
|
||||
|
||||
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
|
||||
payloads := make([][]byte, 0, 2)
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i := range lines {
|
||||
line := bytes.TrimSpace(lines[i])
|
||||
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte("data:")) {
|
||||
line = bytes.TrimSpace(line[len("data:"):])
|
||||
}
|
||||
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
|
||||
continue
|
||||
}
|
||||
if json.Valid(line) {
|
||||
payloads = append(payloads, bytes.Clone(line))
|
||||
}
|
||||
}
|
||||
|
||||
if len(payloads) > 0 {
|
||||
return payloads
|
||||
}
|
||||
|
||||
trimmed := bytes.TrimSpace(chunk)
|
||||
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
|
||||
}
|
||||
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
|
||||
payloads = append(payloads, bytes.Clone(trimmed))
|
||||
}
|
||||
return payloads
|
||||
}
|
||||
|
||||
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
|
||||
status := http.StatusInternalServerError
|
||||
errText := http.StatusText(status)
|
||||
if errMsg != nil {
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
errText = http.StatusText(status)
|
||||
}
|
||||
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
}
|
||||
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
payload := map[string]any{
|
||||
"type": wsEventTypeError,
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if errMsg != nil && errMsg.Addon != nil {
|
||||
headers := map[string]any{}
|
||||
for key, values := range errMsg.Addon {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
headers[key] = values[0]
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
payload["headers"] = headers
|
||||
}
|
||||
}
|
||||
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
var decoded map[string]any
|
||||
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
||||
if inner, ok := decoded["error"]; ok {
|
||||
payload["error"] = inner
|
||||
} else {
|
||||
payload["error"] = decoded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := payload["error"]; !ok {
|
||||
payload["error"] = map[string]any{
|
||||
"type": "server_error",
|
||||
"message": errText,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString("websocket.")
|
||||
builder.WriteString(eventType)
|
||||
builder.WriteString("\n")
|
||||
builder.Write(trimmedPayload)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
func websocketPayloadEventType(payload []byte) string {
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||
if eventType == "" {
|
||||
return "-"
|
||||
}
|
||||
return eventType
|
||||
}
|
||||
|
||||
func websocketPayloadPreview(payload []byte) string {
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
preview := trimmedPayload
|
||||
if len(preview) > wsPayloadLogMaxSize {
|
||||
preview = preview[:wsPayloadLogMaxSize]
|
||||
}
|
||||
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
|
||||
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
|
||||
if len(trimmedPayload) > wsPayloadLogMaxSize {
|
||||
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
|
||||
}
|
||||
return previewText
|
||||
}
|
||||
|
||||
func setWebsocketRequestBody(c *gin.Context, body string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
trimmedBody := strings.TrimSpace(body)
|
||||
if trimmedBody == "" {
|
||||
return
|
||||
}
|
||||
c.Set(wsRequestBodyKey, []byte(trimmedBody))
|
||||
}
|
||||
|
||||
func markAPIResponseTimestamp(c *gin.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
|
||||
return
|
||||
}
|
||||
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||
}
|
||||
249
sdk/api/handlers/openai/openai_responses_websocket_test.go
Normal file
249
sdk/api/handlers/openai/openai_responses_websocket_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "type").Exists() {
|
||||
t.Fatalf("normalized create request must not include type field")
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "stream").Bool() {
|
||||
t.Fatalf("normalized create request must force stream=true")
|
||||
}
|
||||
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||
}
|
||||
if !bytes.Equal(last, normalized) {
|
||||
t.Fatalf("last request snapshot should match normalized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||
{"type":"message","id":"assistant-1"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "type").Exists() {
|
||||
t.Fatalf("normalized subsequent create request must not include type field")
|
||||
}
|
||||
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||
}
|
||||
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 4 {
|
||||
t.Fatalf("merged input len = %d, want 4", len(input))
|
||||
}
|
||||
if input[0].Get("id").String() != "msg-1" ||
|
||||
input[1].Get("id").String() != "fc-1" ||
|
||||
input[2].Get("id").String() != "assistant-1" ||
|
||||
input[3].Get("id").String() != "tool-out-1" {
|
||||
t.Fatalf("unexpected merged input order")
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match normalized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||
{"type":"message","id":"assistant-1"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "type").Exists() {
|
||||
t.Fatalf("normalized request must not include type field")
|
||||
}
|
||||
if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" {
|
||||
t.Fatalf("previous_response_id must be preserved in incremental mode")
|
||||
}
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("incremental input len = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].Get("id").String() != "tool-out-1" {
|
||||
t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String())
|
||||
}
|
||||
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||
}
|
||||
if gjson.GetBytes(normalized, "instructions").String() != "be helpful" {
|
||||
t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String())
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match normalized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||
{"type":"message","id":"assistant-1"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
|
||||
t.Fatalf("previous_response_id must be removed when incremental mode is disabled")
|
||||
}
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 4 {
|
||||
t.Fatalf("merged input len = %d, want 4", len(input))
|
||||
}
|
||||
if input[0].Get("id").String() != "msg-1" ||
|
||||
input[1].Get("id").String() != "fc-1" ||
|
||||
input[2].Get("id").String() != "assistant-1" ||
|
||||
input[3].Get("id").String() != "tool-out-1" {
|
||||
t.Fatalf("unexpected merged input order")
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match normalized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"message","id":"assistant-1"},
|
||||
{"type":"function_call_output","id":"tool-out-1"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 5 {
|
||||
t.Fatalf("merged input len = %d, want 5", len(input))
|
||||
}
|
||||
if input[0].Get("id").String() != "msg-1" ||
|
||||
input[1].Get("id").String() != "assistant-1" ||
|
||||
input[2].Get("id").String() != "tool-out-1" ||
|
||||
input[3].Get("id").String() != "msg-2" ||
|
||||
input[4].Get("id").String() != "msg-3" {
|
||||
t.Fatalf("unexpected merged input order")
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match normalized append request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) {
|
||||
raw := []byte(`{"type":"response.append","input":[]}`)
|
||||
|
||||
_, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
|
||||
if errMsg == nil {
|
||||
t.Fatalf("expected error for append without previous request")
|
||||
}
|
||||
if errMsg.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketJSONPayloadsFromChunk(t *testing.T) {
|
||||
chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n")
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("payloads len = %d, want 1", len(payloads))
|
||||
}
|
||||
if gjson.GetBytes(payloads[0], "type").String() != "response.created" {
|
||||
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) {
|
||||
chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`)
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("payloads len = %d, want 1", len(payloads))
|
||||
}
|
||||
if gjson.GetBytes(payloads[0], "type").String() != "response.completed" {
|
||||
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseCompletedOutputFromPayload(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`)
|
||||
|
||||
output := responseCompletedOutputFromPayload(payload)
|
||||
items := gjson.ParseBytes(output).Array()
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("output len = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].Get("id").String() != "out-1" {
|
||||
t.Fatalf("unexpected output id: %s", items[0].Get("id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendWebsocketEvent(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
|
||||
appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"))
|
||||
appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}"))
|
||||
|
||||
got := builder.String()
|
||||
if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") {
|
||||
t.Fatalf("request event not found in body: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") {
|
||||
t.Fatalf("response event not found in body: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
setWebsocketRequestBody(c, " \n ")
|
||||
if _, exists := c.Get(wsRequestBodyKey); exists {
|
||||
t.Fatalf("request body key should not be set for empty body")
|
||||
}
|
||||
|
||||
setWebsocketRequestBody(c, "event body")
|
||||
value, exists := c.Get(wsRequestBodyKey)
|
||||
if !exists {
|
||||
t.Fatalf("request body key not set")
|
||||
}
|
||||
bodyBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
t.Fatalf("request body key type mismatch")
|
||||
}
|
||||
if string(bodyBytes) != "event body" {
|
||||
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||
}
|
||||
}
|
||||
@@ -28,8 +28,7 @@ func (AntigravityAuthenticator) Provider() string { return "antigravity" }
|
||||
|
||||
// RefreshLead instructs the manager to refresh five minutes before expiry.
|
||||
func (AntigravityAuthenticator) RefreshLead() *time.Duration {
|
||||
lead := 5 * time.Minute
|
||||
return &lead
|
||||
return new(5 * time.Minute)
|
||||
}
|
||||
|
||||
// Login launches a local OAuth flow to obtain antigravity tokens and persists them.
|
||||
|
||||
@@ -32,8 +32,7 @@ func (a *ClaudeAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
func (a *ClaudeAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 4 * time.Hour
|
||||
return &d
|
||||
return new(4 * time.Hour)
|
||||
}
|
||||
|
||||
func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -34,8 +34,7 @@ func (a *CodexAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
func (a *CodexAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 5 * 24 * time.Hour
|
||||
return &d
|
||||
return new(5 * 24 * time.Hour)
|
||||
}
|
||||
|
||||
func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -186,15 +188,21 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
if provider == "" {
|
||||
provider = "unknown"
|
||||
}
|
||||
if provider == "antigravity" {
|
||||
if provider == "antigravity" || provider == "gemini" {
|
||||
projectID := ""
|
||||
if pid, ok := metadata["project_id"].(string); ok {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
if projectID == "" {
|
||||
accessToken := ""
|
||||
if token, ok := metadata["access_token"].(string); ok {
|
||||
accessToken = strings.TrimSpace(token)
|
||||
accessToken := extractAccessToken(metadata)
|
||||
// For gemini type, the stored access_token is likely expired (~1h lifetime).
|
||||
// Refresh it using the long-lived refresh_token before querying.
|
||||
if provider == "gemini" {
|
||||
if tokenMap, ok := metadata["token"].(map[string]any); ok {
|
||||
if refreshed, errRefresh := refreshGeminiAccessToken(tokenMap, http.DefaultClient); errRefresh == nil {
|
||||
accessToken = refreshed
|
||||
}
|
||||
}
|
||||
}
|
||||
if accessToken != "" {
|
||||
fetchedProjectID, errFetch := FetchAntigravityProjectID(context.Background(), accessToken, http.DefaultClient)
|
||||
@@ -304,6 +312,67 @@ func (s *FileTokenStore) baseDirSnapshot() string {
|
||||
return s.baseDir
|
||||
}
|
||||
|
||||
func extractAccessToken(metadata map[string]any) string {
|
||||
if at, ok := metadata["access_token"].(string); ok {
|
||||
if v := strings.TrimSpace(at); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
if tokenMap, ok := metadata["token"].(map[string]any); ok {
|
||||
if at, ok := tokenMap["access_token"].(string); ok {
|
||||
if v := strings.TrimSpace(at); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func refreshGeminiAccessToken(tokenMap map[string]any, httpClient *http.Client) (string, error) {
|
||||
refreshToken, _ := tokenMap["refresh_token"].(string)
|
||||
clientID, _ := tokenMap["client_id"].(string)
|
||||
clientSecret, _ := tokenMap["client_secret"].(string)
|
||||
tokenURI, _ := tokenMap["token_uri"].(string)
|
||||
|
||||
if refreshToken == "" || clientID == "" || clientSecret == "" {
|
||||
return "", fmt.Errorf("missing refresh credentials")
|
||||
}
|
||||
if tokenURI == "" {
|
||||
tokenURI = "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {clientID},
|
||||
"client_secret": {clientSecret},
|
||||
}
|
||||
|
||||
resp, err := httpClient.PostForm(tokenURI, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("refresh failed: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if errUnmarshal := json.Unmarshal(body, &result); errUnmarshal != nil {
|
||||
return "", fmt.Errorf("decode refresh response: %w", errUnmarshal)
|
||||
}
|
||||
|
||||
newAccessToken, _ := result["access_token"].(string)
|
||||
if newAccessToken == "" {
|
||||
return "", fmt.Errorf("no access_token in refresh response")
|
||||
}
|
||||
|
||||
tokenMap["access_token"] = newAccessToken
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// jsonEqual compares two JSON blobs by parsing them into Go objects and deep comparing.
|
||||
func jsonEqual(a, b []byte) bool {
|
||||
var objA any
|
||||
|
||||
80
sdk/auth/filestore_test.go
Normal file
80
sdk/auth/filestore_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package auth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExtractAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"antigravity top-level access_token",
|
||||
map[string]any{"access_token": "tok-abc"},
|
||||
"tok-abc",
|
||||
},
|
||||
{
|
||||
"gemini nested token.access_token",
|
||||
map[string]any{
|
||||
"token": map[string]any{"access_token": "tok-nested"},
|
||||
},
|
||||
"tok-nested",
|
||||
},
|
||||
{
|
||||
"top-level takes precedence over nested",
|
||||
map[string]any{
|
||||
"access_token": "tok-top",
|
||||
"token": map[string]any{"access_token": "tok-nested"},
|
||||
},
|
||||
"tok-top",
|
||||
},
|
||||
{
|
||||
"empty metadata",
|
||||
map[string]any{},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"whitespace-only access_token",
|
||||
map[string]any{"access_token": " "},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"wrong type access_token",
|
||||
map[string]any{"access_token": 12345},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"token is not a map",
|
||||
map[string]any{"token": "not-a-map"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"nested whitespace-only",
|
||||
map[string]any{
|
||||
"token": map[string]any{"access_token": " "},
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"fallback to nested when top-level empty",
|
||||
map[string]any{
|
||||
"access_token": "",
|
||||
"token": map[string]any{"access_token": "tok-fallback"},
|
||||
},
|
||||
"tok-fallback",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := extractAccessToken(tt.metadata)
|
||||
if got != tt.expected {
|
||||
t.Errorf("extractAccessToken() = %q, want %q", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -26,8 +26,7 @@ func (a *IFlowAuthenticator) Provider() string { return "iflow" }
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
func (a *IFlowAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 24 * time.Hour
|
||||
return &d
|
||||
return new(24 * time.Hour)
|
||||
}
|
||||
|
||||
// Login performs the OAuth code flow using a local callback server.
|
||||
|
||||
@@ -27,8 +27,7 @@ func (a *QwenAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
func (a *QwenAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 3 * time.Hour
|
||||
return &d
|
||||
return new(3 * time.Hour)
|
||||
}
|
||||
|
||||
func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -41,6 +41,17 @@ type ProviderExecutor interface {
|
||||
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// ExecutionSessionCloser allows executors to release per-session runtime resources.
|
||||
type ExecutionSessionCloser interface {
|
||||
CloseExecutionSession(sessionID string)
|
||||
}
|
||||
|
||||
const (
|
||||
// CloseAllExecutionSessionsID asks an executor to release all active execution sessions.
|
||||
// Executors that do not support this marker may ignore it.
|
||||
CloseAllExecutionSessionsID = "__all_execution_sessions__"
|
||||
)
|
||||
|
||||
// RefreshEvaluator allows runtime state to override refresh decisions.
|
||||
type RefreshEvaluator interface {
|
||||
ShouldRefresh(now time.Time, auth *Auth) bool
|
||||
@@ -389,9 +400,23 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
||||
if executor == nil {
|
||||
return
|
||||
}
|
||||
provider := strings.TrimSpace(executor.Identifier())
|
||||
if provider == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var replaced ProviderExecutor
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.executors[executor.Identifier()] = executor
|
||||
replaced = m.executors[provider]
|
||||
m.executors[provider] = executor
|
||||
m.mu.Unlock()
|
||||
|
||||
if replaced == nil || replaced == executor {
|
||||
return
|
||||
}
|
||||
if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil {
|
||||
closer.CloseExecutionSession(CloseAllExecutionSessionsID)
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterExecutor removes the executor associated with the provider key.
|
||||
@@ -581,6 +606,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -599,8 +625,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
@@ -637,6 +662,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -655,8 +681,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
@@ -693,6 +718,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -710,8 +736,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
@@ -732,8 +757,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
@@ -798,6 +822,38 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func pinnedAuthIDFromMetadata(meta map[string]any) string {
|
||||
if len(meta) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
switch val := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(val)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(val))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func publishSelectedAuthMetadata(meta map[string]any, authID string) {
|
||||
if len(meta) == 0 {
|
||||
return
|
||||
}
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID
|
||||
if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil {
|
||||
callback(authID)
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteModelForAuth(model string, auth *Auth) string {
|
||||
if auth == nil || model == "" {
|
||||
return model
|
||||
@@ -1431,8 +1487,7 @@ func retryAfterFromError(err error) *time.Duration {
|
||||
if retryAfter == nil {
|
||||
return nil
|
||||
}
|
||||
val := *retryAfter
|
||||
return &val
|
||||
return new(*retryAfter)
|
||||
}
|
||||
|
||||
func statusCodeFromResult(err *Error) int {
|
||||
@@ -1555,7 +1610,56 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
|
||||
return auth.Clone(), true
|
||||
}
|
||||
|
||||
// Executor returns the registered provider executor for a provider key.
|
||||
func (m *Manager) Executor(provider string) (ProviderExecutor, bool) {
|
||||
if m == nil {
|
||||
return nil, false
|
||||
}
|
||||
provider = strings.TrimSpace(provider)
|
||||
if provider == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
executor, okExecutor := m.executors[provider]
|
||||
if !okExecutor {
|
||||
lowerProvider := strings.ToLower(provider)
|
||||
if lowerProvider != provider {
|
||||
executor, okExecutor = m.executors[lowerProvider]
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !okExecutor || executor == nil {
|
||||
return nil, false
|
||||
}
|
||||
return executor, true
|
||||
}
|
||||
|
||||
// CloseExecutionSession asks all registered executors to release the supplied execution session.
|
||||
func (m *Manager) CloseExecutionSession(sessionID string) {
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if m == nil || sessionID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
executors := make([]ProviderExecutor, 0, len(m.executors))
|
||||
for _, exec := range m.executors {
|
||||
executors = append(executors, exec)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
for i := range executors {
|
||||
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
|
||||
closer.CloseExecutionSession(sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
|
||||
m.mu.RLock()
|
||||
executor, okExecutor := m.executors[provider]
|
||||
if !okExecutor {
|
||||
@@ -1576,6 +1680,9 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
if candidate.Provider != provider || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
|
||||
continue
|
||||
}
|
||||
if _, used := tried[candidate.ID]; used {
|
||||
continue
|
||||
}
|
||||
@@ -1611,6 +1718,8 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
p := strings.TrimSpace(strings.ToLower(provider))
|
||||
@@ -1638,6 +1747,9 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
|
||||
if candidate == nil || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
|
||||
100
sdk/cliproxy/auth/conductor_executor_replace_test.go
Normal file
100
sdk/cliproxy/auth/conductor_executor_replace_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type replaceAwareExecutor struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
closedSessionIDs []string
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) Identifier() string {
|
||||
return e.id
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
ch := make(chan cliproxyexecutor.StreamChunk)
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) CloseExecutionSession(sessionID string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.closedSessionIDs = append(e.closedSessionIDs, sessionID)
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) ClosedSessionIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.closedSessionIDs))
|
||||
copy(out, e.closedSessionIDs)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestManagerRegisterExecutorClosesReplacedExecutionSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, nil, nil)
|
||||
replaced := &replaceAwareExecutor{id: "codex"}
|
||||
current := &replaceAwareExecutor{id: "codex"}
|
||||
|
||||
manager.RegisterExecutor(replaced)
|
||||
manager.RegisterExecutor(current)
|
||||
|
||||
closed := replaced.ClosedSessionIDs()
|
||||
if len(closed) != 1 {
|
||||
t.Fatalf("expected replaced executor close calls = 1, got %d", len(closed))
|
||||
}
|
||||
if closed[0] != CloseAllExecutionSessionsID {
|
||||
t.Fatalf("expected close marker %q, got %q", CloseAllExecutionSessionsID, closed[0])
|
||||
}
|
||||
if len(current.ClosedSessionIDs()) != 0 {
|
||||
t.Fatalf("expected current executor to stay open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, nil, nil)
|
||||
current := &replaceAwareExecutor{id: "codex"}
|
||||
manager.RegisterExecutor(current)
|
||||
|
||||
resolved, okResolved := manager.Executor("CODEX")
|
||||
if !okResolved {
|
||||
t.Fatal("expected registered executor to be found")
|
||||
}
|
||||
if resolved != current {
|
||||
t.Fatal("expected resolved executor to match registered executor")
|
||||
}
|
||||
|
||||
_, okMissing := manager.Executor("unknown")
|
||||
if okMissing {
|
||||
t.Fatal("expected unknown provider lookup to fail")
|
||||
}
|
||||
}
|
||||
@@ -134,6 +134,62 @@ func canonicalModelKey(model string) string {
|
||||
return modelName
|
||||
}
|
||||
|
||||
func authWebsocketsEnabled(auth *Auth) bool {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
if len(auth.Attributes) > 0 {
|
||||
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
|
||||
parsed, errParse := strconv.ParseBool(raw)
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(auth.Metadata) == 0 {
|
||||
return false
|
||||
}
|
||||
raw, ok := auth.Metadata["websockets"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
default:
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth {
|
||||
if len(available) == 0 {
|
||||
return available
|
||||
}
|
||||
if !cliproxyexecutor.DownstreamWebsocket(ctx) {
|
||||
return available
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(provider), "codex") {
|
||||
return available
|
||||
}
|
||||
|
||||
wsEnabled := make([]*Auth, 0, len(available))
|
||||
for i := 0; i < len(available); i++ {
|
||||
candidate := available[i]
|
||||
if authWebsocketsEnabled(candidate) {
|
||||
wsEnabled = append(wsEnabled, candidate)
|
||||
}
|
||||
}
|
||||
if len(wsEnabled) > 0 {
|
||||
return wsEnabled
|
||||
}
|
||||
return available
|
||||
}
|
||||
|
||||
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||
available = make(map[int][]*Auth)
|
||||
for i := 0; i < len(auths); i++ {
|
||||
@@ -193,13 +249,13 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
||||
|
||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = ctx
|
||||
_ = opts
|
||||
now := time.Now()
|
||||
available, err := getAvailableAuths(auths, provider, model, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
available = preferCodexWebsocketAuths(ctx, provider, available)
|
||||
key := provider + ":" + canonicalModelKey(model)
|
||||
s.mu.Lock()
|
||||
if s.cursors == nil {
|
||||
@@ -226,13 +282,13 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
||||
|
||||
// Pick selects the first available auth for the provider in a deterministic manner.
|
||||
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = ctx
|
||||
_ = opts
|
||||
now := time.Now()
|
||||
available, err := getAvailableAuths(auths, provider, model, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
available = preferCodexWebsocketAuths(ctx, provider, available)
|
||||
return available[0], nil
|
||||
}
|
||||
|
||||
|
||||
@@ -213,6 +213,23 @@ func (a *Auth) DisableCoolingOverride() (bool, bool) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// ToolPrefixDisabled returns whether the proxy_ tool name prefix should be
|
||||
// skipped for this auth. When true, tool names are sent to Anthropic unchanged.
|
||||
// The value is read from metadata key "tool_prefix_disabled" (or "tool-prefix-disabled").
|
||||
func (a *Auth) ToolPrefixDisabled() bool {
|
||||
if a == nil || a.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
for _, key := range []string{"tool_prefix_disabled", "tool-prefix-disabled"} {
|
||||
if val, ok := a.Metadata[key]; ok {
|
||||
if parsed, okParse := parseBoolAny(val); okParse {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
|
||||
// The value is read from metadata key "request_retry" (or legacy "request-retry").
|
||||
func (a *Auth) RequestRetryOverride() (int, bool) {
|
||||
|
||||
35
sdk/cliproxy/auth/types_test.go
Normal file
35
sdk/cliproxy/auth/types_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package auth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestToolPrefixDisabled(t *testing.T) {
|
||||
var a *Auth
|
||||
if a.ToolPrefixDisabled() {
|
||||
t.Error("nil auth should return false")
|
||||
}
|
||||
|
||||
a = &Auth{}
|
||||
if a.ToolPrefixDisabled() {
|
||||
t.Error("empty auth should return false")
|
||||
}
|
||||
|
||||
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": true}}
|
||||
if !a.ToolPrefixDisabled() {
|
||||
t.Error("should return true when set to true")
|
||||
}
|
||||
|
||||
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": "true"}}
|
||||
if !a.ToolPrefixDisabled() {
|
||||
t.Error("should return true when set to string 'true'")
|
||||
}
|
||||
|
||||
a = &Auth{Metadata: map[string]any{"tool-prefix-disabled": true}}
|
||||
if !a.ToolPrefixDisabled() {
|
||||
t.Error("should return true with kebab-case key")
|
||||
}
|
||||
|
||||
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": false}}
|
||||
if a.ToolPrefixDisabled() {
|
||||
t.Error("should return false when set to false")
|
||||
}
|
||||
}
|
||||
23
sdk/cliproxy/executor/context.go
Normal file
23
sdk/cliproxy/executor/context.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package executor
|
||||
|
||||
import "context"
|
||||
|
||||
type downstreamWebsocketContextKey struct{}
|
||||
|
||||
// WithDownstreamWebsocket marks the current request as coming from a downstream websocket connection.
|
||||
func WithDownstreamWebsocket(ctx context.Context) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, downstreamWebsocketContextKey{}, true)
|
||||
}
|
||||
|
||||
// DownstreamWebsocket reports whether the current request originates from a downstream websocket connection.
|
||||
func DownstreamWebsocket(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return false
|
||||
}
|
||||
raw := ctx.Value(downstreamWebsocketContextKey{})
|
||||
enabled, ok := raw.(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
@@ -10,6 +10,17 @@ import (
|
||||
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
|
||||
const RequestedModelMetadataKey = "requested_model"
|
||||
|
||||
const (
|
||||
// PinnedAuthMetadataKey locks execution to a specific auth ID.
|
||||
PinnedAuthMetadataKey = "pinned_auth_id"
|
||||
// SelectedAuthMetadataKey stores the auth ID selected by the scheduler.
|
||||
SelectedAuthMetadataKey = "selected_auth_id"
|
||||
// SelectedAuthCallbackMetadataKey carries an optional callback invoked with the selected auth ID.
|
||||
SelectedAuthCallbackMetadataKey = "selected_auth_callback"
|
||||
// ExecutionSessionMetadataKey identifies a long-lived downstream execution session.
|
||||
ExecutionSessionMetadataKey = "execution_session_id"
|
||||
)
|
||||
|
||||
// Request encapsulates the translated payload that will be sent to a provider executor.
|
||||
type Request struct {
|
||||
// Model is the upstream model identifier after translation.
|
||||
|
||||
@@ -325,6 +325,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||
if _, err := s.coreManager.Update(ctx, existing); err != nil {
|
||||
log.Errorf("failed to disable auth %s: %v", id, err)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
|
||||
s.ensureExecutorsForAuth(existing)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,7 +360,24 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
|
||||
}
|
||||
|
||||
func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
if s == nil || a == nil {
|
||||
s.ensureExecutorsForAuthWithMode(a, false)
|
||||
}
|
||||
|
||||
func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace bool) {
|
||||
if s == nil || s.coreManager == nil || a == nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(a.Provider), "codex") {
|
||||
if !forceReplace {
|
||||
existingExecutor, hasExecutor := s.coreManager.Executor("codex")
|
||||
if hasExecutor {
|
||||
_, isCodexAutoExecutor := existingExecutor.(*executor.CodexAutoExecutor)
|
||||
if isCodexAutoExecutor {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
s.coreManager.RegisterExecutor(executor.NewCodexAutoExecutor(s.cfg))
|
||||
return
|
||||
}
|
||||
// Skip disabled auth entries when (re)binding executors.
|
||||
@@ -392,8 +412,6 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
|
||||
case "claude":
|
||||
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
||||
case "codex":
|
||||
s.coreManager.RegisterExecutor(executor.NewCodexExecutor(s.cfg))
|
||||
case "qwen":
|
||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||
case "iflow":
|
||||
@@ -415,8 +433,15 @@ func (s *Service) rebindExecutors() {
|
||||
return
|
||||
}
|
||||
auths := s.coreManager.List()
|
||||
reboundCodex := false
|
||||
for _, auth := range auths {
|
||||
s.ensureExecutorsForAuth(auth)
|
||||
if auth != nil && strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
||||
if reboundCodex {
|
||||
continue
|
||||
}
|
||||
reboundCodex = true
|
||||
}
|
||||
s.ensureExecutorsForAuthWithMode(auth, true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -740,6 +765,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
provider = "openai-compatibility"
|
||||
}
|
||||
excluded := s.oauthExcludedModels(provider, authKind)
|
||||
// The synthesizer pre-merges per-account and global exclusions into the "excluded_models" attribute.
|
||||
// If this attribute is present, it represents the complete list of exclusions and overrides the global config.
|
||||
if a.Attributes != nil {
|
||||
if val, ok := a.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
|
||||
excluded = strings.Split(val, ",")
|
||||
}
|
||||
}
|
||||
var models []*ModelInfo
|
||||
switch provider {
|
||||
case "gemini":
|
||||
|
||||
64
sdk/cliproxy/service_codex_executor_binding_test.go
Normal file
64
sdk/cliproxy/service_codex_executor_binding_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestEnsureExecutorsForAuth_CodexDoesNotReplaceInNormalMode(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "codex-auth-1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuth(auth)
|
||||
firstExecutor, okFirst := service.coreManager.Executor("codex")
|
||||
if !okFirst || firstExecutor == nil {
|
||||
t.Fatal("expected codex executor after first bind")
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuth(auth)
|
||||
secondExecutor, okSecond := service.coreManager.Executor("codex")
|
||||
if !okSecond || secondExecutor == nil {
|
||||
t.Fatal("expected codex executor after second bind")
|
||||
}
|
||||
|
||||
if firstExecutor != secondExecutor {
|
||||
t.Fatal("expected codex executor to stay unchanged in normal mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "codex-auth-2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuth(auth)
|
||||
firstExecutor, okFirst := service.coreManager.Executor("codex")
|
||||
if !okFirst || firstExecutor == nil {
|
||||
t.Fatal("expected codex executor after first bind")
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuthWithMode(auth, true)
|
||||
secondExecutor, okSecond := service.coreManager.Executor("codex")
|
||||
if !okSecond || secondExecutor == nil {
|
||||
t.Fatal("expected codex executor after forced rebind")
|
||||
}
|
||||
|
||||
if firstExecutor == secondExecutor {
|
||||
t.Fatal("expected codex executor replacement in force mode")
|
||||
}
|
||||
}
|
||||
65
sdk/cliproxy/service_excluded_models_test.go
Normal file
65
sdk/cliproxy/service_excluded_models_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"gemini-cli": {"gemini-2.5-pro"},
|
||||
},
|
||||
},
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-gemini-cli",
|
||||
Provider: "gemini-cli",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
"excluded_models": "gemini-2.5-flash",
|
||||
},
|
||||
}
|
||||
|
||||
registry := GlobalModelRegistry()
|
||||
registry.UnregisterClient(auth.ID)
|
||||
t.Cleanup(func() {
|
||||
registry.UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
service.registerModelsForAuth(auth)
|
||||
|
||||
models := registry.GetAvailableModelsByProvider("gemini-cli")
|
||||
if len(models) == 0 {
|
||||
t.Fatal("expected gemini-cli models to be registered")
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
modelID := strings.TrimSpace(model.ID)
|
||||
if strings.EqualFold(modelID, "gemini-2.5-flash") {
|
||||
t.Fatalf("expected model %q to be excluded by auth attribute", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
seenGlobalExcluded := false
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(model.ID), "gemini-2.5-pro") {
|
||||
seenGlobalExcluded = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !seenGlobalExcluded {
|
||||
t.Fatal("expected global excluded model to be present when attribute override is set")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user