Compare commits

...

29 Commits

Author SHA1 Message Date
Luis Pater
2662f91082 feat(management): add PostOAuthCallback handler to token requester interface 2026-01-07 10:47:32 +08:00
Luis Pater
c1db2c7d7c Merge pull request #888 from router-for-me/api-call-TOKEN-fix
fix(management): refresh antigravity token for api-call $TOKEN$
2026-01-07 01:19:24 +08:00
LTbinglingfeng
5e5d8142f9 fix(auth): error when antigravity refresh token missing during refresh 2026-01-07 01:09:50 +08:00
LTbinglingfeng
b01619b441 fix(management): refresh antigravity token for api-call $TOKEN$ 2026-01-07 00:14:02 +08:00
Luis Pater
f861bd6a94 docs: add 9Router to community projects in README 2026-01-06 23:15:28 +08:00
Luis Pater
6dbfdd140d Merge pull request #871 from decolua/patch-1
Update README.md
2026-01-06 22:58:53 +08:00
decolua
386ccffed4 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 20:54:33 +07:00
decolua
ffddd1c90a Update README.md 2026-01-05 20:29:26 +07:00
Luis Pater
8f8dfd081b Merge pull request #850 from can1357/main
feat(translator): add developer role support for Gemini translators
2026-01-05 11:27:24 +08:00
Luis Pater
9f1b445c7c docs: add ProxyPilot to community projects in Chinese README 2026-01-05 11:23:48 +08:00
Luis Pater
ae933dfe14 Merge pull request #858 from Finesssee/add-proxypilot
docs: add ProxyPilot to community projects
2026-01-05 11:20:52 +08:00
Luis Pater
e124db723b Merge pull request #862 from router-for-me/gemini
fix(gemini): abort default injection on existing thinking keys
2026-01-05 10:41:07 +08:00
hkfires
05444cf32d fix(gemini): abort default injection on existing thinking keys 2026-01-05 10:24:30 +08:00
Luis Pater
8edbda57cf feat(translator): add thoughtSignature to node parts for Gemini and Antigravity requests
Enhanced node structure by including `thoughtSignature` for inline data parts in Gemini OpenAI, Gemini CLI, and Antigravity request handlers to improve traceability of thought processes.
2026-01-05 09:25:17 +08:00
Finessse
821249a5ed docs: add ProxyPilot to community projects 2026-01-04 18:19:41 +07:00
Luis Pater
ee33863b47 Merge pull request #857 from router-for-me/management-update
Management update
2026-01-04 18:07:13 +08:00
Supra4E8C
cd22c849e2 feat(management): 更新OAuth模型映射的清理逻辑以增强数据安全性 2026-01-04 17:57:34 +08:00
Supra4E8C
f0e73efda2 feat(management): add vertex api key and oauth model mappings endpoints 2026-01-04 17:32:00 +08:00
Supra4E8C
3156109c71 feat(management): 支持管理接口调整日志大小/强制前缀/路由策略 2026-01-04 12:21:49 +08:00
can1357
6762e081f3 feat(translator): add developer role support for Gemini translators
Treat OpenAI's "developer" role the same as "system" role in request
translation for gemini, gemini-cli, and antigravity backends.
2026-01-03 21:01:01 +01:00
Luis Pater
7815ee338d fix(translator): adjust message_delta emission boundary in Claude-to-OpenAI conversion
Fixed incorrect boundary logic for `message_delta` emission, ensuring proper handling of usage updates and `emitMessageStopIfNeeded` within the response loop.
2026-01-04 01:36:51 +08:00
Luis Pater
44b6c872e2 feat(config): add support for Fork in OAuth model mappings with alias handling
Implemented `Fork` flag in `ModelNameMapping` to allow aliases as additional models while preserving the original model ID. Updated the `applyOAuthModelMappings` logic, added tests for `Fork` behavior, and updated documentation and examples accordingly.
2026-01-04 01:18:29 +08:00
Luis Pater
7a77b23f2d feat(executor): add token refresh timeout and improve context handling during refresh
Introduced `tokenRefreshTimeout` constant for token refresh operations and enhanced context propagation for `refreshToken` by embedding roundtrip information if available. Adjusted `refreshAuth` to ensure default context initialization and handle cancellation errors appropriately.
2026-01-04 00:26:08 +08:00
Luis Pater
672e8549c0 docs: reorganize README to adjust CodMate placement
Moved CodMate entry under ProxyPal in both English and Chinese README files for consistency in structure and better readability.
2026-01-03 21:31:53 +08:00
Luis Pater
66f5269a23 Merge pull request #837 from loocor/main
docs: add CodMate to community projects
2026-01-03 21:30:15 +08:00
Luis Pater
ebec293497 feat(api): integrate TokenStore for improved auth entry management
Replaced file-based auth entry counting with `TokenStore`-backed implementation, enhancing flexibility and context-aware token management. Updated related logic to reflect this change.
2026-01-03 04:53:47 +08:00
Luis Pater
e02ceecd35 feat(registry): introduce ModelRegistryHook for monitoring model registrations and unregistrations
Added support for external hooks to observe model registry events using the `ModelRegistryHook` interface. Implemented thread-safe, non-blocking execution of hooks with panic recovery. Comprehensive tests added to verify hook behavior during registration, unregistration, blocking, and panic scenarios.
2026-01-02 23:18:40 +08:00
Loocor
dca8d5ded8 Add CodMate app information to README
Added CodMate section to README with app details.
2026-01-02 17:15:38 +08:00
Loocor
2a7fd1e897 Add CodMate description to README_CN.md
添加 CodMate 应用的描述,提供 CLI AI 会话管理功能。
2026-01-02 17:15:09 +08:00
25 changed files with 1267 additions and 81 deletions

