Compare commits

...

17 Commits

Author SHA1 Message Date
Luis Pater
2d5d06c809 feat(registry): add Qwen3 Vision Model definition #164 2025-10-27 00:41:05 +08:00
Luis Pater
3e20b00357 Merge pull request #163 from router-for-me/nb
fix(gemini): map responseModalities to uppercase IMAGE/TEXT
2025-10-26 22:41:18 +08:00
hkfires
e370f86f63 fix(gemini-executor): uppercase responseModalities 2025-10-26 21:26:15 +08:00
hkfires
7f266aa19e fix(aistudio): ensure colon-spaced JSON in responses 2025-10-26 20:21:45 +08:00
hkfires
f3f31274e8 refactor(wsrelay): rename RoundTrip to NonStream 2025-10-26 20:01:46 +08:00
hkfires
7061cd6058 fix(gemini): map responseModalities to uppercase IMAGE/TEXT 2025-10-26 19:35:22 +08:00
Luis Pater
5da5674ae2 Merge pull request #161 from router-for-me/aistudio
Add websocket provider
2025-10-26 16:39:09 +08:00
hkfires
7459c2c81a fix(aistudio): remove generationConfig and tools when action is countTokens 2025-10-26 16:28:20 +08:00
Luis Pater
cd4706f60e fix(server): resolve incorrect variable usage in management asset paths
- Replaced `s.currentPath` with `s.configFilePath` for consistent handling of management asset paths.
- Adjusted calls to `managementasset.FilePath` and `StaticDir` to use the updated configuration path.
2025-10-26 12:44:57 +08:00
hkfires
359b8de44e feat(ws): add WebSocket auth 2025-10-26 07:46:04 +08:00
hkfires
ea6065f1b1 fix(aistudio): strip usage metadata from non-final stream chunks 2025-10-26 07:46:04 +08:00
hkfires
8aaed4cf09 feat(aistudio): support non-streaming responses 2025-10-26 07:46:04 +08:00
hkfires
c32e013605 feat(aistudio): track Gemini usage and improve stream errors 2025-10-26 07:46:04 +08:00
hkfires
3839d93ba0 feat: add websocket routing and executor unregister API
- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket
  upgrade handlers on the Gin engine.
- Track registered WS paths via wsRoutes with wsRouteMu to prevent
  duplicate registrations; initialize in NewServer and import sync.
- Add Manager.UnregisterExecutor(provider) for clean executor lifecycle
  management.
- Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum.

Motivation: enable services to expose WS endpoints through the core server
and allow removing auth executors dynamically while avoiding duplicate
route setup. No breaking changes.
2025-10-26 07:46:03 +08:00
Luis Pater
a552a45b81 Fixed: #140 #133 #80
feat(translator): add token counting functionality for Gemini, Claude, and CLI

- Introduced `TokenCount` handling across various Codex translators (Gemini, Claude, CLI) with respective implementations.
- Added utility methods for token counting and formatting responses.
- Integrated `tiktoken-go/tokenizer` library for tokenization.
- Updated CodexExecutor with token counting logic to support multiple models including GPT-5 variants.
- Refined go.mod and go.sum to include new dependencies.

feat(runtime): add token counting functionality across executors

- Implemented token counting in OpenAICompatExecutor, QwenExecutor, and IFlowExecutor.
- Added utilities for token counting and response formatting using `tiktoken-go/tokenizer`.
- Integrated token counting into translators for Gemini, Claude, and Gemini CLI.
- Enhanced multiple model support, including GPT-5 variants, for token counting.

docs: update environment variable instructions for multi-model support

- Added details for setting `ANTHROPIC_DEFAULT_OPUS_MODEL`, `ANTHROPIC_DEFAULT_SONNET_MODEL`, and `ANTHROPIC_DEFAULT_HAIKU_MODEL` for version 2.x.x.
- Clarified usage of `ANTHROPIC_MODEL` and `ANTHROPIC_SMALL_FAST_MODEL` for version 1.x.x.
- Expanded examples for setting environment variables across different models including Gemini, GPT-5, Claude, and Qwen3.
2025-10-26 05:39:15 +08:00
Luis Pater
f6cf784cd1 refactor(translator): remove unused log dependency and comment out debug logging
docs: add GPT-5 Codex guidelines for CLI usage

- Added detailed guidelines for GPT-5 Codex in Codex CLI.
- Expanded instructions on sandboxing, approvals, editing constraints, and style requirements.
- Included presentation and response formatting best practices.

fix(codex_instructions): update comparison logic to use prefix matching

- Changed system instructions comparison to use `strings.HasPrefix` for improved flexibility.
2025-10-24 12:15:15 +08:00
Luis Pater
e783923464 feat(executor): add debug logs for rate-limiting retries in Gemini CLI executor 2025-10-23 10:39:21 +08:00
60 changed files with 1956 additions and 49 deletions

View File