View File

@@ -118,9 +118,28 @@ Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings,
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
### [CodMate](https://github.com/loocor/CodMate)
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
## More choices
Those projects are ports of CLIProxyAPI or inspired by it:
### [9Router](https://github.com/decolua/9router)
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

View File

@@ -117,9 +117,28 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
### [CodMate](https://github.com/loocor/CodMate)
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话Claude Code、Codex、Gemini CLI提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。
## 更多选择
以下项目是 CLIProxyAPI 的移植版或受其启发:
### [9Router](https://github.com/decolua/9router)
基于 Next.js 的实现,灵感来自 CLIProxyAPI易于安装使用自研格式转换OpenAI/Claude/Gemini/Ollama、组合系统与自动回退、多账户管理指数退避、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -206,6 +206,7 @@ ws-auth: false
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
# vertex:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"

View File

@@ -33,6 +33,13 @@ var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/userinfo.profile",
}
const (
antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityOAuthClientSecret = "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"
)
var antigravityOAuthTokenURL = "https://oauth2.googleapis.com/token"
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
@@ -251,6 +258,10 @@ func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth)
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
if provider == "antigravity" {
token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
@@ -325,6 +336,161 @@ func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *corea
return strings.TrimSpace(currentToken.AccessToken), nil
}
func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata := auth.Metadata
if len(metadata) == 0 {
return "", fmt.Errorf("antigravity oauth metadata missing")
}
current := strings.TrimSpace(tokenValueFromMetadata(metadata))
if current != "" && !antigravityTokenNeedsRefresh(metadata) {
return current, nil
}
refreshToken := stringValue(metadata, "refresh_token")
if refreshToken == "" {
return "", fmt.Errorf("antigravity refresh token missing")
}
tokenURL := strings.TrimSpace(antigravityOAuthTokenURL)
if tokenURL == "" {
tokenURL = "https://oauth2.googleapis.com/token"
}
form := url.Values{}
form.Set("client_id", antigravityOAuthClientID)
form.Set("client_secret", antigravityOAuthClientSecret)
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
req, errReq := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if errReq != nil {
return "", errReq
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", errRead
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("antigravity oauth token refresh failed: status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
return "", errUnmarshal
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token")
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
now := time.Now()
auth.Metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if strings.TrimSpace(tokenResp.RefreshToken) != "" {
auth.Metadata["refresh_token"] = strings.TrimSpace(tokenResp.RefreshToken)
}
if tokenResp.ExpiresIn > 0 {
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
auth.Metadata["timestamp"] = now.UnixMilli()
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
}
auth.Metadata["type"] = "antigravity"
if h != nil && h.authManager != nil {
auth.LastRefreshedAt = now
auth.UpdatedAt = now
_, _ = h.authManager.Update(ctx, auth)
}
return strings.TrimSpace(tokenResp.AccessToken), nil
}
func antigravityTokenNeedsRefresh(metadata map[string]any) bool {
// Refresh a bit early to avoid requests racing token expiry.
const skew = 30 * time.Second
if metadata == nil {
return true
}
if expStr, ok := metadata["expired"].(string); ok {
if ts, errParse := time.Parse(time.RFC3339, strings.TrimSpace(expStr)); errParse == nil {
return !ts.After(time.Now().Add(skew))
}
}
expiresIn := int64Value(metadata["expires_in"])
timestampMs := int64Value(metadata["timestamp"])
if expiresIn > 0 && timestampMs > 0 {
exp := time.UnixMilli(timestampMs).Add(time.Duration(expiresIn) * time.Second)
return !exp.After(time.Now().Add(skew))
}
return true
}
func int64Value(raw any) int64 {
switch typed := raw.(type) {
case int:
return int64(typed)
case int32:
return int64(typed)
case int64:
return typed
case uint:
return int64(typed)
case uint32:
return int64(typed)
case uint64:
if typed > uint64(^uint64(0)>>1) {
return 0
}
return int64(typed)
case float32:
return int64(typed)
case float64:
return int64(typed)
case json.Number:
if i, errParse := typed.Int64(); errParse == nil {
return i
}
case string:
if s := strings.TrimSpace(typed); s != "" {
if i, errParse := json.Number(s).Int64(); errParse == nil {
return i
}
}
}
return 0
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil

View File

@@ -0,0 +1,173 @@
package management
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
_ = ctx
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
h := &Handler{authManager: manager}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
}
updated, ok := manager.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth in manager after update")
}
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
t.Fatalf("expected manager metadata updated, got %q", got)
}
}
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token)
}
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount)
}
}

View File

@@ -202,6 +202,26 @@ func (h *Handler) PutLoggingToFile(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
}
// LogsMaxTotalSizeMB
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
}
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
var body struct {
Value *int `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
value := *body.Value
if value < 0 {
value = 0
}
h.cfg.LogsMaxTotalSizeMB = value
h.persist(c)
}
// Request log
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
func (h *Handler) PutRequestLog(c *gin.Context) {
@@ -232,6 +252,52 @@ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
}
// ForceModelPrefix
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
}
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
}
func normalizeRoutingStrategy(strategy string) (string, bool) {
normalized := strings.ToLower(strings.TrimSpace(strategy))
switch normalized {
case "", "round-robin", "roundrobin", "rr":
return "round-robin", true
case "fill-first", "fillfirst", "ff":
return "fill-first", true
default:
return "", false
}
}
// RoutingStrategy
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
if !ok {
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
return
}
c.JSON(200, gin.H{"strategy": strategy})
}
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
var body struct {
Value *string `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
normalized, ok := normalizeRoutingStrategy(*body.Value)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
return
}
h.cfg.Routing.Strategy = normalized
h.persist(c)
}
// Proxy URL
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
func (h *Handler) PutProxyURL(c *gin.Context) {

View File

@@ -487,6 +487,137 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
c.JSON(400, gin.H{"error": "missing name or index"})
}
// vertex-api-key: []VertexCompatKey
func (h *Handler) GetVertexCompatKeys(c *gin.Context) {
c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey})
}
func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.VertexCompatKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.VertexCompatKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
for i := range arr {
normalizeVertexCompatKey(&arr[i])
}
h.cfg.VertexCompatAPIKey = arr
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
type vertexCompatPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
Models *[]config.VertexCompatModel `json:"models"`
}
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *vertexCompatPatch `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
if match != "" {
for i := range h.cfg.VertexCompatAPIKey {
if h.cfg.VertexCompatAPIKey[i].APIKey == match {
targetIndex = i
break
}
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.VertexCompatAPIKey[targetIndex]
if body.Value.APIKey != nil {
trimmed := strings.TrimSpace(*body.Value.APIKey)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.APIKey = trimmed
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.Models != nil {
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
}
normalizeVertexCompatKey(&entry)
h.cfg.VertexCompatAPIKey[targetIndex] = entry
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
for _, v := range h.cfg.VertexCompatAPIKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.VertexCompatAPIKey = out
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, errScan := fmt.Sscanf(idxStr, "%d", &idx)
if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
// oauth-excluded-models: map[string][]string
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
@@ -572,6 +703,103 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
h.persist(c)
}
// oauth-model-mappings: map[string][]ModelNameMapping
func (h *Handler) GetOAuthModelMappings(c *gin.Context) {
c.JSON(200, gin.H{"oauth-model-mappings": sanitizedOAuthModelMappings(h.cfg.OAuthModelMappings)})
}
func (h *Handler) PutOAuthModelMappings(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var entries map[string][]config.ModelNameMapping
if err = json.Unmarshal(data, &entries); err != nil {
var wrapper struct {
Items map[string][]config.ModelNameMapping `json:"items"`
}
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
entries = wrapper.Items
}
h.cfg.OAuthModelMappings = sanitizedOAuthModelMappings(entries)
h.persist(c)
}
func (h *Handler) PatchOAuthModelMappings(c *gin.Context) {
var body struct {
Provider *string `json:"provider"`
Channel *string `json:"channel"`
Mappings []config.ModelNameMapping `json:"mappings"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
channelRaw := ""
if body.Channel != nil {
channelRaw = *body.Channel
} else if body.Provider != nil {
channelRaw = *body.Provider
}
channel := strings.ToLower(strings.TrimSpace(channelRaw))
if channel == "" {
c.JSON(400, gin.H{"error": "invalid channel"})
return
}
normalizedMap := sanitizedOAuthModelMappings(map[string][]config.ModelNameMapping{channel: body.Mappings})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
if h.cfg.OAuthModelMappings == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelMappings, channel)
if len(h.cfg.OAuthModelMappings) == 0 {
h.cfg.OAuthModelMappings = nil
}
h.persist(c)
return
}
if h.cfg.OAuthModelMappings == nil {
h.cfg.OAuthModelMappings = make(map[string][]config.ModelNameMapping)
}
h.cfg.OAuthModelMappings[channel] = normalized
h.persist(c)
}
func (h *Handler) DeleteOAuthModelMappings(c *gin.Context) {
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
if channel == "" {
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
}
if channel == "" {
c.JSON(400, gin.H{"error": "missing channel"})
return
}
if h.cfg.OAuthModelMappings == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelMappings, channel)
if len(h.cfg.OAuthModelMappings) == 0 {
h.cfg.OAuthModelMappings = nil
}
h.persist(c)
}
// codex-api-key: []CodexKey
func (h *Handler) GetCodexKeys(c *gin.Context) {
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
@@ -789,6 +1017,53 @@ func normalizeCodexKey(entry *config.CodexKey) {
entry.Models = normalized
}
func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.VertexCompatModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" || model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func sanitizedOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping {
if len(entries) == 0 {
return nil
}
copied := make(map[string][]config.ModelNameMapping, len(entries))
for channel, mappings := range entries {
if len(mappings) == 0 {
continue
}
copied[channel] = append([]config.ModelNameMapping(nil), mappings...)
}
if len(copied) == 0 {
return nil
}
cfg := config.Config{OAuthModelMappings: copied}
cfg.SanitizeOAuthModelMappings()
if len(cfg.OAuthModelMappings) == 0 {
return nil
}
return cfg.OAuthModelMappings
}
// GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil {

View File

@@ -33,6 +33,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
@@ -491,6 +492,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
@@ -563,6 +568,14 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
@@ -578,11 +591,21 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
mgmt.GET("/oauth-model-mappings", s.mgmt.GetOAuthModelMappings)
mgmt.PUT("/oauth-model-mappings", s.mgmt.PutOAuthModelMappings)
mgmt.PATCH("/oauth-model-mappings", s.mgmt.PatchOAuthModelMappings)
mgmt.DELETE("/oauth-model-mappings", s.mgmt.DeleteOAuthModelMappings)
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
@@ -966,8 +989,12 @@ func (s *Server) UpdateClients(cfg *config.Config) {
log.Warnf("amp module is nil, skipping config update")
}
// Count client sources from configuration and auth directory
authFiles := util.CountAuthFiles(cfg.AuthDir)
// Count client sources from configuration and auth store.
tokenStore := sdkAuth.GetTokenStore()
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(cfg.AuthDir)
}
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
geminiAPIKeyCount := len(cfg.GeminiKey)
claudeAPIKeyCount := len(cfg.ClaudeKey)
codexAPIKeyCount := len(cfg.CodexKey)
@@ -978,10 +1005,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
openAICompatCount += len(entry.APIKeyEntries)
}
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total,
authFiles,
authEntries,
geminiAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,