@@ -556,12 +556,17 @@ The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` req
## Claude Code with multiple account load balancing
Start CLI Proxy API server, and then set the `ANTHROPIC_BASE_URL`, `ANTHROPIC_AUTH_TOKEN`, `ANTHROPIC_MODEL`, `ANTHROPIC_SMALL_FAST_MODEL` environment variables.
Start CLI Proxy API server, and then set the `ANTHROPIC_BASE_URL`, `ANTHROPIC_AUTH_TOKEN`, `ANTHROPIC_DEFAULT_OPUS_MODEL`, `ANTHROPIC_DEFAULT_SONNET_MODEL`, `ANTHROPIC_DEFAULT_HAIKU_MODEL` (or `ANTHROPIC_MODEL`, `ANTHROPIC_SMALL_FAST_MODEL` for version 1.x.x) environment variables.
Using Gemini models:
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# version 2.x.x
export ANTHROPIC_DEFAULT_OPUS_MODEL=gemini-2.5-pro
export ANTHROPIC_DEFAULT_SONNET_MODEL=gemini-2.5-flash
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gemini-2.5-flash-lite
# version 1.x.x
export ANTHROPIC_MODEL=gemini-2.5-pro
export ANTHROPIC_SMALL_FAST_MODEL=gemini-2.5-flash
```
@@ -570,6 +575,11 @@ Using OpenAI GPT 5 models:
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# version 2.x.x
export ANTHROPIC_DEFAULT_OPUS_MODEL=gpt-5-high
export ANTHROPIC_DEFAULT_SONNET_MODEL=gpt-5-medium
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gpt-5-minimal
# version 1.x.x
export ANTHROPIC_MODEL=gpt-5
export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-minimal
```
@@ -578,6 +588,11 @@ Using OpenAI GPT 5 Codex models:
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# version 2.x.x
export ANTHROPIC_DEFAULT_OPUS_MODEL=gpt-5-codex-high
export ANTHROPIC_DEFAULT_SONNET_MODEL=gpt-5-codex-medium
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gpt-5-codex-low
# version 1.x.x
export ANTHROPIC_MODEL=gpt-5-codex
export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-codex-low
```
@@ -586,6 +601,11 @@ Using Claude models:
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# version 2.x.x
export ANTHROPIC_DEFAULT_OPUS_MODEL=claude-opus-4-1-20250805
export ANTHROPIC_DEFAULT_SONNET_MODEL=claude-sonnet-4-5-20250929
export ANTHROPIC_DEFAULT_HAIKU_MODEL=claude-3-5-haiku-20241022
# version 1.x.x
export ANTHROPIC_MODEL=claude-sonnet-4-20250514
export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022
```
@@ -594,6 +614,11 @@ Using Qwen models:
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# version 2.x.x
export ANTHROPIC_DEFAULT_OPUS_MODEL=qwen3-coder-plus
export ANTHROPIC_DEFAULT_SONNET_MODEL=qwen3-coder-plus
export ANTHROPIC_DEFAULT_HAIKU_MODEL=qwen3-coder-flash
# version 1.x.x
export ANTHROPIC_MODEL=qwen3-coder-plus
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash
```
@@ -602,6 +627,11 @@ Using iFlow models:
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# version 2.x.x
export ANTHROPIC_DEFAULT_OPUS_MODEL=qwen3-max
export ANTHROPIC_DEFAULT_SONNET_MODEL=qwen3-coder-plus
export ANTHROPIC_DEFAULT_HAIKU_MODEL=qwen3-235b-a22b-instruct
# version 1.x.x
export ANTHROPIC_MODEL=qwen3-max
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-235b-a22b-instruct
```

View File

@@ -564,12 +564,17 @@ export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
## Claude Code 的使用方法
启动 CLI Proxy API 服务器, 设置如下系统环境变量 `ANTHROPIC_BASE_URL`, `ANTHROPIC_AUTH_TOKEN`, `ANTHROPIC_MODEL`, `ANTHROPIC_SMALL_FAST_MODEL`
启动 CLI Proxy API 服务器, 设置如下系统环境变量 `ANTHROPIC_BASE_URL`, `ANTHROPIC_AUTH_TOKEN`, `ANTHROPIC_DEFAULT_OPUS_MODEL`, `ANTHROPIC_DEFAULT_SONNET_MODEL`, `ANTHROPIC_DEFAULT_HAIKU_MODEL` (或 `ANTHROPIC_MODEL`, `ANTHROPIC_SMALL_FAST_MODEL` 对应 1.x.x 版本)
使用 Gemini 模型:
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# 2.x.x 版本
export ANTHROPIC_DEFAULT_OPUS_MODEL=gemini-2.5-pro
export ANTHROPIC_DEFAULT_SONNET_MODEL=gemini-2.5-flash
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gemini-2.5-flash-lite
# 1.x.x 版本
export ANTHROPIC_MODEL=gemini-2.5-pro
export ANTHROPIC_SMALL_FAST_MODEL=gemini-2.5-flash
```
@@ -578,6 +583,11 @@ export ANTHROPIC_SMALL_FAST_MODEL=gemini-2.5-flash
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# 2.x.x 版本
export ANTHROPIC_DEFAULT_OPUS_MODEL=gpt-5-high
export ANTHROPIC_DEFAULT_SONNET_MODEL=gpt-5-medium
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gpt-5-minimal
# 1.x.x 版本
export ANTHROPIC_MODEL=gpt-5
export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-minimal
```
@@ -586,15 +596,24 @@ export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-minimal
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# 2.x.x 版本
export ANTHROPIC_DEFAULT_OPUS_MODEL=gpt-5-codex-high
export ANTHROPIC_DEFAULT_SONNET_MODEL=gpt-5-codex-medium
export ANTHROPIC_DEFAULT_HAIKU_MODEL=gpt-5-codex-low
# 1.x.x 版本
export ANTHROPIC_MODEL=gpt-5-codex
export ANTHROPIC_SMALL_FAST_MODEL=gpt-5-codex-low
```
使用 Claude 模型:
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# 2.x.x 版本
export ANTHROPIC_DEFAULT_OPUS_MODEL=claude-opus-4-1-20250805
export ANTHROPIC_DEFAULT_SONNET_MODEL=claude-sonnet-4-5-20250929
export ANTHROPIC_DEFAULT_HAIKU_MODEL=claude-3-5-haiku-20241022
# 1.x.x 版本
export ANTHROPIC_MODEL=claude-sonnet-4-20250514
export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022
```
@@ -603,6 +622,11 @@ export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# 2.x.x 版本
export ANTHROPIC_DEFAULT_OPUS_MODEL=qwen3-coder-plus
export ANTHROPIC_DEFAULT_SONNET_MODEL=qwen3-coder-plus
export ANTHROPIC_DEFAULT_HAIKU_MODEL=qwen3-coder-flash
# 1.x.x 版本
export ANTHROPIC_MODEL=qwen3-coder-plus
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash
```
@@ -611,6 +635,11 @@ export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash
```bash
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
export ANTHROPIC_AUTH_TOKEN=sk-dummy
# 2.x.x 版本
export ANTHROPIC_DEFAULT_OPUS_MODEL=qwen3-max
export ANTHROPIC_DEFAULT_SONNET_MODEL=qwen3-coder-plus
export ANTHROPIC_DEFAULT_HAIKU_MODEL=qwen3-235b-a22b-instruct
# 1.x.x 版本
export ANTHROPIC_MODEL=qwen3-max
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-235b-a22b-instruct
```

View File

@@ -43,6 +43,9 @@ quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
# API keys for official Generative Language API
#generative-language-api-key:
# - "AIzaSy...01"

5
go.mod
View File

@@ -7,14 +7,16 @@ require (
github.com/gin-gonic/gin v1.10.1
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.7.6
github.com/joho/godotenv v1.5.1
github.com/klauspost/compress v1.17.4
github.com/minio/minio-go/v7 v7.0.66
github.com/sirupsen/logrus v1.9.3
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/tiktoken-go/tokenizer v0.7.0
golang.org/x/crypto v0.43.0
golang.org/x/net v0.46.0
golang.org/x/oauth2 v0.30.0
@@ -32,6 +34,7 @@ require (
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/cyphar/filepath-securejoin v0.4.1 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect

8
go.sum
View File

@@ -23,6 +23,8 @@ github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGL
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
@@ -64,6 +66,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -78,8 +82,6 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA=
github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
@@ -147,6 +149,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw=
github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=

View File

@@ -57,10 +57,12 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
authHeaderAnthropic := r.Header.Get("X-Api-Key")
queryKey := ""
queryAuthToken := ""
if r.URL != nil {
queryKey = r.URL.Query().Get("key")
queryAuthToken = r.URL.Query().Get("auth_token")
}
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" {
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
return nil, sdkaccess.ErrNoCredentials
}
@@ -74,6 +76,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
{authHeaderGoogle, "x-goog-api-key"},
{authHeaderAnthropic, "x-api-key"},
{queryKey, "query-key"},
{queryAuthToken, "query-auth-token"},
}
for _, candidate := range candidates {

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
@@ -63,13 +64,11 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
// Capture URL
url := c.Request.URL.String()
if c.Request.URL.Path != "" {
url = c.Request.URL.Path
if c.Request.URL.RawQuery != "" {
url += "?" + c.Request.URL.RawQuery
}
// Capture URL with sensitive query parameters masked
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
url := c.Request.URL.Path
if maskedQuery != "" {
url += "?" + maskedQuery
}
// Capture method

View File

@@ -13,6 +13,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
@@ -138,6 +139,12 @@ type Server struct {
// currentPath is the absolute path to the current working directory.
currentPath string
// wsRoutes tracks registered websocket upgrade paths.
wsRouteMu sync.Mutex
wsRoutes map[string]struct{}
wsAuthChanged func(bool, bool)
wsAuthEnabled atomic.Bool
// management handler
mgmt *managementHandlers.Handler
@@ -228,7 +235,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
configFilePath: configFilePath,
currentPath: wd,
envManagementSecret: envManagementSecret,
wsRoutes: make(map[string]struct{}),
}
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
// Save initial YAML snapshot
s.oldConfigYaml, _ = yaml.Marshal(cfg)
s.applyAccessConfig(nil, cfg)
@@ -371,6 +380,43 @@ func (s *Server) setupRoutes() {
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
}
// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine.
// The handler is served as-is without additional middleware beyond the standard stack already configured.
func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
if s == nil || s.engine == nil || handler == nil {
return
}
trimmed := strings.TrimSpace(path)
if trimmed == "" {
trimmed = "/v1/ws"
}
if !strings.HasPrefix(trimmed, "/") {
trimmed = "/" + trimmed
}
s.wsRouteMu.Lock()
if _, exists := s.wsRoutes[trimmed]; exists {
s.wsRouteMu.Unlock()
return
}
s.wsRoutes[trimmed] = struct{}{}
s.wsRouteMu.Unlock()
authMiddleware := AuthMiddleware(s.accessManager)
conditionalAuth := func(c *gin.Context) {
if !s.wsAuthEnabled.Load() {
c.Next()
return
}
authMiddleware(c)
}
finalHandler := func(c *gin.Context) {
handler.ServeHTTP(c.Writer, c.Request)
c.Abort()
}
s.engine.GET(trimmed, conditionalAuth, finalHandler)
}
func (s *Server) registerManagementRoutes() {
if s == nil || s.engine == nil || s.mgmt == nil {
return
@@ -479,7 +525,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
c.AbortWithStatus(http.StatusNotFound)
return
}
filePath := managementasset.FilePath(s.currentPath)
filePath := managementasset.FilePath(s.configFilePath)
if strings.TrimSpace(filePath) == "" {
c.AbortWithStatus(http.StatusNotFound)
return
@@ -487,7 +533,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
if _, err := os.Stat(filePath); err != nil {
if os.IsNotExist(err) {
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.currentPath), cfg.ProxyURL)
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL)
c.AbortWithStatus(http.StatusNotFound)
return
}
@@ -770,13 +816,17 @@ func (s *Server) UpdateClients(cfg *config.Config) {
s.applyAccessConfig(oldCfg, cfg)
s.cfg = cfg
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
}
managementasset.SetCurrentConfig(cfg)
// Save YAML snapshot for next comparison
s.oldConfigYaml, _ = yaml.Marshal(cfg)
s.handlers.UpdateClients(&cfg.SDKConfig)
if !cfg.RemoteManagement.DisableControlPanel {
staticDir := managementasset.StaticDir(s.currentPath)
staticDir := managementasset.StaticDir(s.configFilePath)
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL)
}
if s.mgmt != nil {
@@ -810,6 +860,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
)
}
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
if s == nil {
return
}
s.wsAuthChanged = fn
}
// (management handlers moved to internal/api/handlers/management)
// AuthMiddleware returns a Gin middleware handler that authenticates requests