View File

@@ -145,11 +145,14 @@ type RoutingConfig struct {
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// ModelNameMapping defines a model ID rename mapping for a specific channel.
// It maps the original model name (Name) to the client-visible alias (Alias).
// ModelNameMapping defines a model ID mapping for a specific channel.
// It maps the upstream model name (Name) to the client-visible alias (Alias).
// When Fork is true, the alias is added as an additional model in listings while
// keeping the original model ID available.
type ModelNameMapping struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests.
@@ -551,7 +554,7 @@ func (cfg *Config) SanitizeOAuthModelMappings() {
}
seenName[nameKey] = struct{}{}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
clean = append(clean, ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork})
}
if len(clean) > 0 {
out[channel] = clean

View File

@@ -0,0 +1,27 @@
package config
import "testing"
func TestSanitizeOAuthModelMappings_PreservesForkFlag(t *testing.T) {
cfg := &Config{
OAuthModelMappings: map[string][]ModelNameMapping{
" CoDeX ": {
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
{Name: "gpt-6", Alias: "g6"},
},
},
}
cfg.SanitizeOAuthModelMappings()
mappings := cfg.OAuthModelMappings["codex"]
if len(mappings) != 2 {
t.Fatalf("expected 2 sanitized mappings, got %d", len(mappings))
}
if mappings[0].Name != "gpt-5" || mappings[0].Alias != "g5" || !mappings[0].Fork {
t.Fatalf("expected first mapping to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", mappings[0].Name, mappings[0].Alias, mappings[0].Fork)
}
if mappings[1].Name != "gpt-6" || mappings[1].Alias != "g6" || mappings[1].Fork {
t.Fatalf("expected second mapping to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", mappings[1].Name, mappings[1].Alias, mappings[1].Fork)
}
}

View File

@@ -4,6 +4,7 @@
package registry
import (
"context"
"fmt"
"sort"
"strings"
@@ -84,6 +85,13 @@ type ModelRegistration struct {
SuspendedClients map[string]string
}
// ModelRegistryHook provides optional callbacks for external integrations to track model list changes.
// Hook implementations must be non-blocking and resilient; calls are executed asynchronously and panics are recovered.
type ModelRegistryHook interface {
OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo)
OnModelsUnregistered(ctx context.Context, provider, clientID string)
}
// ModelRegistry manages the global registry of available models
type ModelRegistry struct {
// models maps model ID to registration information
@@ -97,6 +105,8 @@ type ModelRegistry struct {
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
// hook is an optional callback sink for model registration changes
hook ModelRegistryHook
}
// Global model registry instance
@@ -117,6 +127,53 @@ func GetGlobalRegistry() *ModelRegistry {
return globalRegistry
}
// SetHook sets an optional hook for observing model registration changes.
func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
if r == nil {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.hook = hook
}
const defaultModelRegistryHookTimeout = 5 * time.Second
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook
if hook == nil {
return
}
modelsCopy := cloneModelInfosUnique(models)
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsRegistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsRegistered(ctx, provider, clientID, modelsCopy)
}()
}
func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
hook := r.hook
if hook == nil {
return
}
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsUnregistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsUnregistered(ctx, provider, clientID)
}()
}
// RegisterClient registers a client and its supported models
// Parameters:
// - clientID: Unique identifier for the client
@@ -177,6 +234,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
} else {
delete(r.clientProviders, clientID)
}
r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator()
return
@@ -310,6 +368,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientProviders, clientID)
}
r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
return
@@ -400,6 +459,25 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
return &copyModel
}
func cloneModelInfosUnique(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
cloned := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
if _, exists := seen[model.ID]; exists {
continue
}
seen[model.ID] = struct{}{}
cloned = append(cloned, cloneModelInfo(model))
}
return cloned
}
// UnregisterClient removes a client and decrements counts for its models
// Parameters:
// - clientID: Unique identifier for the client to remove
@@ -460,6 +538,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
log.Debugf("Unregistered client %s", clientID)
// Separator line after completing client unregistration (after the summary line)
misc.LogCredentialSeparator()
r.triggerModelsUnregistered(provider, clientID)
}
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client

View File

@@ -0,0 +1,204 @@
package registry
import (
"context"
"sync"
"testing"
"time"
)
func newTestModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
}
}
type registeredCall struct {
provider string
clientID string
models []*ModelInfo
}
type unregisteredCall struct {
provider string
clientID string
}
type capturingHook struct {
registeredCh chan registeredCall
unregisteredCh chan unregisteredCall
}
func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models}
}
func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID}
}
func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
inputModels := []*ModelInfo{
{ID: "m1", DisplayName: "Model One"},
{ID: "m2", DisplayName: "Model Two"},
}
r.RegisterClient("client-1", "OpenAI", inputModels)
select {
case call := <-hook.registeredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
if len(call.models) != 2 {
t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2)
}
if call.models[0] == nil || call.models[0].ID != "m1" {
t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1")
}
if call.models[1] == nil || call.models[1].ID != "m2" {
t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
}
func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
r.UnregisterClient("client-1")
select {
case call := <-hook.unregisteredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}
type blockingHook struct {
started chan struct{}
unblock chan struct{}
}
func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
select {
case <-h.started:
default:
close(h.started)
}
<-h.unblock
}
func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {}
func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) {
r := newTestModelRegistry()
hook := &blockingHook{
started: make(chan struct{}),
unblock: make(chan struct{}),
}
r.SetHook(hook)
defer close(hook.unblock)
done := make(chan struct{})
go func() {
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
close(done)
}()
select {
case <-hook.started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for hook to start")
}
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("RegisterClient appears to be blocked by hook")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
}
type panicHook struct {
registeredCalled chan struct{}
unregisteredCalled chan struct{}
}
func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
if h.registeredCalled != nil {
h.registeredCalled <- struct{}{}
}
panic("boom")
}
func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
if h.unregisteredCalled != nil {
h.unregisteredCalled <- struct{}{}
}
panic("boom")
}
func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) {
r := newTestModelRegistry()
hook := &panicHook{
registeredCalled: make(chan struct{}, 1),
unregisteredCalled: make(chan struct{}, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
r.UnregisterClient("client-1")
select {
case <-hook.unregisteredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}

View File

@@ -45,6 +45,7 @@ const (
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
tokenRefreshTimeout = 30 * time.Second
)
var (
@@ -914,7 +915,13 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr
if accessToken != "" && expiry.After(time.Now().Add(refreshSkew)) {
return accessToken, nil, nil
}
updated, errRefresh := e.refreshToken(ctx, auth.Clone())
refreshCtx := context.Background()
if ctx != nil {
if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil {
refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt)
}
}
updated, errRefresh := e.refreshToken(refreshCtx, auth.Clone())
if errRefresh != nil {
return "", nil, errRefresh
}
@@ -944,7 +951,7 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
httpReq.Header.Set("User-Agent", defaultAntigravityAgent)
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, tokenRefreshTimeout)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
return auth, errDo

View File

@@ -184,7 +184,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -223,6 +223,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
@@ -266,6 +267,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}

View File

@@ -152,7 +152,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
@@ -169,7 +169,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -191,6 +191,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
@@ -236,6 +237,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}

View File

@@ -170,7 +170,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> system_instruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -209,6 +209,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
}
}
@@ -253,6 +254,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
}
}

View File

@@ -299,17 +299,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
inputTokens = promptTokens.Int()
outputTokens = completionTokens.Int()
}
// Send message_delta with usage
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results)
}
// Send message_delta with usage
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results)
}
return results

View File