View File

@@ -40,6 +40,9 @@ type Config struct {
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
// WebsocketAuth enables or disables authentication for the WebSocket API.
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
// GlAPIKey is the API key for the generative language API.
GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"`

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
@@ -23,7 +24,7 @@ func GinLogrusLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
c.Next()

View File

@@ -20,7 +20,7 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
// lastReviewPrompt := ""
for _, entry := range entries {
content, _ := codexInstructionsDir.ReadFile("codex_instructions/" + entry.Name())
if systemInstructions == string(content) {
if strings.HasPrefix(systemInstructions, string(content)) {
return true, ""
}
if strings.HasPrefix(entry.Name(), "gpt_5_codex_prompt.md") {

View File

@@ -385,6 +385,19 @@ func GetQwenModels() []*ModelInfo {
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "vision-model",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "qwen",
Type: "qwen",
Version: "3.0",
DisplayName: "Qwen3 Vision Model",
Description: "Vision model model",
ContextLength: 32768,
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
}
}

View File

@@ -0,0 +1,396 @@
package executor
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// AistudioExecutor routes AI Studio requests through a websocket-backed transport.
type AistudioExecutor struct {
provider string
relay *wsrelay.Manager
cfg *config.Config
}
// NewAistudioExecutor constructs a websocket executor for the provider name.
func NewAistudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AistudioExecutor {
return &AistudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
}
// Identifier returns the provider key served by this executor.
func (e *AistudioExecutor) Identifier() string { return e.provider }
// PrepareRequest is a no-op because websocket transport already injects headers.
func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
return nil
}
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, false)
if err != nil {
return resp, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload,
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
wsResp, err := e.relay.NonStream(ctx, e.provider, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
if len(wsResp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
}
if wsResp.Status < 200 || wsResp.Status >= 300 {
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
}
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
return resp, nil
}
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, true)
if err != nil {
return nil, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload,
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
wsStream, err := e.relay.Stream(ctx, e.provider, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
var param any
metadataLogged := false
for event := range wsStream {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return
}
switch event.Type {
case wsrelay.MessageTypeStreamStart:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
filtered := filterAistudioUsageMetadata(event.Payload)
if detail, ok := parseGeminiStreamUsage(filtered); ok {
reporter.publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
}
break
}
case wsrelay.MessageTypeStreamEnd:
return
case wsrelay.MessageTypeHTTPResp:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
}
reporter.publish(ctx, parseGeminiUsage(event.Payload))
return
case wsrelay.MessageTypeError:
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return
}
}
}()
return stream, nil
}
func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_, body, err := e.translateRequest(req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
}
body.payload, _ = sjson.DeleteBytes(body.payload, "generationConfig")
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload,
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
resp, err := e.relay.NonStream(ctx, e.provider, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
if len(resp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
}
if resp.Status < 200 || resp.Status >= 300 {
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
}
totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int()
if totalTokens <= 0 {
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
}
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
_ = ctx
return auth, nil
}
type translatedPayload struct {
payload []byte
action string
toFormat sdktranslator.Format
}
func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok {
payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
}
payload = disableGeminiThinkingConfig(payload, req.Model)
payload = fixGeminiImageAspectRatio(req.Model, payload)
metadataAction := "generateContent"
if req.Metadata != nil {
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
metadataAction = action
}
}
action := metadataAction
if stream && action != "countTokens" {
action = "streamGenerateContent"
}
payload, _ = sjson.DeleteBytes(payload, "session_id")
return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil
}
func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string {
base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action)
if action == "streamGenerateContent" {
if alt == "" {
return base + "?alt=sse"
}
return base + "?$alt=" + url.QueryEscape(alt)
}
if alt != "" && action != "countTokens" {
return base + "?$alt=" + url.QueryEscape(alt)
}
return base
}
// filterAistudioUsageMetadata removes usageMetadata from intermediate SSE events so that
// only the terminal chunk retains token statistics.
func filterAistudioUsageMetadata(payload []byte) []byte {
if len(payload) == 0 {
return payload
}
lines := bytes.Split(payload, []byte("\n"))
modified := false
for idx, line := range lines {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
dataIdx := bytes.Index(line, []byte("data:"))
if dataIdx < 0 {
continue
}
rawJSON := bytes.TrimSpace(line[dataIdx+5:])
cleaned, changed := stripUsageMetadataFromJSON(rawJSON)
if !changed {
continue
}
var rebuilt []byte
rebuilt = append(rebuilt, line[:dataIdx]...)
rebuilt = append(rebuilt, []byte("data:")...)
if len(cleaned) > 0 {
rebuilt = append(rebuilt, ' ')
rebuilt = append(rebuilt, cleaned...)
}
lines[idx] = rebuilt
modified = true
}
if !modified {
return payload
}
return bytes.Join(lines, []byte("\n"))
}
// stripUsageMetadataFromJSON drops usageMetadata when no finishReason is present.
func stripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
jsonBytes := bytes.TrimSpace(rawJSON)
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
return rawJSON, false
}
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
if finishReason.Exists() && finishReason.String() != "" {
return rawJSON, false
}
if !gjson.GetBytes(jsonBytes, "usageMetadata").Exists() {
return rawJSON, false
}
cleaned, err := sjson.DeleteBytes(jsonBytes, "usageMetadata")
if err != nil {
return rawJSON, false
}
return cleaned, true
}
// ensureColonSpacedJSON normalizes JSON objects so that colons are followed by a single space while
// keeping the payload otherwise compact. Non-JSON inputs are returned unchanged.
func ensureColonSpacedJSON(payload []byte) []byte {
trimmed := bytes.TrimSpace(payload)
if len(trimmed) == 0 {
return payload
}
var decoded any
if err := json.Unmarshal(trimmed, &decoded); err != nil {
return payload
}
indented, err := json.MarshalIndent(decoded, "", " ")
if err != nil {
return payload
}
compacted := make([]byte, 0, len(indented))
inString := false
skipSpace := false
for i := 0; i < len(indented); i++ {
ch := indented[i]
if ch == '"' && (i == 0 || indented[i-1] != '\\') {
inString = !inString
}
if !inString {
if ch == '\n' || ch == '\r' {
skipSpace = true
continue
}
if skipSpace {
if ch == ' ' || ch == '\t' {
continue
}
skipSpace = false
}
}
compacted = append(compacted, ch)
}
return compacted
}

View File

@@ -20,6 +20,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tiktoken-go/tokenizer"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -277,7 +278,180 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented")
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
modelForCounting := req.Model
if util.InArray([]string{"gpt-5", "gpt-5-minimal", "gpt-5-low", "gpt-5-medium", "gpt-5-high"}, req.Model) {
modelForCounting = "gpt-5"
body, _ = sjson.SetBytes(body, "model", "gpt-5")
switch req.Model {
case "gpt-5-minimal":
body, _ = sjson.SetBytes(body, "reasoning.effort", "minimal")
case "gpt-5-low":
body, _ = sjson.SetBytes(body, "reasoning.effort", "low")
case "gpt-5-medium":
body, _ = sjson.SetBytes(body, "reasoning.effort", "medium")
case "gpt-5-high":
body, _ = sjson.SetBytes(body, "reasoning.effort", "high")
default:
body, _ = sjson.SetBytes(body, "reasoning.effort", "low")
}
} else if util.InArray([]string{"gpt-5-codex", "gpt-5-codex-low", "gpt-5-codex-medium", "gpt-5-codex-high"}, req.Model) {
modelForCounting = "gpt-5"
body, _ = sjson.SetBytes(body, "model", "gpt-5-codex")
switch req.Model {
case "gpt-5-codex-low":
body, _ = sjson.SetBytes(body, "reasoning.effort", "low")
case "gpt-5-codex-medium":
body, _ = sjson.SetBytes(body, "reasoning.effort", "medium")
case "gpt-5-codex-high":
body, _ = sjson.SetBytes(body, "reasoning.effort", "high")
default:
body, _ = sjson.SetBytes(body, "reasoning.effort", "low")
}
}
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "stream", false)
enc, err := tokenizerForCodexModel(modelForCounting)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
}
count, err := countCodexInputTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: token counting failed: %w", err)
}
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
switch {
case sanitized == "":
return tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5"):
return tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
return tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
return tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
return tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
return tokenizer.ForModel(tokenizer.GPT35Turbo)
default:
return tokenizer.Get(tokenizer.Cl100kBase)
}
}
func countCodexInputTokens(enc tokenizer.Codec, body []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
if len(body) == 0 {
return 0, nil
}
root := gjson.ParseBytes(body)
var segments []string
if inst := strings.TrimSpace(root.Get("instructions").String()); inst != "" {
segments = append(segments, inst)
}
inputItems := root.Get("input")
if inputItems.IsArray() {
arr := inputItems.Array()
for i := range arr {
item := arr[i]
switch item.Get("type").String() {
case "message":
content := item.Get("content")
if content.IsArray() {
parts := content.Array()
for j := range parts {
part := parts[j]
if text := strings.TrimSpace(part.Get("text").String()); text != "" {
segments = append(segments, text)
}
}
}
case "function_call":
if name := strings.TrimSpace(item.Get("name").String()); name != "" {
segments = append(segments, name)
}
if args := strings.TrimSpace(item.Get("arguments").String()); args != "" {
segments = append(segments, args)
}
case "function_call_output":
if out := strings.TrimSpace(item.Get("output").String()); out != "" {
segments = append(segments, out)
}
default:
if text := strings.TrimSpace(item.Get("text").String()); text != "" {
segments = append(segments, text)
}
}
}
}
tools := root.Get("tools")
if tools.IsArray() {
tarr := tools.Array()
for i := range tarr {
tool := tarr[i]
if name := strings.TrimSpace(tool.Get("name").String()); name != "" {
segments = append(segments, name)
}
if desc := strings.TrimSpace(tool.Get("description").String()); desc != "" {
segments = append(segments, desc)
}
if params := tool.Get("parameters"); params.Exists() {
val := params.Raw
if params.Type == gjson.String {
val = params.String()
}
if trimmed := strings.TrimSpace(val); trimmed != "" {
segments = append(segments, trimmed)
}
}
}
}
textFormat := root.Get("text.format")
if textFormat.Exists() {
if name := strings.TrimSpace(textFormat.Get("name").String()); name != "" {
segments = append(segments, name)
}
if schema := textFormat.Get("schema"); schema.Exists() {
val := schema.Raw
if schema.Type == gjson.String {
val = schema.String()
}
if trimmed := strings.TrimSpace(val); trimmed != "" {
segments = append(segments, trimmed)
}
}
}
text := strings.Join(segments, "\n")
if text == "" {
return 0, nil
}
count, err := enc.Count(text)
if err != nil {
return 0, err
}
return int64(count), nil
}
func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {

View File

@@ -166,6 +166,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
lastBody = append([]byte(nil), data...)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, string(data))
if httpResp.StatusCode == 429 {
log.Debugf("gemini cli executor: rate limited, retrying with next model")
continue
}
@@ -281,6 +282,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
lastBody = append([]byte(nil), data...)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, string(data))
if httpResp.StatusCode == 429 {
log.Debugf("gemini cli executor: rate limited, retrying with next model")
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
@@ -451,6 +453,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...)
if resp.StatusCode == 429 {
log.Debugf("gemini cli executor: rate limited, retrying with next model")
continue
}
break
@@ -700,7 +703,7 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
}
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["Image", "Text"]`))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
}
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "request.generationConfig.imageConfig")

View File

@@ -494,7 +494,7 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
}
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["Image", "Text"]`))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
}
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "generationConfig.imageConfig")

View File

@@ -221,9 +221,24 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return stream, nil
}
// CountTokens is not implemented for iFlow.
func (e *IFlowExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{Payload: nil}, fmt.Errorf("not implemented")
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
enc, err := tokenizerForModel(req.Model)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
// Refresh refreshes OAuth tokens and updates the stored API key.

View File

@@ -219,7 +219,29 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
}
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented")
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
modelForCounting := req.Model
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
modelForCounting = modelOverride
}
enc, err := tokenizerForModel(modelForCounting)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, translated)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil
}
// Refresh is a no-op for API-key based compatibility providers.

View File