@@ -71,10 +71,13 @@ func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool)
incl = &defaultInclude
}
if incl != nil {
valuePath := "generationConfig.thinkingConfig.include_thoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
if !gjson.GetBytes(updated, "generationConfig.thinkingConfig.includeThoughts").Exists() &&
!gjson.GetBytes(updated, "generationConfig.thinkingConfig.include_thoughts").Exists() {
valuePath := "generationConfig.thinkingConfig.include_thoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
}
}
}
return updated
@@ -99,10 +102,13 @@ func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *boo
incl = &defaultInclude
}
if incl != nil {
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
if !gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.includeThoughts").Exists() &&
!gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts").Exists() {
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
}
}
}
return updated
@@ -130,15 +136,15 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool)
incl = &defaultInclude
}
if incl != nil {
valuePath := "generationConfig.thinkingConfig.includeThoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
if !gjson.GetBytes(updated, "generationConfig.thinkingConfig.includeThoughts").Exists() &&
!gjson.GetBytes(updated, "generationConfig.thinkingConfig.include_thoughts").Exists() {
valuePath := "generationConfig.thinkingConfig.includeThoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
}
}
}
if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts")
}
if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget")
}
@@ -167,15 +173,15 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo
incl = &defaultInclude
}
if incl != nil {
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
if !gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.includeThoughts").Exists() &&
!gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts").Exists() {
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
}
}
}
if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts")
}
if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget")
}

View File

@@ -4,8 +4,8 @@
package util
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"regexp"
@@ -93,36 +93,23 @@ func ResolveAuthDir(authDir string) (string, error) {
return filepath.Clean(authDir), nil
}
// CountAuthFiles returns the number of JSON auth files located under the provided directory.
// The function resolves leading tildes to the user's home directory and performs a case-insensitive
// match on the ".json" suffix so that files saved with uppercase extensions are also counted.
func CountAuthFiles(authDir string) int {
dir, err := ResolveAuthDir(authDir)
// CountAuthFiles returns the number of auth records available through the provided Store.
// For filesystem-backed stores, this reflects the number of JSON auth files under the configured directory.
func CountAuthFiles[T any](ctx context.Context, store interface {
List(context.Context) ([]T, error)
}) int {
if store == nil {
return 0
}
if ctx == nil {
ctx = context.Background()
}
entries, err := store.List(ctx)
if err != nil {
log.Debugf("countAuthFiles: failed to resolve auth directory: %v", err)
log.Debugf("countAuthFiles: failed to list auth records: %v", err)
return 0
}
if dir == "" {
return 0
}
count := 0
walkErr := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
log.Debugf("countAuthFiles: error accessing %s: %v", path, err)
return nil
}
if d.IsDir() {
return nil
}
if strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
count++
}
return nil
})
if walkErr != nil {
log.Debugf("countAuthFiles: walk error: %v", walkErr)
}
return count
return len(entries)
}
// WritablePath returns the cleaned WRITABLE_PATH environment variable when it is set.

View File

@@ -80,6 +80,9 @@ func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMa
continue
}
key := name + "->" + alias
if mapping.Fork {
key += "|fork"
}
if _, exists := seen[key]; exists {
continue
}

View File