@@ -207,7 +207,28 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented")
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
modelName := gjson.GetBytes(body, "model").String()
if strings.TrimSpace(modelName) == "" {
modelName = req.Model
}
enc, err := tokenizerForModel(modelName)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {

View File

@@ -0,0 +1,234 @@
package executor
import (
"fmt"
"strings"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
func tokenizerForModel(model string) (tokenizer.Codec, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
switch {
case sanitized == "":
return tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5"):
return tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
return tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
return tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
return tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
return tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"):
return tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"):
return tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"):
return tokenizer.ForModel(tokenizer.O4Mini)
default:
return tokenizer.Get(tokenizer.O200kBase)
}
}
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
if len(payload) == 0 {
return 0, nil
}
root := gjson.ParseBytes(payload)
segments := make([]string, 0, 32)
collectOpenAIMessages(root.Get("messages"), &segments)
collectOpenAITools(root.Get("tools"), &segments)
collectOpenAIFunctions(root.Get("functions"), &segments)
collectOpenAIToolChoice(root.Get("tool_choice"), &segments)
collectOpenAIResponseFormat(root.Get("response_format"), &segments)
addIfNotEmpty(&segments, root.Get("input").String())
addIfNotEmpty(&segments, root.Get("prompt").String())
joined := strings.TrimSpace(strings.Join(segments, "\n"))
if joined == "" {
return 0, nil
}
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
return int64(count), nil
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
func buildOpenAIUsageJSON(count int64) []byte {
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
}
func collectOpenAIMessages(messages gjson.Result, segments *[]string) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
addIfNotEmpty(segments, message.Get("name").String())
collectOpenAIContent(message.Get("content"), segments)
collectOpenAIToolCalls(message.Get("tool_calls"), segments)
collectOpenAIFunctionCall(message.Get("function_call"), segments)
return true
})
}
func collectOpenAIContent(content gjson.Result, segments *[]string) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text", "input_text", "output_text":
addIfNotEmpty(segments, part.Get("text").String())
case "image_url":
addIfNotEmpty(segments, part.Get("image_url.url").String())
case "input_audio", "output_audio", "audio":
addIfNotEmpty(segments, part.Get("id").String())
case "tool_result":
addIfNotEmpty(segments, part.Get("name").String())
collectOpenAIContent(part.Get("content"), segments)
default:
if part.IsArray() {
collectOpenAIContent(part, segments)
return true
}
if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
return true
}
addIfNotEmpty(segments, part.String())
}
return true
})
return
}
if content.Type == gjson.JSON {
addIfNotEmpty(segments, content.Raw)
}
}
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
if !calls.Exists() || !calls.IsArray() {
return
}
calls.ForEach(func(_, call gjson.Result) bool {
addIfNotEmpty(segments, call.Get("id").String())
addIfNotEmpty(segments, call.Get("type").String())
function := call.Get("function")
if function.Exists() {
addIfNotEmpty(segments, function.Get("name").String())
addIfNotEmpty(segments, function.Get("description").String())
addIfNotEmpty(segments, function.Get("arguments").String())
if params := function.Get("parameters"); params.Exists() {
addIfNotEmpty(segments, params.Raw)
}
}
return true
})
}
func collectOpenAIFunctionCall(call gjson.Result, segments *[]string) {
if !call.Exists() {
return
}
addIfNotEmpty(segments, call.Get("name").String())
addIfNotEmpty(segments, call.Get("arguments").String())
}
func collectOpenAITools(tools gjson.Result, segments *[]string) {
if !tools.Exists() {
return
}
if tools.IsArray() {
tools.ForEach(func(_, tool gjson.Result) bool {
appendToolPayload(tool, segments)
return true
})
return
}
appendToolPayload(tools, segments)
}
func collectOpenAIFunctions(functions gjson.Result, segments *[]string) {
if !functions.Exists() || !functions.IsArray() {
return
}
functions.ForEach(func(_, function gjson.Result) bool {
addIfNotEmpty(segments, function.Get("name").String())
addIfNotEmpty(segments, function.Get("description").String())
if params := function.Get("parameters"); params.Exists() {
addIfNotEmpty(segments, params.Raw)
}
return true
})
}
func collectOpenAIToolChoice(choice gjson.Result, segments *[]string) {
if !choice.Exists() {
return
}
if choice.Type == gjson.String {
addIfNotEmpty(segments, choice.String())
return
}
addIfNotEmpty(segments, choice.Raw)
}
func collectOpenAIResponseFormat(format gjson.Result, segments *[]string) {
if !format.Exists() {
return
}
addIfNotEmpty(segments, format.Get("type").String())
addIfNotEmpty(segments, format.Get("name").String())
if schema := format.Get("json_schema"); schema.Exists() {
addIfNotEmpty(segments, schema.Raw)
}
if schema := format.Get("schema"); schema.Exists() {
addIfNotEmpty(segments, schema.Raw)
}
}
func appendToolPayload(tool gjson.Result, segments *[]string) {
if !tool.Exists() {
return
}
addIfNotEmpty(segments, tool.Get("type").String())
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if function := tool.Get("function"); function.Exists() {
addIfNotEmpty(segments, function.Get("name").String())
addIfNotEmpty(segments, function.Get("description").String())
if params := function.Get("parameters"); params.Exists() {
addIfNotEmpty(segments, params.Raw)
}
}
}
func addIfNotEmpty(segments *[]string, value string) {
if segments == nil {
return
}
if trimmed := strings.TrimSpace(value); trimmed != "" {
*segments = append(*segments, trimmed)
}
}

View File

@@ -354,3 +354,7 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
}
return rev
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
}

View File

@@ -12,8 +12,9 @@ func init() {
Codex,
ConvertClaudeRequestToCodex,
interfaces.TranslateResponse{
Stream: ConvertCodexResponseToClaude,
NonStream: ConvertCodexResponseToClaudeNonStream,
Stream: ConvertCodexResponseToClaude,
NonStream: ConvertCodexResponseToClaudeNonStream,
TokenCount: ClaudeTokenCount,
},
)
}

View File

@@ -6,6 +6,7 @@ package geminiCLI
import (
"context"
"fmt"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
"github.com/tidwall/sjson"
@@ -54,3 +55,7 @@ func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName str
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}

View File

@@ -12,8 +12,9 @@ func init() {
Codex,
ConvertGeminiCLIRequestToCodex,
interfaces.TranslateResponse{
Stream: ConvertCodexResponseToGeminiCLI,
NonStream: ConvertCodexResponseToGeminiCLINonStream,
Stream: ConvertCodexResponseToGeminiCLI,
NonStream: ConvertCodexResponseToGeminiCLINonStream,
TokenCount: GeminiCLITokenCount,
},
)
}

View File

@@ -8,6 +8,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"time"
"github.com/tidwall/gjson"
@@ -330,3 +331,7 @@ func mustMarshalJSON(v interface{}) string {
}
return string(data)
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}

View File

@@ -12,8 +12,9 @@ func init() {
Codex,
ConvertGeminiRequestToCodex,
interfaces.TranslateResponse{
Stream: ConvertCodexResponseToGemini,
NonStream: ConvertCodexResponseToGeminiNonStream,
Stream: ConvertCodexResponseToGemini,
NonStream: ConvertCodexResponseToGeminiNonStream,
TokenCount: GeminiTokenCount,
},
)
}

View File

@@ -6,7 +6,6 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -74,7 +73,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
if hasOfficialInstructions {
return rawJSON
}
log.Debugf("instructions not matched, %s\n", originalInstructions)
// log.Debugf("instructions not matched, %s\n", originalInstructions)
if len(inputResults) > 0 {
newInput := "[]"

View File

@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
}
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["Image", "Text"]
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
var responseMods []string
for _, m := range mods.Array() {
switch strings.ToLower(m.String()) {
case "text":
responseMods = append(responseMods, "Text")
responseMods = append(responseMods, "TEXT")
case "image":
responseMods = append(responseMods, "Image")
responseMods = append(responseMods, "IMAGE")
}
}
if len(responseMods) > 0 {

View File

@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
// Map OpenAI modalities -> Gemini generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["Image", "Text"]
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
var responseMods []string
for _, m := range mods.Array() {
switch strings.ToLower(m.String()) {
case "text":
responseMods = append(responseMods, "Text")
responseMods = append(responseMods, "TEXT")
case "image":
responseMods = append(responseMods, "Image")
responseMods = append(responseMods, "IMAGE")
}
}
if len(responseMods) > 0 {

View File

@@ -12,8 +12,9 @@ func init() {
OpenAI,
ConvertClaudeRequestToOpenAI,
interfaces.TranslateResponse{
Stream: ConvertOpenAIResponseToClaude,
NonStream: ConvertOpenAIResponseToClaudeNonStream,
Stream: ConvertOpenAIResponseToClaude,
NonStream: ConvertOpenAIResponseToClaudeNonStream,
TokenCount: ClaudeTokenCount,
},
)
}

View File

@@ -9,6 +9,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -630,3 +631,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
}
return string(responseJSON)
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
}