@@ -21,6 +21,7 @@ type ManagementTokenRequester interface {
RequestIFlowToken(*gin.Context)
RequestIFlowCookieToken(*gin.Context)
GetAuthStatus(c *gin.Context)
PostOAuthCallback(c *gin.Context)
}
type managementTokenRequester struct {
@@ -65,3 +66,7 @@ func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
func (m *managementTokenRequester) GetAuthStatus(c *gin.Context) {
m.handler.GetAuthStatus(c)
}
func (m *managementTokenRequester) PostOAuthCallback(c *gin.Context) {
m.handler.PostOAuthCallback(c)
}

View File

@@ -1536,6 +1536,9 @@ func (m *Manager) markRefreshPending(id string, now time.Time) bool {
}
func (m *Manager) refreshAuth(ctx context.Context, id string) {
if ctx == nil {
ctx = context.Background()
}
m.mu.RLock()
auth := m.auths[id]
var exec ProviderExecutor
@@ -1548,6 +1551,10 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
}
cloned := auth.Clone()
updated, err := exec.Refresh(ctx, cloned)
if err != nil && errors.Is(err, context.Canceled) {
log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID)
return
}
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
now := time.Now()
if err != nil {

View File

@@ -5,6 +5,9 @@ import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
// ModelInfo re-exports the registry model info structure.
type ModelInfo = registry.ModelInfo
// ModelRegistryHook re-exports the registry hook interface for external integrations.
type ModelRegistryHook = registry.ModelRegistryHook
// ModelRegistry describes registry operations consumed by external callers.
type ModelRegistry interface {
RegisterClient(clientID, clientProvider string, models []*ModelInfo)
@@ -20,3 +23,8 @@ type ModelRegistry interface {
func GlobalModelRegistry() ModelRegistry {
return registry.GetGlobalRegistry()
}
// SetGlobalModelRegistryHook registers an optional hook on the shared global registry instance.
func SetGlobalModelRegistryHook(hook ModelRegistryHook) {
registry.GetGlobalRegistry().SetHook(hook)
}

View File

@@ -1231,7 +1231,13 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode
if len(mappings) == 0 {
return models
}
forward := make(map[string]string, len(mappings))
type mappingEntry struct {
alias string
fork bool
}
forward := make(map[string]mappingEntry, len(mappings))
for i := range mappings {
name := strings.TrimSpace(mappings[i].Name)
alias := strings.TrimSpace(mappings[i].Alias)
@@ -1245,7 +1251,7 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode
if _, exists := forward[key]; exists {
continue
}
forward[key] = alias
forward[key] = mappingEntry{alias: alias, fork: mappings[i].Fork}
}
if len(forward) == 0 {
return models
@@ -1260,10 +1266,45 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode
if id == "" {
continue
}
mappedID := id
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
mappedID = strings.TrimSpace(to)
key := strings.ToLower(id)
entry, ok := forward[key]
if !ok {
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
out = append(out, model)
continue
}
mappedID := strings.TrimSpace(entry.alias)
if mappedID == "" {
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
out = append(out, model)
continue
}
if entry.fork {
if _, exists := seen[key]; !exists {
seen[key] = struct{}{}
out = append(out, model)
}
aliasKey := strings.ToLower(mappedID)
if _, exists := seen[aliasKey]; exists {
continue
}
seen[aliasKey] = struct{}{}
clone := *model
clone.ID = mappedID
if clone.Name != "" {
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
}
out = append(out, &clone)
continue
}
uniqueKey := strings.ToLower(mappedID)
if _, exists := seen[uniqueKey]; exists {
continue

View File

@@ -0,0 +1,58 @@
package cliproxy
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestApplyOAuthModelMappings_Rename(t *testing.T) {
cfg := &config.Config{
OAuthModelMappings: map[string][]config.ModelNameMapping{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
models := []*ModelInfo{
{ID: "gpt-5", Name: "models/gpt-5"},
}
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
if len(out) != 1 {
t.Fatalf("expected 1 model, got %d", len(out))
}
if out[0].ID != "g5" {
t.Fatalf("expected model id %q, got %q", "g5", out[0].ID)
}
if out[0].Name != "models/g5" {
t.Fatalf("expected model name %q, got %q", "models/g5", out[0].Name)
}
}
func TestApplyOAuthModelMappings_ForkAddsAlias(t *testing.T) {
cfg := &config.Config{
OAuthModelMappings: map[string][]config.ModelNameMapping{
"codex": {
{Name: "gpt-5", Alias: "g5", Fork: true},
},
},
}
models := []*ModelInfo{
{ID: "gpt-5", Name: "models/gpt-5"},
}
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
if len(out) != 2 {
t.Fatalf("expected 2 models, got %d", len(out))
}
if out[0].ID != "gpt-5" {
t.Fatalf("expected first model id %q, got %q", "gpt-5", out[0].ID)
}
if out[1].ID != "g5" {
t.Fatalf("expected second model id %q, got %q", "g5", out[1].ID)
}
if out[1].Name != "models/g5" {
t.Fatalf("expected forked model name %q, got %q", "models/g5", out[1].Name)
}
}