View File

@@ -12,8 +12,9 @@ func init() {
OpenAI,
ConvertGeminiCLIRequestToOpenAI,
interfaces.TranslateResponse{
Stream: ConvertOpenAIResponseToGeminiCLI,
NonStream: ConvertOpenAIResponseToGeminiCLINonStream,
Stream: ConvertOpenAIResponseToGeminiCLI,
NonStream: ConvertOpenAIResponseToGeminiCLINonStream,
TokenCount: GeminiCLITokenCount,
},
)
}

View File

@@ -7,6 +7,7 @@ package geminiCLI
import (
"context"
"fmt"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini"
"github.com/tidwall/sjson"
@@ -51,3 +52,7 @@ func ConvertOpenAIResponseToGeminiCLINonStream(ctx context.Context, modelName st
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}

View File

@@ -12,8 +12,9 @@ func init() {
OpenAI,
ConvertGeminiRequestToOpenAI,
interfaces.TranslateResponse{
Stream: ConvertOpenAIResponseToGemini,
NonStream: ConvertOpenAIResponseToGeminiNonStream,
Stream: ConvertOpenAIResponseToGemini,
NonStream: ConvertOpenAIResponseToGeminiNonStream,
TokenCount: GeminiTokenCount,
},
)
}

View File

@@ -9,6 +9,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
@@ -609,3 +610,7 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
return out
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}

View File

@@ -4,6 +4,7 @@
package util
import (
"net/url"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -188,3 +189,56 @@ func MaskSensitiveHeaderValue(key, value string) string {
return value
}
}
// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string.
func MaskSensitiveQuery(raw string) string {
if raw == "" {
return ""
}
parts := strings.Split(raw, "&")
changed := false
for i, part := range parts {
if part == "" {
continue
}
keyPart := part
valuePart := ""
if idx := strings.Index(part, "="); idx >= 0 {
keyPart = part[:idx]
valuePart = part[idx+1:]
}
decodedKey, err := url.QueryUnescape(keyPart)
if err != nil {
decodedKey = keyPart
}
if !shouldMaskQueryParam(decodedKey) {
continue
}
decodedValue, err := url.QueryUnescape(valuePart)
if err != nil {
decodedValue = valuePart
}
masked := HideAPIKey(strings.TrimSpace(decodedValue))
parts[i] = keyPart + "=" + url.QueryEscape(masked)
changed = true
}
if !changed {
return raw
}
return strings.Join(parts, "&")
}
func shouldMaskQueryParam(key string) bool {
key = strings.ToLower(strings.TrimSpace(key))
if key == "" {
return false
}
key = strings.TrimSuffix(key, "[]")
if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") {
return true
}
if strings.Contains(key, "token") || strings.Contains(key, "secret") {
return true
}
return false
}

View File

@@ -1204,6 +1204,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.ProxyURL != newCfg.ProxyURL {
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
}
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
}
// Quota-exceeded behavior
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {

233
internal/wsrelay/http.go Normal file
View File

@@ -0,0 +1,233 @@
package wsrelay
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
)
// HTTPRequest represents a proxied HTTP request delivered to websocket clients.
type HTTPRequest struct {
Method string
URL string
Headers http.Header
Body []byte
}
// HTTPResponse captures the response relayed back from websocket clients.
type HTTPResponse struct {
Status int
Headers http.Header
Body []byte
}
// StreamEvent represents a streaming response event from clients.
type StreamEvent struct {
Type string
Payload []byte
Status int
Headers http.Header
Err error
}
// NonStream executes a non-streaming HTTP request using the websocket provider.
func (m *Manager) NonStream(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
if req == nil {
return nil, fmt.Errorf("wsrelay: request is nil")
}
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
respCh, err := m.Send(ctx, provider, msg)
if err != nil {
return nil, err
}
var (
streamMode bool
streamResp *HTTPResponse
streamBody bytes.Buffer
)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case msg, ok := <-respCh:
if !ok {
if streamMode {
if streamResp == nil {
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
} else if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
return streamResp, nil
}
return nil, errors.New("wsrelay: connection closed during response")
}
switch msg.Type {
case MessageTypeHTTPResp:
resp := decodeResponse(msg.Payload)
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 {
resp.Body = append(resp.Body[:0], streamBody.Bytes()...)
}
return resp, nil
case MessageTypeError:
return nil, decodeError(msg.Payload)
case MessageTypeStreamStart, MessageTypeStreamChunk:
if msg.Type == MessageTypeStreamStart {
streamMode = true
streamResp = decodeResponse(msg.Payload)
if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamBody.Reset()
continue
}
if !streamMode {
streamMode = true
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
}
chunk := decodeChunk(msg.Payload)
if len(chunk) > 0 {
streamBody.Write(chunk)
}
case MessageTypeStreamEnd:
if !streamMode {
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil
}
if streamResp == nil {
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
} else if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
return streamResp, nil
default:
}
}
}
}
// Stream executes a streaming HTTP request and returns channel with stream events.
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) {
if req == nil {
return nil, fmt.Errorf("wsrelay: request is nil")
}
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
respCh, err := m.Send(ctx, provider, msg)
if err != nil {
return nil, err
}
out := make(chan StreamEvent)
go func() {
defer close(out)
for {
select {
case <-ctx.Done():
out <- StreamEvent{Err: ctx.Err()}
return
case msg, ok := <-respCh:
if !ok {
out <- StreamEvent{Err: errors.New("wsrelay: stream closed")}
return
}
switch msg.Type {
case MessageTypeStreamStart:
resp := decodeResponse(msg.Payload)
out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}
case MessageTypeStreamChunk:
chunk := decodeChunk(msg.Payload)
out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}
case MessageTypeStreamEnd:
out <- StreamEvent{Type: MessageTypeStreamEnd}
return
case MessageTypeError:
out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)}
return
case MessageTypeHTTPResp:
resp := decodeResponse(msg.Payload)
out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body}
return
default:
}
}
}
}()
return out, nil
}
func encodeRequest(req *HTTPRequest) map[string]any {
headers := make(map[string]any, len(req.Headers))
for key, values := range req.Headers {
copyValues := make([]string, len(values))
copy(copyValues, values)
headers[key] = copyValues
}
return map[string]any{
"method": req.Method,
"url": req.URL,
"headers": headers,
"body": string(req.Body),
"sent_at": time.Now().UTC().Format(time.RFC3339Nano),
}
}
func decodeResponse(payload map[string]any) *HTTPResponse {
if payload == nil {
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)}
}
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
if status, ok := payload["status"].(float64); ok {
resp.Status = int(status)
}
if headers, ok := payload["headers"].(map[string]any); ok {
for key, raw := range headers {
switch v := raw.(type) {
case []any:
for _, item := range v {
if str, ok := item.(string); ok {
resp.Headers.Add(key, str)
}
}
case []string:
for _, str := range v {
resp.Headers.Add(key, str)
}
case string:
resp.Headers.Set(key, v)
}
}
}
if body, ok := payload["body"].(string); ok {
resp.Body = []byte(body)
}
return resp
}
func decodeChunk(payload map[string]any) []byte {
if payload == nil {
return nil
}
if data, ok := payload["data"].(string); ok {
return []byte(data)
}
return nil
}
func decodeError(payload map[string]any) error {
if payload == nil {
return errors.New("wsrelay: unknown error")
}
message, _ := payload["error"].(string)
status := 0
if v, ok := payload["status"].(float64); ok {
status = int(v)
}
if message == "" {
message = "wsrelay: upstream error"
}
return fmt.Errorf("%s (status=%d)", message, status)
}

205
internal/wsrelay/manager.go Normal file
View File

@@ -0,0 +1,205 @@
package wsrelay
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
// Manager exposes a websocket endpoint that proxies Gemini requests to
// connected clients.
type Manager struct {
path string
upgrader websocket.Upgrader
sessions map[string]*session
sessMutex sync.RWMutex
providerFactory func(*http.Request) (string, error)
onConnected func(string)
onDisconnected func(string, error)
logDebugf func(string, ...any)
logInfof func(string, ...any)
logWarnf func(string, ...any)
}
// Options configures a Manager instance.
type Options struct {
Path string
ProviderFactory func(*http.Request) (string, error)
OnConnected func(string)
OnDisconnected func(string, error)
LogDebugf func(string, ...any)
LogInfof func(string, ...any)
LogWarnf func(string, ...any)
}
// NewManager builds a websocket relay manager with the supplied options.
func NewManager(opts Options) *Manager {
path := strings.TrimSpace(opts.Path)
if path == "" {
path = "/v1/ws"
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
mgr := &Manager{
path: path,
sessions: make(map[string]*session),
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
},
providerFactory: opts.ProviderFactory,
onConnected: opts.OnConnected,
onDisconnected: opts.OnDisconnected,
logDebugf: opts.LogDebugf,
logInfof: opts.LogInfof,
logWarnf: opts.LogWarnf,
}
if mgr.logDebugf == nil {
mgr.logDebugf = func(string, ...any) {}
}
if mgr.logInfof == nil {
mgr.logInfof = func(string, ...any) {}
}
if mgr.logWarnf == nil {
mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) }
}
return mgr
}
// Path returns the HTTP path the manager expects for websocket upgrades.
func (m *Manager) Path() string {
if m == nil {
return "/v1/ws"
}
return m.path
}
// Handler exposes an http.Handler that upgrades connections to websocket sessions.
func (m *Manager) Handler() http.Handler {
return http.HandlerFunc(m.handleWebsocket)
}
// Stop gracefully closes all active websocket sessions.
func (m *Manager) Stop(_ context.Context) error {
m.sessMutex.Lock()
sessions := make([]*session, 0, len(m.sessions))
for _, sess := range m.sessions {
sessions = append(sessions, sess)
}
m.sessions = make(map[string]*session)
m.sessMutex.Unlock()
for _, sess := range sessions {
if sess != nil {
sess.cleanup(errors.New("wsrelay: manager stopped"))
}
}
return nil
}
// handleWebsocket upgrades the connection and wires the session into the pool.
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
expectedPath := m.Path()
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
http.NotFound(w, r)
return
}
if !strings.EqualFold(r.Method, http.MethodGet) {
w.Header().Set("Allow", http.MethodGet)
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
conn, err := m.upgrader.Upgrade(w, r, nil)
if err != nil {
m.logWarnf("wsrelay: upgrade failed: %v", err)
return
}
s := newSession(conn, m, randomProviderName())
if m.providerFactory != nil {
name, err := m.providerFactory(r)
if err != nil {
s.cleanup(err)
return
}
if strings.TrimSpace(name) != "" {
s.provider = strings.ToLower(name)
}
}
if s.provider == "" {
s.provider = strings.ToLower(s.id)
}
m.sessMutex.Lock()
var replaced *session
if existing, ok := m.sessions[s.provider]; ok {
replaced = existing
}
m.sessions[s.provider] = s
m.sessMutex.Unlock()
if replaced != nil {
replaced.cleanup(errors.New("replaced by new connection"))
}
if m.onConnected != nil {
m.onConnected(s.provider)
}
go s.run(context.Background())
}
// Send forwards the message to the specific provider connection and returns a channel
// yielding response messages.
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
s := m.session(provider)
if s == nil {
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
}
return s.request(ctx, msg)
}
func (m *Manager) session(provider string) *session {
key := strings.ToLower(strings.TrimSpace(provider))
m.sessMutex.RLock()
s := m.sessions[key]
m.sessMutex.RUnlock()
return s
}
func (m *Manager) handleSessionClosed(s *session, cause error) {
if s == nil {
return
}
key := strings.ToLower(strings.TrimSpace(s.provider))
m.sessMutex.Lock()
if cur, ok := m.sessions[key]; ok && cur == s {
delete(m.sessions, key)
}
m.sessMutex.Unlock()
if m.onDisconnected != nil {
m.onDisconnected(s.provider, cause)
}
}
func randomProviderName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
buf := make([]byte, 16)
if _, err := rand.Read(buf); err != nil {
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
}
for i := range buf {
buf[i] = alphabet[int(buf[i])%len(alphabet)]
}
return "aistudio-" + string(buf)
}

View File

@@ -0,0 +1,27 @@
package wsrelay
// Message represents the JSON payload exchanged with websocket clients.
type Message struct {
ID string `json:"id"`
Type string `json:"type"`
Payload map[string]any `json:"payload,omitempty"`
}
const (
// MessageTypeHTTPReq identifies an HTTP-style request envelope.
MessageTypeHTTPReq = "http_request"
// MessageTypeHTTPResp identifies a non-streaming HTTP response envelope.
MessageTypeHTTPResp = "http_response"
// MessageTypeStreamStart marks the beginning of a streaming response.
MessageTypeStreamStart = "stream_start"
// MessageTypeStreamChunk carries a streaming response chunk.
MessageTypeStreamChunk = "stream_chunk"
// MessageTypeStreamEnd marks the completion of a streaming response.
MessageTypeStreamEnd = "stream_end"
// MessageTypeError carries an error response.
MessageTypeError = "error"
// MessageTypePing represents ping messages from clients.
MessageTypePing = "ping"
// MessageTypePong represents pong responses back to clients.
MessageTypePong = "pong"
)

188
internal/wsrelay/session.go Normal file
View File

@@ -0,0 +1,188 @@
package wsrelay
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/gorilla/websocket"
)
const (
readTimeout = 60 * time.Second
writeTimeout = 10 * time.Second
maxInboundMessageLen = 64 << 20 // 64 MiB
heartbeatInterval = 30 * time.Second
)
var errClosed = errors.New("websocket session closed")
type pendingRequest struct {
ch chan Message
closeOnce sync.Once
}
func (pr *pendingRequest) close() {
if pr == nil {
return
}
pr.closeOnce.Do(func() {
close(pr.ch)
})
}
type session struct {
conn *websocket.Conn
manager *Manager
provider string
id string
closed chan struct{}
closeOnce sync.Once
writeMutex sync.Mutex
pending sync.Map // map[string]*pendingRequest
}
func newSession(conn *websocket.Conn, mgr *Manager, id string) *session {
s := &session{
conn: conn,
manager: mgr,
provider: "",
id: id,
closed: make(chan struct{}),
}
conn.SetReadLimit(maxInboundMessageLen)
conn.SetReadDeadline(time.Now().Add(readTimeout))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(readTimeout))
return nil
})
s.startHeartbeat()
return s
}
func (s *session) startHeartbeat() {
if s == nil || s.conn == nil {
return
}
ticker := time.NewTicker(heartbeatInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-s.closed:
return
case <-ticker.C:
s.writeMutex.Lock()
err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout))
s.writeMutex.Unlock()
if err != nil {
s.cleanup(err)
return
}
}
}
}()
}
func (s *session) run(ctx context.Context) {
defer s.cleanup(errClosed)
for {
var msg Message
if err := s.conn.ReadJSON(&msg); err != nil {
s.cleanup(err)
return
}
s.dispatch(msg)
}
}
func (s *session) dispatch(msg Message) {
if msg.Type == MessageTypePing {
_ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong})
return
}
if value, ok := s.pending.Load(msg.ID); ok {
req := value.(*pendingRequest)
select {
case req.ch <- msg:
default:
}
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
actual.(*pendingRequest).close()
}
}
return
}
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider)
}
}
func (s *session) send(ctx context.Context, msg Message) error {
select {
case <-s.closed:
return errClosed
default:
}
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
if err := s.conn.WriteJSON(msg); err != nil {
return fmt.Errorf("write json: %w", err)
}
return nil
}
func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) {
if msg.ID == "" {
return nil, fmt.Errorf("wsrelay: message id is required")
}
if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded {
return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID)
}
value, _ := s.pending.Load(msg.ID)
req := value.(*pendingRequest)
if err := s.send(ctx, msg); err != nil {
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
req := actual.(*pendingRequest)
req.close()
}
return nil, err
}
go func() {
select {
case <-ctx.Done():
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
actual.(*pendingRequest).close()
}
case <-s.closed:
}
}()
return req.ch, nil
}
func (s *session) cleanup(cause error) {
s.closeOnce.Do(func() {
close(s.closed)
s.pending.Range(func(key, value any) bool {
req := value.(*pendingRequest)
msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}}
select {
case req.ch <- msg:
default:
}
req.close()
return true
})
s.pending = sync.Map{}
_ = s.conn.Close()
if s.manager != nil {
s.manager.handleSessionClosed(s, cause)
}
})
}

View File

@@ -153,6 +153,17 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
m.executors[executor.Identifier()] = executor
}
// UnregisterExecutor removes the executor associated with the provider key.
func (m *Manager) UnregisterExecutor(provider string) {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return
}
m.mu.Lock()
delete(m.executors, provider)
m.mu.Unlock()
}
// Register inserts a new auth entry into the manager.
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil {

View File

@@ -156,7 +156,17 @@ func (a *Auth) AccountInfo() (string, string) {
if v, ok := a.Metadata["email"].(string); ok {
return "oauth", v
}
} else if a.Attributes != nil {
}
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
if label := strings.TrimSpace(a.Label); label != "" {
return "oauth", label
}
if id := strings.TrimSpace(a.ID); id != "" {
return "oauth", id
}
return "oauth", "aistudio"
}
if a.Attributes != nil {
if v := a.Attributes["api_key"]; v != "" {
return "api_key", v
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -82,6 +83,9 @@ type Service struct {
// shutdownOnce ensures shutdown is called only once.
shutdownOnce sync.Once
// wsGateway manages websocket Gemini providers.
wsGateway *wsrelay.Manager
}
// RegisterUsagePlugin registers a usage plugin on the global usage manager.
@@ -172,6 +176,72 @@ func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdat
}
}
func (s *Service) ensureWebsocketGateway() {
if s == nil {
return
}
if s.wsGateway != nil {
return
}
opts := wsrelay.Options{
Path: "/v1/ws",
OnConnected: s.wsOnConnected,
OnDisconnected: s.wsOnDisconnected,
LogDebugf: log.Debugf,
LogInfof: log.Infof,
LogWarnf: log.Warnf,
}
s.wsGateway = wsrelay.NewManager(opts)
}
func (s *Service) wsOnConnected(provider string) {
if s == nil || provider == "" {
return
}
if !strings.HasPrefix(strings.ToLower(provider), "aistudio-") {
return
}
if s.coreManager != nil {
if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil {
if !existing.Disabled && existing.Status == coreauth.StatusActive {
return
}
}
}
now := time.Now().UTC()
auth := &coreauth.Auth{
ID: provider,
Provider: provider,
Label: provider,
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Attributes: map[string]string{"ws_provider": "gemini"},
}
log.Infof("websocket provider connected: %s", provider)
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
}
func (s *Service) wsOnDisconnected(provider string, reason error) {
if s == nil || provider == "" {
return
}
if reason != nil {
if strings.Contains(reason.Error(), "replaced by new connection") {
log.Infof("websocket provider replaced: %s", provider)
return
}
log.Warnf("websocket provider disconnected: %s (%v)", provider, reason)
} else {
log.Infof("websocket provider disconnected: %s", provider)
}
ctx := context.Background()
s.applyCoreAuthRemoval(ctx, provider)
if s.coreManager != nil {
s.coreManager.UnregisterExecutor(provider)
}
}
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
if s == nil || auth == nil || auth.ID == "" {
return
@@ -247,6 +317,12 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg))
return
}
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
if s.wsGateway != nil {
s.coreManager.RegisterExecutor(executor.NewAistudioExecutor(s.cfg, a.Provider, s.wsGateway))
}
return
}
switch strings.ToLower(a.Provider) {
case "gemini":
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
@@ -342,6 +418,27 @@ func (s *Service) Run(ctx context.Context) error {
s.authManager = newDefaultAuthManager()
}
s.ensureWebsocketGateway()
if s.server != nil && s.wsGateway != nil {
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) {
if oldEnabled == newEnabled {
return
}
if !oldEnabled && newEnabled {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if errStop := s.wsGateway.Stop(ctx); errStop != nil {
log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop)
return
}
log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication")
return
}
log.Debugf("ws-auth disabled; existing websocket sessions remain connected")
})
}
if s.hooks.OnBeforeStart != nil {
s.hooks.OnBeforeStart(s.cfg)
}
@@ -379,7 +476,6 @@ func (s *Service) Run(ctx context.Context) error {
s.cfg = newCfg
s.cfgMu.Unlock()
s.rebindExecutors()
}
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)
@@ -449,6 +545,14 @@ func (s *Service) Shutdown(ctx context.Context) error {
shutdownErr = err
}
}
if s.wsGateway != nil {
if err := s.wsGateway.Stop(ctx); err != nil {
log.Errorf("failed to stop websocket gateway: %v", err)
if shutdownErr == nil {
shutdownErr = err
}
}
}
if s.authQueueStop != nil {
s.authQueueStop()
s.authQueueStop = nil
@@ -505,6 +609,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
provider := strings.ToLower(strings.TrimSpace(a.Provider))
compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a)
if a.Attributes != nil {
if strings.EqualFold(a.Attributes["ws_provider"], "gemini") {
models := mergeGeminiModels()
GlobalModelRegistry().RegisterClient(a.ID, provider, models)
return
}
}
if compatDetected {
provider = "openai-compatibility"
}
@@ -611,3 +722,24 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
GlobalModelRegistry().RegisterClient(a.ID, key, models)
}
}
func mergeGeminiModels() []*ModelInfo {
models := make([]*ModelInfo, 0, 16)
seen := make(map[string]struct{})
appendModels := func(items []*ModelInfo) {
for i := range items {
m := items[i]
if m == nil || m.ID == "" {
continue
}
if _, ok := seen[m.ID]; ok {
continue
}
seen[m.ID] = struct{}{}
models = append(models, m)
}
}
appendModels(registry.GetGeminiModels())
appendModels(registry.GetGeminiCLIModels())
return models
}