mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 12:30:50 +08:00
Compare commits
76 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
750b930679 | ||
|
|
3902fd7501 | ||
|
|
4fc3d5e935 | ||
|
|
2d2f4572a7 | ||
|
|
8f4c46f38d | ||
|
|
b6ba51bc2a | ||
|
|
6a66d32d37 | ||
|
|
8d15723195 | ||
|
|
736e0aae86 | ||
|
|
8bf3305b2b | ||
|
|
d00e3ea973 | ||
|
|
89db4e9481 | ||
|
|
e332419081 | ||
|
|
e998b1229a | ||
|
|
bbed134bd1 | ||
|
|
47b9503112 | ||
|
|
3b9253c2be | ||
|
|
d241359153 | ||
|
|
f4d4249ba5 | ||
|
|
cb56cb250e | ||
|
|
e0381a6ae0 | ||
|
|
2c01b2ef64 | ||
|
|
e947266743 | ||
|
|
c6b0e85b54 | ||
|
|
26efbed05c | ||
|
|
96340bf136 | ||
|
|
b055e00c1a | ||
|
|
857c880f99 | ||
|
|
ce7474d953 | ||
|
|
70fdd70b84 | ||
|
|
08ab6a7d77 | ||
|
|
9fa2a7e9df | ||
|
|
d443c86620 | ||
|
|
7be3f1c36c | ||
|
|
f6ab6d97b9 | ||
|
|
bc866bac49 | ||
|
|
50e6d845f4 | ||
|
|
a8cb01819d | ||
|
|
530273906b | ||
|
|
06ddf575d9 | ||
|
|
3099114cbb | ||
|
|
44b63f0767 | ||
|
|
6705d20194 | ||
|
|
a38a9c0b0f | ||
|
|
8286caa366 | ||
|
|
bd1ec8424d | ||
|
|
225e2c6797 | ||
|
|
d8fc485513 | ||
|
|
f137eb0ac4 | ||
|
|
f39a460487 | ||
|
|
ee171bc563 | ||
|
|
a95428f204 | ||
|
|
3ca5fb1046 | ||
|
|
a091d12f4e | ||
|
|
457924828a | ||
|
|
aca2ef6359 | ||
|
|
ade7194792 | ||
|
|
3a436e116a | ||
|
|
336867853b | ||
|
|
6403ff4ec4 | ||
|
|
d222469b44 | ||
|
|
7646a2b877 | ||
|
|
62090f2568 | ||
|
|
c281f4cbaf | ||
|
|
09455f9e85 | ||
|
|
c8e72ba0dc | ||
|
|
375ef252ab | ||
|
|
ee552f8720 | ||
|
|
2e88c4858e | ||
|
|
3f50da85c1 | ||
|
|
8be06255f7 | ||
|
|
72274099aa | ||
|
|
dcae098e23 | ||
|
|
2eb05ec640 | ||
|
|
3ce0d76aa4 | ||
|
|
a00b79d9be |
@@ -13,8 +13,6 @@ Dockerfile
|
||||
docs/*
|
||||
README.md
|
||||
README_CN.md
|
||||
MANAGEMENT_API.md
|
||||
MANAGEMENT_API_CN.md
|
||||
LICENSE
|
||||
|
||||
# Runtime data folders (should be mounted as volumes)
|
||||
@@ -25,10 +23,14 @@ config.yaml
|
||||
|
||||
# Development/editor
|
||||
bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.codex/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -7,6 +7,13 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is it a request payload issue?**
|
||||
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
|
||||
[ ] No, it's another issue.
|
||||
|
||||
**If it's a request payload issue, you MUST know**
|
||||
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -11,11 +11,15 @@ bin/*
|
||||
logs/*
|
||||
conv/*
|
||||
temp/*
|
||||
refs/*
|
||||
|
||||
# Storage backends
|
||||
pgstore/*
|
||||
gitstore/*
|
||||
objectstore/*
|
||||
|
||||
# Static assets
|
||||
static/*
|
||||
refs/*
|
||||
|
||||
# Authentication data
|
||||
auths/*
|
||||
@@ -29,12 +33,17 @@ GEMINI.md
|
||||
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.codex/*
|
||||
.claude/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
@@ -114,6 +114,10 @@ CLI wrapper for instant switching between multiple Claude accounts and alternati
|
||||
|
||||
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
@@ -113,6 +113,10 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
|
||||
|
||||
基于 macOS 平台的原生 CLIProxyAPI GUI:配置供应商、模型映射以及OAuth端点,无需 API 密钥。
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
|
||||
@@ -405,7 +405,7 @@ func main() {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
if err = logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to configure log output: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ auth-dir: "~/.cli-proxy-api"
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
- "your-api-key-3"
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
@@ -89,6 +90,9 @@ ws-auth: false
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# models:
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
|
||||
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
||||
@@ -104,6 +108,9 @@ ws-auth: false
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "gpt-5-codex" # upstream model name
|
||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gpt-5.1" # exclude specific models (exact match)
|
||||
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
|
||||
@@ -121,7 +128,7 @@ ws-auth: false
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
|
||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||
@@ -152,9 +159,9 @@ ws-auth: false
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# models: # optional: map aliases to upstream model names
|
||||
# - name: "gemini-2.0-flash" # upstream model name
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "vertex-flash" # client-visible alias
|
||||
# - name: "gemini-1.5-pro"
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "vertex-pro"
|
||||
|
||||
# Amp Integration
|
||||
@@ -163,6 +170,18 @@ ws-auth: false
|
||||
# upstream-url: "https://ampcode.com"
|
||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||
# upstream-api-key: ""
|
||||
# # Per-client upstream API key mapping
|
||||
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
|
||||
# # Useful when different clients need to use different Amp accounts/quotas.
|
||||
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
|
||||
# upstream-api-keys:
|
||||
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
|
||||
# api-keys: # Client keys that use this upstream key
|
||||
# - "your-api-key-1"
|
||||
# - "your-api-key-2"
|
||||
# - upstream-api-key: "amp_key_for_team_b"
|
||||
# api-keys:
|
||||
# - "your-api-key-3"
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||
# restrict-management-to-localhost: false
|
||||
# # Force model mappings to run before checking local API keys (default: false)
|
||||
@@ -172,12 +191,42 @@ ws-auth: false
|
||||
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||
# # but you have a similar model available (e.g., Claude Sonnet 4).
|
||||
# model-mappings:
|
||||
# - from: "claude-opus-4.5" # Model requested by Amp CLI
|
||||
# to: "claude-sonnet-4" # Route to this available model instead
|
||||
# - from: "gpt-5"
|
||||
# to: "gemini-2.5-pro"
|
||||
# - from: "claude-3-opus-20240229"
|
||||
# to: "claude-3-5-sonnet-20241022"
|
||||
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
|
||||
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
|
||||
# - from: "claude-sonnet-4-5-20250929"
|
||||
# to: "gemini-claude-sonnet-4-5-thinking"
|
||||
# - from: "claude-haiku-4-5-20251001"
|
||||
# to: "gemini-2.5-flash"
|
||||
|
||||
# Global OAuth model name mappings (per channel)
|
||||
# These mappings rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# oauth-model-mappings:
|
||||
# gemini-cli:
|
||||
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||
# alias: "g2.5p" # client-visible alias
|
||||
# vertex:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# antigravity:
|
||||
# - name: "gemini-3-pro-preview"
|
||||
# alias: "g3p"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
# codex:
|
||||
# - name: "gpt-5"
|
||||
# alias: "g5"
|
||||
# qwen:
|
||||
# - name: "qwen3-coder-plus"
|
||||
# alias: "qwen-plus"
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
|
||||
# OAuth provider excluded models
|
||||
# oauth-excluded-models:
|
||||
|
||||
538
internal/api/handlers/management/api_tools.go
Normal file
538
internal/api/handlers/management/api_tools.go
Normal file
@@ -0,0 +1,538 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const defaultAPICallTimeout = 60 * time.Second
|
||||
|
||||
const (
|
||||
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOAuthClientSecret = "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"
|
||||
)
|
||||
|
||||
var geminiOAuthScopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
type apiCallRequest struct {
|
||||
AuthIndexSnake *string `json:"auth_index"`
|
||||
AuthIndexCamel *string `json:"authIndex"`
|
||||
AuthIndexPascal *string `json:"AuthIndex"`
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Header map[string]string `json:"header"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type apiCallResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Header map[string][]string `json:"header"`
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
// APICall makes a generic HTTP request on behalf of the management API caller.
|
||||
// It is protected by the management middleware.
|
||||
//
|
||||
// Endpoint:
|
||||
//
|
||||
// POST /v0/management/api-call
|
||||
//
|
||||
// Authentication:
|
||||
//
|
||||
// Same as other management APIs (requires a management key and remote-management rules).
|
||||
// You can provide the key via:
|
||||
// - Authorization: Bearer <key>
|
||||
// - X-Management-Key: <key>
|
||||
//
|
||||
// Request JSON:
|
||||
// - auth_index / authIndex / AuthIndex (optional):
|
||||
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
|
||||
// If omitted or not found, credential-specific proxy/token substitution is skipped.
|
||||
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
|
||||
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
|
||||
// - header (optional): Request headers map.
|
||||
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
|
||||
// 1) metadata.access_token
|
||||
// 2) attributes.api_key
|
||||
// 3) metadata.token / metadata.id_token / metadata.cookie
|
||||
// Example: {"Authorization":"Bearer $TOKEN$"}.
|
||||
// Note: if you need to override the HTTP Host header, set header["Host"].
|
||||
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
|
||||
//
|
||||
// Proxy selection (highest priority first):
|
||||
// 1. Selected credential proxy_url
|
||||
// 2. Global config proxy-url
|
||||
// 3. Direct connect (environment proxies are not used)
|
||||
//
|
||||
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
|
||||
// - status_code: Upstream HTTP status code.
|
||||
// - header: Upstream response headers.
|
||||
// - body: Upstream response body as string.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
|
||||
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
|
||||
// -H "Content-Type: application/json" \
|
||||
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
|
||||
//
|
||||
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
|
||||
// -H "Authorization: Bearer 831227" \
|
||||
// -H "Content-Type: application/json" \
|
||||
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
|
||||
func (h *Handler) APICall(c *gin.Context) {
|
||||
var body apiCallRequest
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
method := strings.ToUpper(strings.TrimSpace(body.Method))
|
||||
if method == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
|
||||
return
|
||||
}
|
||||
|
||||
urlStr := strings.TrimSpace(body.URL)
|
||||
if urlStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
|
||||
return
|
||||
}
|
||||
parsedURL, errParseURL := url.Parse(urlStr)
|
||||
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
|
||||
return
|
||||
}
|
||||
|
||||
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
|
||||
auth := h.authByIndex(authIndex)
|
||||
|
||||
reqHeaders := body.Header
|
||||
if reqHeaders == nil {
|
||||
reqHeaders = map[string]string{}
|
||||
}
|
||||
|
||||
var hostOverride string
|
||||
var token string
|
||||
var tokenResolved bool
|
||||
var tokenErr error
|
||||
for key, value := range reqHeaders {
|
||||
if !strings.Contains(value, "$TOKEN$") {
|
||||
continue
|
||||
}
|
||||
if !tokenResolved {
|
||||
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
|
||||
tokenResolved = true
|
||||
}
|
||||
if auth != nil && token == "" {
|
||||
if tokenErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
|
||||
return
|
||||
}
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if body.Data != "" {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
}
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
|
||||
if errNewRequest != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
|
||||
return
|
||||
}
|
||||
|
||||
for key, value := range reqHeaders {
|
||||
if strings.EqualFold(key, "host") {
|
||||
hostOverride = strings.TrimSpace(value)
|
||||
continue
|
||||
}
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
if hostOverride != "" {
|
||||
req.Host = hostOverride
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
}
|
||||
httpClient.Transport = h.apiCallTransport(auth)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.WithError(errDo).Debug("management APICall request failed")
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
respBody, errReadAll := io.ReadAll(resp.Body)
|
||||
if errReadAll != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, apiCallResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header,
|
||||
Body: string(respBody),
|
||||
})
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...*string) string {
|
||||
for _, v := range values {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
if out := strings.TrimSpace(*v); out != "" {
|
||||
return out
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func tokenValueForAuth(auth *coreauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
|
||||
return v
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
||||
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if provider == "gemini-cli" {
|
||||
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
|
||||
return token, errToken
|
||||
}
|
||||
|
||||
return tokenValueForAuth(auth), nil
|
||||
}
|
||||
|
||||
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
metadata, updater := geminiOAuthMetadata(auth)
|
||||
if len(metadata) == 0 {
|
||||
return "", fmt.Errorf("gemini oauth metadata missing")
|
||||
}
|
||||
|
||||
base := make(map[string]any)
|
||||
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
|
||||
base = cloneMap(tokenRaw)
|
||||
}
|
||||
|
||||
var token oauth2.Token
|
||||
if len(base) > 0 {
|
||||
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
|
||||
_ = json.Unmarshal(raw, &token)
|
||||
}
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
token.AccessToken = stringValue(metadata, "access_token")
|
||||
}
|
||||
if token.RefreshToken == "" {
|
||||
token.RefreshToken = stringValue(metadata, "refresh_token")
|
||||
}
|
||||
if token.TokenType == "" {
|
||||
token.TokenType = stringValue(metadata, "token_type")
|
||||
}
|
||||
if token.Expiry.IsZero() {
|
||||
if expiry := stringValue(metadata, "expiry"); expiry != "" {
|
||||
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
|
||||
token.Expiry = ts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conf := &oauth2.Config{
|
||||
ClientID: geminiOAuthClientID,
|
||||
ClientSecret: geminiOAuthClientSecret,
|
||||
Scopes: geminiOAuthScopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
ctxToken := ctx
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
Transport: h.apiCallTransport(auth),
|
||||
}
|
||||
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
|
||||
|
||||
src := conf.TokenSource(ctxToken, &token)
|
||||
currentToken, errToken := src.Token()
|
||||
if errToken != nil {
|
||||
return "", errToken
|
||||
}
|
||||
|
||||
merged := buildOAuthTokenMap(base, currentToken)
|
||||
fields := buildOAuthTokenFields(currentToken, merged)
|
||||
if updater != nil {
|
||||
updater(fields)
|
||||
}
|
||||
return strings.TrimSpace(currentToken.AccessToken), nil
|
||||
}
|
||||
|
||||
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
|
||||
if auth == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
||||
snapshot := shared.MetadataSnapshot()
|
||||
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
|
||||
}
|
||||
return auth.Metadata, func(fields map[string]any) {
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
for k, v := range fields {
|
||||
auth.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(metadata map[string]any, key string) string {
|
||||
if len(metadata) == 0 || key == "" {
|
||||
return ""
|
||||
}
|
||||
if v, ok := metadata[key].(string); ok {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
|
||||
merged := cloneMap(base)
|
||||
if merged == nil {
|
||||
merged = make(map[string]any)
|
||||
}
|
||||
if tok == nil {
|
||||
return merged
|
||||
}
|
||||
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
|
||||
var tokenMap map[string]any
|
||||
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
|
||||
for k, v := range tokenMap {
|
||||
merged[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
|
||||
fields := make(map[string]any, 5)
|
||||
if tok != nil && tok.AccessToken != "" {
|
||||
fields["access_token"] = tok.AccessToken
|
||||
}
|
||||
if tok != nil && tok.TokenType != "" {
|
||||
fields["token_type"] = tok.TokenType
|
||||
}
|
||||
if tok != nil && tok.RefreshToken != "" {
|
||||
fields["refresh_token"] = tok.RefreshToken
|
||||
}
|
||||
if tok != nil && !tok.Expiry.IsZero() {
|
||||
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
|
||||
}
|
||||
if len(merged) > 0 {
|
||||
fields["token"] = cloneMap(merged)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func tokenValueFromMetadata(metadata map[string]any) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
|
||||
switch typed := tokenRaw.(type) {
|
||||
case string:
|
||||
if v := strings.TrimSpace(typed); v != "" {
|
||||
return v
|
||||
}
|
||||
case map[string]any:
|
||||
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
case map[string]string:
|
||||
if v := strings.TrimSpace(typed["access_token"]); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
|
||||
authIndex = strings.TrimSpace(authIndex)
|
||||
if authIndex == "" || h == nil || h.authManager == nil {
|
||||
return nil
|
||||
}
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
if auth.Index == authIndex {
|
||||
return auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
||||
var proxyCandidates []string
|
||||
if auth != nil {
|
||||
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||
}
|
||||
}
|
||||
if h != nil && h.cfg != nil {
|
||||
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, proxyStr := range proxyCandidates {
|
||||
if transport := buildProxyTransport(proxyStr); transport != nil {
|
||||
return transport
|
||||
}
|
||||
}
|
||||
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
return &http.Transport{Proxy: nil}
|
||||
}
|
||||
clone := transport.Clone()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
proxyStr = strings.TrimSpace(proxyStr)
|
||||
if proxyStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
proxyURL, errParse := url.Parse(proxyStr)
|
||||
if errParse != nil {
|
||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
||||
return nil
|
||||
}
|
||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
||||
log.Debug("proxy URL missing scheme/host")
|
||||
return nil
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
var proxyAuth *proxy.Auth
|
||||
if proxyURL.User != nil {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
||||
return nil
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
return nil
|
||||
}
|
||||
@@ -427,9 +427,46 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
log.WithError(err).Warnf("failed to stat auth file %s", path)
|
||||
}
|
||||
}
|
||||
if claims := extractCodexIDTokenClaims(auth); claims != nil {
|
||||
entry["id_token"] = claims
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
||||
return nil
|
||||
}
|
||||
idTokenRaw, ok := auth.Metadata["id_token"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
idToken := strings.TrimSpace(idTokenRaw)
|
||||
if idToken == "" {
|
||||
return nil
|
||||
}
|
||||
claims, err := codex.ParseJWTToken(idToken)
|
||||
if err != nil || claims == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := gin.H{}
|
||||
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
|
||||
result["chatgpt_account_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
|
||||
result["plan_type"] = v
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func authEmail(auth *coreauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
|
||||
@@ -597,11 +597,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
filtered := make([]config.CodexKey, 0, len(arr))
|
||||
for i := range arr {
|
||||
entry := arr[i]
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
normalizeCodexKey(&entry)
|
||||
if entry.BaseURL == "" {
|
||||
continue
|
||||
}
|
||||
@@ -613,12 +609,13 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
}
|
||||
func (h *Handler) PatchCodexKey(c *gin.Context) {
|
||||
type codexKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Models *[]config.CodexModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
@@ -667,12 +664,16 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
normalizeCodexKey(&entry)
|
||||
h.cfg.CodexKey[targetIndex] = entry
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
@@ -762,6 +763,32 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
func normalizeCodexKey(entry *config.CodexKey) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.Prefix = strings.TrimSpace(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
normalized := make([]config.CodexModel, 0, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
model.Name = strings.TrimSpace(model.Name)
|
||||
model.Alias = strings.TrimSpace(model.Alias)
|
||||
if model.Name == "" && model.Alias == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, model)
|
||||
}
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
// GetAmpCode returns the complete ampcode configuration.
|
||||
func (h *Handler) GetAmpCode(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
@@ -913,3 +940,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
|
||||
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
||||
}
|
||||
|
||||
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
|
||||
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
|
||||
}
|
||||
|
||||
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
|
||||
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
// Normalize entries: trim whitespace, filter empty
|
||||
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
|
||||
// Matching is done by upstream-api-key value.
|
||||
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
existing := make(map[string]int)
|
||||
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
|
||||
}
|
||||
|
||||
for _, newEntry := range body.Value {
|
||||
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
|
||||
UpstreamAPIKey: upstreamKey,
|
||||
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
|
||||
}
|
||||
if idx, ok := existing[upstreamKey]; ok {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
|
||||
} else {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
|
||||
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
|
||||
}
|
||||
}
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
|
||||
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
|
||||
// If "value" is an empty array, clears all entries.
|
||||
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
|
||||
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []string `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
if body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "missing value"})
|
||||
return
|
||||
}
|
||||
|
||||
// Empty array means clear all
|
||||
if len(body.Value) == 0 {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = nil
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
|
||||
toRemove := make(map[string]bool)
|
||||
for _, key := range body.Value {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
toRemove[trimmed] = true
|
||||
}
|
||||
if len(toRemove) == 0 {
|
||||
c.JSON(400, gin.H{"error": "empty value"})
|
||||
return
|
||||
}
|
||||
|
||||
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
|
||||
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
|
||||
newEntries = append(newEntries, entry)
|
||||
}
|
||||
}
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
|
||||
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
apiKeys := normalizeAPIKeysList(entry.APIKeys)
|
||||
out = append(out, config.AmpUpstreamAPIKeyEntry{
|
||||
UpstreamAPIKey: upstreamKey,
|
||||
APIKeys: apiKeys,
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
|
||||
func normalizeAPIKeysList(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed != "" {
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -59,6 +59,11 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
|
||||
return NewHandler(cfg, "", manager)
|
||||
}
|
||||
|
||||
// SetConfig updates the in-memory config reference when the server hot-reloads.
|
||||
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
|
||||
|
||||
|
||||
@@ -1,12 +1,25 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
)
|
||||
|
||||
type usageExportPayload struct {
|
||||
Version int `json:"version"`
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
type usageImportPayload struct {
|
||||
Version int `json:"version"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
||||
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
|
||||
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
|
||||
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
if h != nil && h.usageStats != nil {
|
||||
snapshot = h.usageStats.Snapshot()
|
||||
}
|
||||
c.JSON(http.StatusOK, usageExportPayload{
|
||||
Version: 1,
|
||||
ExportedAt: time.Now().UTC(),
|
||||
Usage: snapshot,
|
||||
})
|
||||
}
|
||||
|
||||
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
|
||||
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
|
||||
if h == nil || h.usageStats == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
var payload usageImportPayload
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
|
||||
return
|
||||
}
|
||||
if payload.Version != 0 && payload.Version != 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
|
||||
return
|
||||
}
|
||||
|
||||
result := h.usageStats.MergeSnapshot(payload.Usage)
|
||||
snapshot := h.usageStats.Snapshot()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"added": result.Added,
|
||||
"skipped": result.Skipped,
|
||||
"total_requests": snapshot.TotalRequests,
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check API key change
|
||||
// Check API key change (both default and per-client mappings)
|
||||
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
||||
if apiKeyChanged {
|
||||
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
|
||||
if apiKeyChanged || upstreamAPIKeysChanged {
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||
if apiKeyChanged {
|
||||
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
if upstreamAPIKeysChanged {
|
||||
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
|
||||
}
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
|
||||
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
||||
if m.secretSource == nil {
|
||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
|
||||
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
mappedSource := NewMappedSecretSource(defaultSource)
|
||||
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
m.secretSource = mappedSource
|
||||
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
ms.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
|
||||
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
mappedSource := NewMappedSecretSource(ms)
|
||||
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
m.secretSource = mappedSource
|
||||
}
|
||||
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
@@ -313,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
|
||||
return oldKey != newKey
|
||||
}
|
||||
|
||||
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
|
||||
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||
if old == nil {
|
||||
return len(new.UpstreamAPIKeys) > 0
|
||||
}
|
||||
|
||||
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for comparison: upstreamKey -> set of clientKeys
|
||||
type entryInfo struct {
|
||||
upstreamKey string
|
||||
clientKeys map[string]struct{}
|
||||
}
|
||||
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
|
||||
for i, entry := range old.UpstreamAPIKeys {
|
||||
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
|
||||
for _, k := range entry.APIKeys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
clientKeys[trimmed] = struct{}{}
|
||||
}
|
||||
oldEntries[i] = entryInfo{
|
||||
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
|
||||
clientKeys: clientKeys,
|
||||
}
|
||||
}
|
||||
|
||||
for i, newEntry := range new.UpstreamAPIKeys {
|
||||
if i >= len(oldEntries) {
|
||||
return true
|
||||
}
|
||||
oldE := oldEntries[i]
|
||||
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
|
||||
return true
|
||||
}
|
||||
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
|
||||
for _, k := range newEntry.APIKeys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
newKeys[trimmed] = struct{}{}
|
||||
}
|
||||
if len(newKeys) != len(oldE.clientKeys) {
|
||||
return true
|
||||
}
|
||||
for k := range newKeys {
|
||||
if _, ok := oldE.clientKeys[k]; !ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||
return m.modelMapper
|
||||
|
||||
@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
|
||||
m := &AmpModule{}
|
||||
|
||||
oldCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||
},
|
||||
}
|
||||
newCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
|
||||
},
|
||||
}
|
||||
|
||||
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
|
||||
m := &AmpModule{}
|
||||
|
||||
oldCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||
},
|
||||
}
|
||||
newCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
|
||||
},
|
||||
}
|
||||
|
||||
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||
t.Fatal("expected no change when only whitespace/empty entries differ")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,33 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func removeQueryValuesMatching(req *http.Request, key string, match string) {
|
||||
if req == nil || req.URL == nil || match == "" {
|
||||
return
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
values, ok := q[key]
|
||||
if !ok || len(values) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
kept := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
if v == match {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, v)
|
||||
}
|
||||
|
||||
if len(kept) == 0 {
|
||||
q.Del(key)
|
||||
} else {
|
||||
q[key] = kept
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// readCloser wraps a reader and forwards Close to a separate closer.
|
||||
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
||||
type readCloser struct {
|
||||
@@ -45,6 +72,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// We will set our own Authorization using the configured upstream-api-key
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Del("X-Api-Key")
|
||||
req.Header.Del("X-Goog-Api-Key")
|
||||
|
||||
// Remove query-based credentials if they match the authenticated client API key.
|
||||
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||
// breaking unrelated upstream query parameters.
|
||||
clientKey := getClientAPIKeyFromContext(req.Context())
|
||||
removeQueryValuesMatching(req, "key", clientKey)
|
||||
removeQueryValuesMatching(req, "auth_token", clientKey)
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
|
||||
@@ -3,11 +3,15 @@ package amp
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// Helper: compress data with gzip
|
||||
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
|
||||
type captured struct {
|
||||
headers http.Header
|
||||
query string
|
||||
}
|
||||
got := make(chan captured, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer client-key")
|
||||
req.Header.Set("X-Api-Key", "client-key")
|
||||
req.Header.Set("X-Goog-Api-Key", "client-key")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
c := <-got
|
||||
|
||||
// These are client-provided credentials and must not reach the upstream.
|
||||
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
|
||||
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
|
||||
}
|
||||
|
||||
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
|
||||
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
|
||||
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
|
||||
}
|
||||
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
|
||||
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
|
||||
}
|
||||
|
||||
// Query-based credentials should be stripped only when they match the authenticated client key.
|
||||
// Should keep unrelated values and parameters.
|
||||
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
|
||||
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
|
||||
}
|
||||
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
|
||||
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
mapped := NewMappedSecretSource(defaultSource)
|
||||
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "u1" {
|
||||
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer u1" {
|
||||
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
mapped := NewMappedSecretSource(defaultSource)
|
||||
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "default" {
|
||||
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer default" {
|
||||
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
// Point proxy to a non-routable address to trigger error
|
||||
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -16,6 +17,37 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// clientAPIKeyContextKey is the context key used to pass the client API key
|
||||
// from gin.Context to the request context for SecretSource lookup.
|
||||
type clientAPIKeyContextKey struct{}
|
||||
|
||||
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
|
||||
// into the request context so that SecretSource can look it up for per-client upstream routing.
|
||||
func clientAPIKeyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Extract the client API key from gin context (set by AuthMiddleware)
|
||||
if apiKey, exists := c.Get("apiKey"); exists {
|
||||
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
|
||||
// Inject into request context for SecretSource.Get(ctx) to read
|
||||
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// getClientAPIKeyFromContext retrieves the client API key from request context.
|
||||
// Returns empty string if not present.
|
||||
func getClientAPIKeyFromContext(ctx context.Context) string {
|
||||
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
|
||||
if keyStr, ok := val.(string); ok {
|
||||
return keyStr
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
||||
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
||||
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
||||
}
|
||||
|
||||
// Inject client API key into request context for per-client upstream routing
|
||||
ampAPI.Use(clientAPIKeyMiddleware())
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
if authWithBypass != nil {
|
||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||
}
|
||||
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
|
||||
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
|
||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
if auth != nil {
|
||||
ampProviders.Use(auth)
|
||||
}
|
||||
// Inject client API key into request context for per-client upstream routing
|
||||
ampProviders.Use(clientAPIKeyMiddleware())
|
||||
|
||||
provider := ampProviders.Group("/:provider")
|
||||
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SecretSource provides Amp API keys with configurable precedence and caching
|
||||
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
|
||||
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
|
||||
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
|
||||
// When a request context contains a client API key that matches a configured mapping,
|
||||
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
|
||||
type MappedSecretSource struct {
|
||||
defaultSource SecretSource
|
||||
mu sync.RWMutex
|
||||
lookup map[string]string // clientKey -> upstreamKey
|
||||
}
|
||||
|
||||
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
|
||||
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
|
||||
return &MappedSecretSource{
|
||||
defaultSource: defaultSource,
|
||||
lookup: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves the Amp API key, checking per-client mappings first.
|
||||
// If the request context contains a client API key that matches a configured mapping,
|
||||
// returns the corresponding upstream key. Otherwise, falls back to the default source.
|
||||
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
|
||||
// Try to get client API key from request context
|
||||
clientKey := getClientAPIKeyFromContext(ctx)
|
||||
if clientKey != "" {
|
||||
s.mu.RLock()
|
||||
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
|
||||
s.mu.RUnlock()
|
||||
return upstreamKey, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
|
||||
// Fall back to default source
|
||||
return s.defaultSource.Get(ctx)
|
||||
}
|
||||
|
||||
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
|
||||
// If the same client key appears in multiple entries, logs a warning and uses the first one.
|
||||
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
|
||||
newLookup := make(map[string]string)
|
||||
|
||||
for _, entry := range entries {
|
||||
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
for _, clientKey := range entry.APIKeys {
|
||||
trimmedKey := strings.TrimSpace(clientKey)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := newLookup[trimmedKey]; exists {
|
||||
// Log warning for duplicate client key, first one wins
|
||||
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
|
||||
continue
|
||||
}
|
||||
newLookup[trimmedKey] = upstreamKey
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.lookup = newLookup
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
|
||||
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
|
||||
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(key)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
|
||||
func (s *MappedSecretSource) InvalidateCache() {
|
||||
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
)
|
||||
|
||||
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
||||
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
||||
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "u1" {
|
||||
t.Fatalf("want u1, got %q", got)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
|
||||
got, err = s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "default" {
|
||||
t.Fatalf("want default fallback, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
{
|
||||
UpstreamAPIKey: "u2",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "u1" {
|
||||
t.Fatalf("want u1 (first wins), got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
|
||||
hook := test.NewLocal(log.StandardLogger())
|
||||
defer hook.Reset()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
{
|
||||
UpstreamAPIKey: "u2",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
foundWarning := false
|
||||
for _, entry := range hook.AllEntries() {
|
||||
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
|
||||
foundWarning = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundWarning {
|
||||
t.Fatal("expected warning log for duplicate client key, but none was found")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -476,6 +476,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
||||
{
|
||||
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
||||
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
|
||||
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
|
||||
mgmt.GET("/config", s.mgmt.GetConfig)
|
||||
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||
@@ -498,6 +500,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
|
||||
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
|
||||
|
||||
mgmt.POST("/api-call", s.mgmt.APICall)
|
||||
|
||||
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
|
||||
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
||||
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
||||
@@ -547,6 +551,10 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
||||
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
|
||||
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
|
||||
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
|
||||
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
|
||||
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
@@ -848,7 +856,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
|
||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
if err := logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to reconfigure log output: %v", err)
|
||||
} else {
|
||||
if oldCfg == nil {
|
||||
|
||||
@@ -91,6 +91,14 @@ type Config struct {
|
||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||
|
||||
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
|
||||
// These mappings affect both model listing and model routing for supported channels:
|
||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
//
|
||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
|
||||
|
||||
// Payload defines default and override rules for provider payload parameters.
|
||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||
|
||||
@@ -137,6 +145,13 @@ type RoutingConfig struct {
|
||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||
}
|
||||
|
||||
// ModelNameMapping defines a model ID rename mapping for a specific channel.
|
||||
// It maps the original model name (Name) to the client-visible alias (Alias).
|
||||
type ModelNameMapping struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||
// When Amp requests a model that isn't available locally, this mapping
|
||||
// allows routing to an alternative model that IS available.
|
||||
@@ -163,6 +178,11 @@ type AmpCode struct {
|
||||
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
|
||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||
|
||||
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
||||
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
||||
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||
|
||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
||||
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
||||
@@ -178,6 +198,17 @@ type AmpCode struct {
|
||||
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
||||
}
|
||||
|
||||
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
|
||||
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||
// is used for the upstream Amp request.
|
||||
type AmpUpstreamAPIKeyEntry struct {
|
||||
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
|
||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||
|
||||
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
}
|
||||
|
||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||
type PayloadConfig struct {
|
||||
// Default defines rules that only set parameters when they are missing in the payload.
|
||||
@@ -237,6 +268,9 @@ type ClaudeModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m ClaudeModel) GetName() string { return m.Name }
|
||||
func (m ClaudeModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// CodexKey represents the configuration for a Codex API key,
|
||||
// including the API key itself and an optional base URL for the API endpoint.
|
||||
type CodexKey struct {
|
||||
@@ -253,6 +287,9 @@ type CodexKey struct {
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// Models defines upstream model names and aliases for request routing.
|
||||
Models []CodexModel `yaml:"models" json:"models"`
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
@@ -260,6 +297,18 @@ type CodexKey struct {
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// CodexModel describes a mapping between an alias and the actual upstream model name.
|
||||
type CodexModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the client-facing model name that maps to Name.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m CodexModel) GetName() string { return m.Name }
|
||||
func (m CodexModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// GeminiKey represents the configuration for a Gemini API key,
|
||||
// including optional overrides for upstream base URL, proxy routing, and headers.
|
||||
type GeminiKey struct {
|
||||
@@ -275,6 +324,9 @@ type GeminiKey struct {
|
||||
// ProxyURL optionally overrides the global proxy for this API key.
|
||||
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||
|
||||
// Models defines upstream model names and aliases for request routing.
|
||||
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
@@ -282,6 +334,18 @@ type GeminiKey struct {
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiModel describes a mapping between an alias and the actual upstream model name.
|
||||
type GeminiModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the client-facing model name that maps to Name.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m GeminiModel) GetName() string { return m.Name }
|
||||
func (m GeminiModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
||||
type OpenAICompatibility struct {
|
||||
@@ -433,6 +497,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Normalize OAuth provider model exclusion map.
|
||||
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
||||
|
||||
// Normalize global OAuth model name mappings.
|
||||
cfg.SanitizeOAuthModelMappings()
|
||||
|
||||
if cfg.legacyMigrationPending {
|
||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
if !optional && configFile != "" {
|
||||
@@ -449,6 +516,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// and ensures (From, To) pairs are unique within each channel.
|
||||
func (cfg *Config) SanitizeOAuthModelMappings() {
|
||||
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
|
||||
return
|
||||
}
|
||||
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
|
||||
for rawChannel, mappings := range cfg.OAuthModelMappings {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(mappings) == 0 {
|
||||
continue
|
||||
}
|
||||
seenName := make(map[string]struct{}, len(mappings))
|
||||
seenAlias := make(map[string]struct{}, len(mappings))
|
||||
clean := make([]ModelNameMapping, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
name := strings.TrimSpace(mapping.Name)
|
||||
alias := strings.TrimSpace(mapping.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
nameKey := strings.ToLower(name)
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, ok := seenName[nameKey]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := seenAlias[aliasKey]; ok {
|
||||
continue
|
||||
}
|
||||
seenName[nameKey] = struct{}{}
|
||||
seenAlias[aliasKey] = struct{}{}
|
||||
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
|
||||
}
|
||||
if len(clean) > 0 {
|
||||
out[channel] = clean
|
||||
}
|
||||
}
|
||||
cfg.OAuthModelMappings = out
|
||||
}
|
||||
|
||||
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
|
||||
// not actionable, specifically those missing a BaseURL. It trims whitespace before
|
||||
// evaluation and preserves the relative order of remaining entries.
|
||||
@@ -817,8 +928,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
||||
}
|
||||
|
||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||
// key order and comments of existing keys in dst. Unknown keys from src are appended
|
||||
// to dst at the end, copying their node structure from src.
|
||||
// key order and comments of existing keys in dst. New keys are only added if their
|
||||
// value is non-zero to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
@@ -829,20 +940,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
copyNodeShallow(dst, src)
|
||||
return
|
||||
}
|
||||
// Build a lookup of existing keys in dst
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
sk := src.Content[i]
|
||||
sv := src.Content[i+1]
|
||||
idx := findMapKeyIndex(dst, sk.Value)
|
||||
if idx >= 0 {
|
||||
// Merge into existing value node
|
||||
// Merge into existing value node (always update, even to zero values)
|
||||
dv := dst.Content[idx+1]
|
||||
mergeNodePreserve(dv, sv)
|
||||
} else {
|
||||
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) {
|
||||
// New key: only add if value is non-zero to avoid polluting config with defaults
|
||||
if isZeroValueNode(sv) {
|
||||
continue
|
||||
}
|
||||
// Append new key/value pair by deep-copying from src
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
||||
}
|
||||
}
|
||||
@@ -925,32 +1035,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
|
||||
switch key {
|
||||
case "generative-language-api-key",
|
||||
"gemini-api-key",
|
||||
"vertex-api-key",
|
||||
"claude-api-key",
|
||||
"codex-api-key",
|
||||
"openai-compatibility":
|
||||
return isEmptyCollectionNode(node)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isEmptyCollectionNode(node *yaml.Node) bool {
|
||||
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||
// that should not be written as a new key to preserve config cleanliness.
|
||||
// For mappings and sequences, recursively checks if all children are zero values.
|
||||
func isZeroValueNode(node *yaml.Node) bool {
|
||||
if node == nil {
|
||||
return true
|
||||
}
|
||||
switch node.Kind {
|
||||
case yaml.SequenceNode:
|
||||
return len(node.Content) == 0
|
||||
case yaml.ScalarNode:
|
||||
return node.Tag == "!!null"
|
||||
default:
|
||||
return false
|
||||
switch node.Tag {
|
||||
case "!!bool":
|
||||
return node.Value == "false"
|
||||
case "!!int", "!!float":
|
||||
return node.Value == "0" || node.Value == "0.0"
|
||||
case "!!str":
|
||||
return node.Value == ""
|
||||
case "!!null":
|
||||
return true
|
||||
}
|
||||
case yaml.SequenceNode:
|
||||
if len(node.Content) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check if all elements are zero values
|
||||
for _, child := range node.Content {
|
||||
if !isZeroValueNode(child) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case yaml.MappingNode:
|
||||
if len(node.Content) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check if all values are zero values (values are at odd indices)
|
||||
for i := 1; i < len(node.Content); i += 2 {
|
||||
if !isZeroValueNode(node.Content[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// deepCopyNode creates a deep copy of a yaml.Node graph.
|
||||
|
||||
@@ -30,13 +30,13 @@ type SDKConfig struct {
|
||||
// StreamingConfig holds server streaming behavior configuration.
|
||||
type StreamingConfig struct {
|
||||
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
|
||||
// nil means default (15 seconds). <= 0 disables keep-alives.
|
||||
KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
||||
// <= 0 disables keep-alives. Default is 0.
|
||||
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
||||
|
||||
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
|
||||
// to allow auth rotation / transient recovery.
|
||||
// nil means default (2). 0 disables bootstrap retries.
|
||||
BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
// <= 0 disables bootstrap retries. Default is 0.
|
||||
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
|
||||
@@ -42,6 +42,9 @@ type VertexCompatModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m VertexCompatModel) GetName() string { return m.Name }
|
||||
func (m VertexCompatModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
|
||||
func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if cfg == nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@@ -83,10 +84,30 @@ func SetupBaseLogger() {
|
||||
})
|
||||
}
|
||||
|
||||
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
|
||||
func isDirWritable(dir string) bool {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil || !info.IsDir() {
|
||||
return false
|
||||
}
|
||||
|
||||
testFile := filepath.Join(dir, ".perm_test")
|
||||
f, err := os.Create(testFile)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(testFile)
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||
// until the total size is within the limit.
|
||||
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
func ConfigureLogOutput(cfg *config.Config) error {
|
||||
SetupBaseLogger()
|
||||
|
||||
writerMu.Lock()
|
||||
@@ -95,10 +116,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
} else if !isDirWritable(logDir) {
|
||||
logDir = filepath.Join(cfg.AuthDir, "logs")
|
||||
}
|
||||
|
||||
protectedPath := ""
|
||||
if loggingToFile {
|
||||
if cfg.LoggingToFile {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
||||
}
|
||||
@@ -122,7 +145,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
|
||||
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -24,10 +24,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -198,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
return
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
@@ -210,15 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if err := os.MkdirAll(staticDir, 0o755); err != nil {
|
||||
log.WithError(err).Warn("failed to prepare static directory for management asset")
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
@@ -229,6 +239,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
@@ -240,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
@@ -256,6 +280,22 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
}
|
||||
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to download fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
|
||||
return true
|
||||
}
|
||||
|
||||
func resolveReleaseURL(repo string) string {
|
||||
repo = strings.TrimSpace(repo)
|
||||
if repo == "" {
|
||||
|
||||
@@ -740,8 +740,8 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -773,7 +773,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
return map[string]*AntigravityModelConfig{
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
|
||||
@@ -781,3 +781,29 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
}
|
||||
}
|
||||
|
||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||
// Returns nil if no matching model is found.
|
||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
allModels := [][]*ModelInfo{
|
||||
GetClaudeModels(),
|
||||
GetGeminiModels(),
|
||||
GetGeminiVertexModels(),
|
||||
GetGeminiCLIModels(),
|
||||
GetAIStudioModels(),
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
if m != nil && m.ID == modelID {
|
||||
return m
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -625,6 +625,131 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
return models
|
||||
}
|
||||
|
||||
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
||||
// Parameters:
|
||||
// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity")
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: List of available models for the provider
|
||||
func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
type providerModel struct {
|
||||
count int
|
||||
info *ModelInfo
|
||||
}
|
||||
|
||||
providerModels := make(map[string]*providerModel)
|
||||
|
||||
for clientID, clientProvider := range r.clientProviders {
|
||||
if clientProvider != provider {
|
||||
continue
|
||||
}
|
||||
modelIDs := r.clientModels[clientID]
|
||||
if len(modelIDs) == 0 {
|
||||
continue
|
||||
}
|
||||
clientInfos := r.clientModelInfos[clientID]
|
||||
for _, modelID := range modelIDs {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
entry := providerModels[modelID]
|
||||
if entry == nil {
|
||||
entry = &providerModel{}
|
||||
providerModels[modelID] = entry
|
||||
}
|
||||
entry.count++
|
||||
if entry.info == nil {
|
||||
if clientInfos != nil {
|
||||
if info := clientInfos[modelID]; info != nil {
|
||||
entry.info = info
|
||||
}
|
||||
}
|
||||
if entry.info == nil {
|
||||
if reg, ok := r.models[modelID]; ok && reg != nil && reg.Info != nil {
|
||||
entry.info = reg.Info
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(providerModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
now := time.Now()
|
||||
result := make([]*ModelInfo, 0, len(providerModels))
|
||||
|
||||
for modelID, entry := range providerModels {
|
||||
if entry == nil || entry.count <= 0 {
|
||||
continue
|
||||
}
|
||||
registration, ok := r.models[modelID]
|
||||
|
||||
expiredClients := 0
|
||||
cooldownSuspended := 0
|
||||
otherSuspended := 0
|
||||
if ok && registration != nil {
|
||||
if registration.QuotaExceededClients != nil {
|
||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||
continue
|
||||
}
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
}
|
||||
if registration.SuspendedClients != nil {
|
||||
for clientID, reason := range registration.SuspendedClients {
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(reason, "quota") {
|
||||
cooldownSuspended++
|
||||
continue
|
||||
}
|
||||
otherSuspended++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
availableClients := entry.count
|
||||
effectiveClients := availableClients - expiredClients - otherSuspended
|
||||
if effectiveClients < 0 {
|
||||
effectiveClients = 0
|
||||
}
|
||||
|
||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||
if entry.info != nil {
|
||||
result = append(result, entry.info)
|
||||
continue
|
||||
}
|
||||
if ok && registration != nil && registration.Info != nil {
|
||||
result = append(result, registration.Info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetModelCount returns the number of available clients for a specific model
|
||||
// Parameters:
|
||||
// - modelID: The model ID to check
|
||||
|
||||
@@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
@@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
|
||||
@@ -76,7 +76,8 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
|
||||
|
||||
// Execute performs a non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if strings.Contains(req.Model, "claude") {
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
if isClaude {
|
||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
@@ -95,10 +96,10 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
@@ -190,10 +191,10 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, true)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
@@ -520,14 +521,16 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
@@ -676,6 +679,8 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
@@ -692,9 +697,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload)
|
||||
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
@@ -1308,7 +1313,7 @@ func alias2ModelName(modelName string) string {
|
||||
|
||||
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
|
||||
// For Claude models, it additionally ensures thinking budget < max_tokens.
|
||||
func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
||||
func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
|
||||
payload = util.StripThinkingConfigIfUnsupported(model, payload)
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return payload
|
||||
@@ -1320,7 +1325,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
||||
raw := int(budget.Int())
|
||||
normalized := util.NormalizeThinkingBudget(model, raw)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(model), "claude")
|
||||
if isClaude {
|
||||
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
|
||||
if effectiveMax > 0 && normalized >= effectiveMax {
|
||||
|
||||
@@ -49,33 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
// Inject thinking config based on model metadata for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
|
||||
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||
|
||||
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(req.Model, body)
|
||||
body = ensureMaxTokensForThinking(model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
@@ -167,26 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
// Inject thinking config based on model metadata for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
|
||||
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||
body = checkSystemInstructions(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(req.Model, body)
|
||||
body = ensureMaxTokensForThinking(model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
@@ -310,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
|
||||
@@ -461,6 +446,19 @@ func (e *ClaudeExecutor) injectThinkingConfig(modelName string, metadata map[str
|
||||
return util.ApplyClaudeThinkingConfig(body, budget)
|
||||
}
|
||||
|
||||
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
|
||||
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
|
||||
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
|
||||
func disableThinkingIfToolChoiceForced(body []byte) []byte {
|
||||
toolChoiceType := gjson.GetBytes(body, "tool_choice.type").String()
|
||||
// "auto" is allowed with thinking, but "any" or "tool" (specific tool) are not
|
||||
if toolChoiceType == "any" || toolChoiceType == "tool" {
|
||||
// Remove thinking configuration entirely to avoid API error
|
||||
body, _ = sjson.DeleteBytes(body, "thinking")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
|
||||
// Anthropic API requires this constraint; violating it returns a 400 error.
|
||||
// This function should be called after all thinking configuration is finalized.
|
||||
|
||||
@@ -49,18 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
|
||||
@@ -146,20 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -246,20 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
|
||||
modelForCounting := req.Model
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
enc, err := tokenizerForCodexModel(modelForCounting)
|
||||
enc, err := tokenizerForCodexModel(model)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
@@ -520,3 +527,87 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
trimmed := strings.TrimSpace(alias)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entry := e.resolveCodexConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||
|
||||
// Candidate names to match against configured aliases/names.
|
||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||
candidates = append(candidates, original)
|
||||
}
|
||||
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.CodexKey {
|
||||
entry := &e.cfg.CodexKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.CodexKey {
|
||||
entry := &e.cfg.CodexKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -78,9 +78,9 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||
@@ -217,9 +217,9 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||
@@ -318,7 +318,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response, reqBody []byte, attempt string) {
|
||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
@@ -336,14 +336,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -365,12 +365,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -417,15 +417,17 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
|
||||
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||
// Gemini CLI endpoint when iterating fallback variants.
|
||||
for _, attemptModel := range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload)
|
||||
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
|
||||
|
||||
tok, errTok := tokenSource.Token()
|
||||
if errTok != nil {
|
||||
|
||||
@@ -77,19 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -98,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
}
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -173,21 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -287,19 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
apiKey, bearer := geminiCreds(auth)
|
||||
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
|
||||
|
||||
requestBody := bytes.NewReader(translatedReq)
|
||||
|
||||
@@ -398,6 +410,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
|
||||
return base
|
||||
}
|
||||
|
||||
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
trimmed := strings.TrimSpace(alias)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entry := e.resolveGeminiConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||
|
||||
// Candidate names to match against configured aliases/names.
|
||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||
candidates = append(candidates, original)
|
||||
}
|
||||
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.GeminiKey {
|
||||
entry := &e.cfg.GeminiKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.GeminiKey {
|
||||
entry := &e.cfg.GeminiKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
|
||||
@@ -120,8 +120,6 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
@@ -137,7 +135,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -146,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
}
|
||||
}
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -220,24 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -250,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -321,8 +322,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -338,10 +337,10 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -438,30 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -552,8 +554,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
|
||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
@@ -566,14 +566,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
@@ -641,21 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
@@ -665,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
@@ -808,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
|
||||
}
|
||||
return tok.AccessToken, nil
|
||||
}
|
||||
|
||||
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
|
||||
// It matches the requested model alias against configured models and returns the actual upstream name.
|
||||
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
trimmed := strings.TrimSpace(alias)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entry := e.resolveVertexConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||
|
||||
// Candidate names to match against configured aliases/names.
|
||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||
candidates = append(candidates, original)
|
||||
}
|
||||
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
|
||||
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.VertexCompatAPIKey {
|
||||
entry := &e.cfg.VertexCompatAPIKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.VertexCompatAPIKey {
|
||||
entry := &e.cfg.VertexCompatAPIKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -58,15 +58,13 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
@@ -150,15 +148,13 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
@@ -445,20 +441,85 @@ func ensureToolsArray(body []byte) []byte {
|
||||
return updated
|
||||
}
|
||||
|
||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
|
||||
// preserveReasoningContentInMessages checks if reasoning_content from assistant messages
|
||||
// is preserved in conversation history for iFlow models that support thinking.
|
||||
// This is helpful for multi-turn conversations where the model may benefit from seeing
|
||||
// its previous reasoning to maintain coherent thought chains.
|
||||
//
|
||||
// For GLM-4.6/4.7 and MiniMax M2/M2.1, it is recommended to include the full assistant
|
||||
// response (including reasoning_content) in message history for better context continuity.
|
||||
func preserveReasoningContentInMessages(body []byte) []byte {
|
||||
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||
|
||||
// Only apply to models that support thinking with history preservation
|
||||
needsPreservation := strings.HasPrefix(model, "glm-4") || strings.HasPrefix(model, "minimax-m2")
|
||||
|
||||
if !needsPreservation {
|
||||
return body
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
// Check if any assistant message already has reasoning_content preserved
|
||||
hasReasoningContent := false
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
role := msg.Get("role").String()
|
||||
if role == "assistant" {
|
||||
rc := msg.Get("reasoning_content")
|
||||
if rc.Exists() && rc.String() != "" {
|
||||
hasReasoningContent = true
|
||||
return false // stop iteration
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// If reasoning content is already present, the messages are properly formatted
|
||||
// No need to modify - the client has correctly preserved reasoning in history
|
||||
if hasReasoningContent {
|
||||
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
|
||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
||||
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
|
||||
//
|
||||
// Model-specific handling:
|
||||
// - GLM-4.6/4.7: Uses chat_template_kwargs.enable_thinking (boolean) and chat_template_kwargs.clear_thinking=false
|
||||
// - MiniMax M2/M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
|
||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
||||
if !effort.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||
val := strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
enableThinking := val != "none" && val != ""
|
||||
|
||||
// Remove reasoning_effort as we'll convert to model-specific format
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||
body, _ = sjson.DeleteBytes(body, "thinking")
|
||||
|
||||
// GLM-4.6/4.7: Use chat_template_kwargs
|
||||
if strings.HasPrefix(model, "glm-4") {
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||
if enableThinking {
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.clear_thinking", false)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// MiniMax M2/M2.1: Use reasoning_split
|
||||
if strings.HasPrefix(model, "minimax-m2") {
|
||||
body, _ = sjson.SetBytes(body, "reasoning_split", enableThinking)
|
||||
return body
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
|
||||
@@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
|
||||
|
||||
@@ -14,32 +14,54 @@ import (
|
||||
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
|
||||
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
|
||||
// Use the alias from metadata if available, as it's registered in the global registry
|
||||
// with thinking metadata; the upstream model name may not be registered.
|
||||
lookupModel := util.ResolveOriginalModel(model, metadata)
|
||||
|
||||
// Determine which model to use for thinking support check.
|
||||
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
|
||||
thinkingModel := lookupModel
|
||||
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
|
||||
thinkingModel = model
|
||||
}
|
||||
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
|
||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
||||
return payload
|
||||
}
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
if !util.ModelSupportsThinking(thinkingModel) {
|
||||
return payload
|
||||
}
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
|
||||
// applyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// ApplyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
|
||||
func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
|
||||
func ApplyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
|
||||
// Use the alias from metadata if available, as it's registered in the global registry
|
||||
// with thinking metadata; the upstream model name may not be registered.
|
||||
lookupModel := util.ResolveOriginalModel(model, metadata)
|
||||
|
||||
// Determine which model to use for thinking support check.
|
||||
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
|
||||
thinkingModel := lookupModel
|
||||
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
|
||||
thinkingModel = model
|
||||
}
|
||||
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
|
||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
||||
return payload
|
||||
}
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
if !util.ModelSupportsThinking(thinkingModel) {
|
||||
return payload
|
||||
}
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -52,12 +51,9 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -132,12 +128,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
|
||||
@@ -19,7 +19,7 @@ type usageReporter struct {
|
||||
provider string
|
||||
model string
|
||||
authID string
|
||||
authIndex uint64
|
||||
authIndex string
|
||||
apiKey string
|
||||
source string
|
||||
requestedAt time.Time
|
||||
@@ -482,12 +482,16 @@ func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
|
||||
cleaned := jsonBytes
|
||||
var changed bool
|
||||
|
||||
if gjson.GetBytes(cleaned, "usageMetadata").Exists() {
|
||||
if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
|
||||
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
|
||||
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
|
||||
changed = true
|
||||
}
|
||||
|
||||
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() {
|
||||
if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
|
||||
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
|
||||
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
|
||||
changed = true
|
||||
}
|
||||
|
||||
@@ -99,6 +99,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// This follows the Claude Code API specification for streaming message initialization
|
||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||
|
||||
// Use cpaUsageMetadata within the message_start event for Claude.
|
||||
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
||||
}
|
||||
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
||||
}
|
||||
|
||||
// Override default values with actual response metadata if available from the Gemini CLI response
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
|
||||
@@ -247,7 +247,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else if role == "assistant" {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
if content.Type == gjson.String {
|
||||
if content.Type == gjson.String && content.String() != "" {
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
|
||||
@@ -205,9 +205,12 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens)
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
}
|
||||
return []string{template}
|
||||
|
||||
@@ -281,8 +284,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
var messageID string
|
||||
var model string
|
||||
var createdAt int64
|
||||
var inputTokens, outputTokens int64
|
||||
var reasoningTokens int64
|
||||
var stopReason string
|
||||
var contentParts []string
|
||||
var reasoningParts []string
|
||||
@@ -299,9 +300,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
messageID = message.Get("id").String()
|
||||
model = message.Get("model").String()
|
||||
createdAt = time.Now().Unix()
|
||||
if usage := message.Get("usage"); usage.Exists() {
|
||||
inputTokens = usage.Get("input_tokens").Int()
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
@@ -364,11 +362,14 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
}
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
outputTokens = usage.Get("output_tokens").Int()
|
||||
// Estimate reasoning tokens from accumulated thinking content
|
||||
if len(reasoningParts) > 0 {
|
||||
reasoningTokens = int64(len(strings.Join(reasoningParts, "")) / 4) // Rough estimation
|
||||
}
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -427,16 +428,5 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
}
|
||||
|
||||
// Set usage information including prompt tokens, completion tokens, and total tokens
|
||||
totalTokens := inputTokens + outputTokens
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.total_tokens", totalTokens)
|
||||
|
||||
// Add reasoning tokens to usage details if any reasoning content was processed
|
||||
if reasoningTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.completion_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -114,13 +114,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
var builder strings.Builder
|
||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
text := part.Get("text").String()
|
||||
textResult := part.Get("text")
|
||||
text := textResult.String()
|
||||
if builder.Len() > 0 && text != "" {
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
builder.WriteString(text)
|
||||
return true
|
||||
})
|
||||
} else if parts.Type == gjson.String {
|
||||
builder.WriteString(parts.String())
|
||||
}
|
||||
instructionsText = builder.String()
|
||||
if instructionsText != "" {
|
||||
@@ -207,6 +210,8 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if parts.Type == gjson.String {
|
||||
textAggregate.WriteString(parts.String())
|
||||
}
|
||||
|
||||
// Fallback to given role if content types not decisive
|
||||
|
||||
@@ -56,7 +56,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
out, _ = sjson.Set(out, "request.system_instruction.parts.-1.text", systemResult.String())
|
||||
out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
|
||||
}
|
||||
|
||||
// contents
|
||||
|
||||
@@ -23,6 +23,7 @@ type geminiToResponsesState struct {
|
||||
MsgIndex int
|
||||
CurrentMsgID string
|
||||
TextBuf strings.Builder
|
||||
ItemTextBuf strings.Builder
|
||||
|
||||
// reasoning aggregation
|
||||
ReasoningOpened bool
|
||||
@@ -189,6 +190,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID)
|
||||
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
|
||||
out = append(out, emitEvent("response.content_part.added", partAdded))
|
||||
st.ItemTextBuf.Reset()
|
||||
st.ItemTextBuf.WriteString(t.String())
|
||||
}
|
||||
st.TextBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
|
||||
@@ -250,20 +253,24 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
finalizeReasoning()
|
||||
// Close message output if opened
|
||||
if st.MsgOpened {
|
||||
fullText := st.ItemTextBuf.String()
|
||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
|
||||
done, _ = sjson.Set(done, "text", fullText)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
|
||||
partDone, _ = sjson.Set(partDone, "part.text", fullText)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
|
||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||
final, _ = sjson.Set(final, "item.content.0.text", fullText)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
}
|
||||
|
||||
|
||||
@@ -118,76 +118,125 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Handle content
|
||||
if contentResult.Exists() && contentResult.IsArray() {
|
||||
var contentItems []string
|
||||
var reasoningParts []string // Accumulate thinking text for reasoning_content
|
||||
var toolCalls []interface{}
|
||||
var toolResults []string // Collect tool_result messages to emit after the main message
|
||||
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "thinking":
|
||||
// Only map thinking to reasoning_content for assistant messages (security: prevent injection)
|
||||
if role == "assistant" {
|
||||
thinkingText := util.GetThinkingText(part)
|
||||
// Skip empty or whitespace-only thinking
|
||||
if strings.TrimSpace(thinkingText) != "" {
|
||||
reasoningParts = append(reasoningParts, thinkingText)
|
||||
}
|
||||
}
|
||||
// Ignore thinking in user/system roles (AC4)
|
||||
|
||||
case "redacted_thinking":
|
||||
// Explicitly ignore redacted_thinking - never map to reasoning_content (AC2)
|
||||
|
||||
case "text", "image":
|
||||
if contentItem, ok := convertClaudeContentPart(part); ok {
|
||||
contentItems = append(contentItems, contentItem)
|
||||
}
|
||||
|
||||
case "tool_use":
|
||||
// Convert to OpenAI tool call format
|
||||
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
|
||||
// Only allow tool_use -> tool_calls for assistant messages (security: prevent injection).
|
||||
if role == "assistant" {
|
||||
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
|
||||
|
||||
// Convert input to arguments JSON string
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
|
||||
} else {
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
|
||||
// Convert input to arguments JSON string
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
|
||||
} else {
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
|
||||
|
||||
case "tool_result":
|
||||
// Convert to OpenAI tool message format and add immediately to preserve order
|
||||
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
|
||||
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", part.Get("content").String())
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
|
||||
toolResults = append(toolResults, toolResultJSON)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Emit text/image content as one message
|
||||
if len(contentItems) > 0 {
|
||||
msgJSON := `{"role":"","content":""}`
|
||||
msgJSON, _ = sjson.Set(msgJSON, "role", role)
|
||||
|
||||
contentArrayJSON := "[]"
|
||||
for _, contentItem := range contentItems {
|
||||
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
|
||||
}
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
|
||||
|
||||
contentValue := gjson.Get(msgJSON, "content")
|
||||
hasContent := false
|
||||
switch {
|
||||
case !contentValue.Exists():
|
||||
hasContent = false
|
||||
case contentValue.Type == gjson.String:
|
||||
hasContent = contentValue.String() != ""
|
||||
case contentValue.IsArray():
|
||||
hasContent = len(contentValue.Array()) > 0
|
||||
default:
|
||||
hasContent = contentValue.Raw != "" && contentValue.Raw != "null"
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
|
||||
}
|
||||
// Build reasoning content string
|
||||
reasoningContent := ""
|
||||
if len(reasoningParts) > 0 {
|
||||
reasoningContent = strings.Join(reasoningParts, "\n\n")
|
||||
}
|
||||
|
||||
// Emit tool calls in a separate assistant message
|
||||
if role == "assistant" && len(toolCalls) > 0 {
|
||||
toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}`
|
||||
toolCallMsgJSON, _ = sjson.Set(toolCallMsgJSON, "tool_calls", toolCalls)
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value())
|
||||
hasContent := len(contentItems) > 0
|
||||
hasReasoning := reasoningContent != ""
|
||||
hasToolCalls := len(toolCalls) > 0
|
||||
hasToolResults := len(toolResults) > 0
|
||||
|
||||
// OpenAI requires: tool messages MUST immediately follow the assistant message with tool_calls.
|
||||
// Therefore, we emit tool_result messages FIRST (they respond to the previous assistant's tool_calls),
|
||||
// then emit the current message's content.
|
||||
for _, toolResultJSON := range toolResults {
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
|
||||
}
|
||||
|
||||
// For assistant messages: emit a single unified message with content, tool_calls, and reasoning_content
|
||||
// This avoids splitting into multiple assistant messages which breaks OpenAI tool-call adjacency
|
||||
if role == "assistant" {
|
||||
if hasContent || hasReasoning || hasToolCalls {
|
||||
msgJSON := `{"role":"assistant"}`
|
||||
|
||||
// Add content (as array if we have items, empty string if reasoning-only)
|
||||
if hasContent {
|
||||
contentArrayJSON := "[]"
|
||||
for _, contentItem := range contentItems {
|
||||
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
|
||||
}
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
|
||||
} else {
|
||||
// Ensure content field exists for OpenAI compatibility
|
||||
msgJSON, _ = sjson.Set(msgJSON, "content", "")
|
||||
}
|
||||
|
||||
// Add reasoning_content if present
|
||||
if hasReasoning {
|
||||
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
|
||||
}
|
||||
|
||||
// Add tool_calls if present (in same message as content)
|
||||
if hasToolCalls {
|
||||
msgJSON, _ = sjson.Set(msgJSON, "tool_calls", toolCalls)
|
||||
}
|
||||
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
|
||||
}
|
||||
} else {
|
||||
// For non-assistant roles: emit content message if we have content
|
||||
// If the message only contains tool_results (no text/image), we still processed them above
|
||||
if hasContent {
|
||||
msgJSON := `{"role":""}`
|
||||
msgJSON, _ = sjson.Set(msgJSON, "role", role)
|
||||
|
||||
contentArrayJSON := "[]"
|
||||
for _, contentItem := range contentItems {
|
||||
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
|
||||
}
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
|
||||
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
|
||||
} else if hasToolResults && !hasContent {
|
||||
// tool_results already emitted above, no additional user message needed
|
||||
}
|
||||
}
|
||||
|
||||
} else if contentResult.Exists() && contentResult.Type == gjson.String {
|
||||
@@ -307,3 +356,43 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func convertClaudeToolResultContentToString(content gjson.Result) string {
|
||||
if !content.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
|
||||
if content.IsArray() {
|
||||
var parts []string
|
||||
content.ForEach(func(_, item gjson.Result) bool {
|
||||
switch {
|
||||
case item.Type == gjson.String:
|
||||
parts = append(parts, item.String())
|
||||
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
|
||||
parts = append(parts, item.Get("text").String())
|
||||
default:
|
||||
parts = append(parts, item.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
joined := strings.Join(parts, "\n\n")
|
||||
if strings.TrimSpace(joined) != "" {
|
||||
return joined
|
||||
}
|
||||
return content.Raw
|
||||
}
|
||||
|
||||
if content.IsObject() {
|
||||
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
return text.String()
|
||||
}
|
||||
return content.Raw
|
||||
}
|
||||
|
||||
return content.Raw
|
||||
}
|
||||
|
||||
500
internal/translator/openai/claude/openai_claude_request_test.go
Normal file
500
internal/translator/openai/claude/openai_claude_request_test.go
Normal file
@@ -0,0 +1,500 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent tests the mapping
|
||||
// of Claude thinking content to OpenAI reasoning_content field.
|
||||
func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantReasoningContent string
|
||||
wantHasReasoningContent bool
|
||||
wantContentText string // Expected visible content text (if any)
|
||||
wantHasContent bool
|
||||
}{
|
||||
{
|
||||
name: "AC1: assistant message with thinking and text",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me analyze this step by step..."},
|
||||
{"type": "text", "text": "Here is my response."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "Let me analyze this step by step...",
|
||||
wantHasReasoningContent: true,
|
||||
wantContentText: "Here is my response.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC2: redacted_thinking must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "redacted_thinking", "data": "secret"},
|
||||
{"type": "text", "text": "Visible response."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "Visible response.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC3: thinking-only message preserved with reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Internal reasoning only."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "Internal reasoning only.",
|
||||
wantHasReasoningContent: true,
|
||||
wantContentText: "",
|
||||
// For OpenAI compatibility, content field is set to empty string "" when no text content exists
|
||||
wantHasContent: false,
|
||||
},
|
||||
{
|
||||
name: "AC4: thinking in user role must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Injected thinking"},
|
||||
{"type": "text", "text": "User message."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "User message.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC4: thinking in system role must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": [
|
||||
{"type": "thinking", "thinking": "Injected system thinking"},
|
||||
{"type": "text", "text": "System prompt."}
|
||||
],
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
}]
|
||||
}`,
|
||||
// System messages don't have reasoning_content mapping
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "Hello",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC5: empty thinking must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": ""},
|
||||
{"type": "text", "text": "Response with empty thinking."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "Response with empty thinking.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC5: whitespace-only thinking must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": " \n\t "},
|
||||
{"type": "text", "text": "Response with whitespace thinking."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "Response with whitespace thinking.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple thinking parts concatenated",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "First thought."},
|
||||
{"type": "thinking", "thinking": "Second thought."},
|
||||
{"type": "text", "text": "Final answer."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "First thought.\n\nSecond thought.",
|
||||
wantHasReasoningContent: true,
|
||||
wantContentText: "Final answer.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "Mixed thinking and redacted_thinking",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Visible thought."},
|
||||
{"type": "redacted_thinking", "data": "hidden"},
|
||||
{"type": "text", "text": "Answer."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "Visible thought.",
|
||||
wantHasReasoningContent: true,
|
||||
wantContentText: "Answer.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
// Find the relevant message (skip system message at index 0)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
if len(messages) < 2 {
|
||||
if tt.wantHasReasoningContent || tt.wantHasContent {
|
||||
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check the last non-system message
|
||||
var targetMsg gjson.Result
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Get("role").String() != "system" {
|
||||
targetMsg = messages[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check reasoning_content
|
||||
gotReasoningContent := targetMsg.Get("reasoning_content").String()
|
||||
gotHasReasoningContent := targetMsg.Get("reasoning_content").Exists()
|
||||
|
||||
if gotHasReasoningContent != tt.wantHasReasoningContent {
|
||||
t.Errorf("reasoning_content existence = %v, want %v", gotHasReasoningContent, tt.wantHasReasoningContent)
|
||||
}
|
||||
|
||||
if gotReasoningContent != tt.wantReasoningContent {
|
||||
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
|
||||
}
|
||||
|
||||
// Check content
|
||||
content := targetMsg.Get("content")
|
||||
// content has meaningful content if it's a non-empty array, or a non-empty string
|
||||
var gotHasContent bool
|
||||
switch {
|
||||
case content.IsArray():
|
||||
gotHasContent = len(content.Array()) > 0
|
||||
case content.Type == gjson.String:
|
||||
gotHasContent = content.String() != ""
|
||||
default:
|
||||
gotHasContent = false
|
||||
}
|
||||
|
||||
if gotHasContent != tt.wantHasContent {
|
||||
t.Errorf("content existence = %v, want %v", gotHasContent, tt.wantHasContent)
|
||||
}
|
||||
|
||||
if tt.wantHasContent && tt.wantContentText != "" {
|
||||
// Find text content
|
||||
var foundText string
|
||||
content.ForEach(func(_, v gjson.Result) bool {
|
||||
if v.Get("type").String() == "text" {
|
||||
foundText = v.Get("text").String()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if foundText != tt.wantContentText {
|
||||
t.Errorf("content text = %q, want %q", foundText, tt.wantContentText)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved tests AC3:
|
||||
// that a message with only thinking content is preserved (not dropped).
|
||||
func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "What is 2+2?"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "thinking", "thinking": "Let me calculate: 2+2=4"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Thanks"}]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
// Check the assistant message (index 2) has reasoning_content
|
||||
assistantMsg := messages[2]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
if !assistantMsg.Get("reasoning_content").Exists() {
|
||||
t.Error("Expected assistant message to have reasoning_content")
|
||||
}
|
||||
|
||||
if assistantMsg.Get("reasoning_content").String() != "Let me calculate: 2+2=4" {
|
||||
t.Errorf("Unexpected reasoning_content: %s", assistantMsg.Get("reasoning_content").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "before"},
|
||||
{"type": "tool_result", "tool_use_id": "call_1", "content": [{"type":"text","text":"tool ok"}]},
|
||||
{"type": "text", "text": "after"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
|
||||
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after)
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[0].Get("role").String() != "system" {
|
||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
||||
}
|
||||
|
||||
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
|
||||
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
|
||||
}
|
||||
|
||||
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
|
||||
if messages[2].Get("role").String() != "tool" {
|
||||
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String())
|
||||
}
|
||||
if got := messages[2].Get("tool_call_id").String(); got != "call_1" {
|
||||
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
|
||||
}
|
||||
if got := messages[2].Get("content").String(); got != "tool ok" {
|
||||
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
|
||||
}
|
||||
|
||||
// User message comes after tool message
|
||||
if messages[3].Get("role").String() != "user" {
|
||||
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String())
|
||||
}
|
||||
// User message should contain both "before" and "after" text
|
||||
if got := messages[3].Get("content.0.text").String(); got != "before" {
|
||||
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
|
||||
}
|
||||
if got := messages[3].Get("content.1.text").String(); got != "after" {
|
||||
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_1", "content": {"foo": "bar"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// system + assistant(tool_calls) + tool(result)
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[2].Get("role").String() != "tool" {
|
||||
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String())
|
||||
}
|
||||
|
||||
toolContent := messages[2].Get("content").String()
|
||||
parsed := gjson.Parse(toolContent)
|
||||
if parsed.Get("foo").String() != "bar" {
|
||||
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "pre"},
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
|
||||
{"type": "text", "text": "post"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// New behavior: content + tool_calls unified in single assistant message
|
||||
// Expect: system + assistant(content[pre,post] + tool_calls)
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[0].Get("role").String() != "system" {
|
||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
||||
}
|
||||
|
||||
assistantMsg := messages[1]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
// Should have both content and tool_calls in same message
|
||||
if !assistantMsg.Get("tool_calls").Exists() {
|
||||
t.Fatalf("Expected assistant message to have tool_calls")
|
||||
}
|
||||
if got := assistantMsg.Get("tool_calls.0.id").String(); got != "call_1" {
|
||||
t.Fatalf("Expected tool_call id %q, got %q", "call_1", got)
|
||||
}
|
||||
if got := assistantMsg.Get("tool_calls.0.function.name").String(); got != "do_work" {
|
||||
t.Fatalf("Expected tool_call name %q, got %q", "do_work", got)
|
||||
}
|
||||
|
||||
// Content should have both pre and post text
|
||||
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
|
||||
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
|
||||
}
|
||||
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
|
||||
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "t1"},
|
||||
{"type": "text", "text": "pre"},
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
|
||||
{"type": "thinking", "thinking": "t2"},
|
||||
{"type": "text", "text": "post"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// New behavior: all content, thinking, and tool_calls unified in single assistant message
|
||||
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
assistantMsg := messages[1]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
// Should have content with both pre and post
|
||||
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
|
||||
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
|
||||
}
|
||||
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
|
||||
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
|
||||
}
|
||||
|
||||
// Should have tool_calls
|
||||
if !assistantMsg.Get("tool_calls").Exists() {
|
||||
t.Fatalf("Expected assistant message to have tool_calls")
|
||||
}
|
||||
|
||||
// Should have combined reasoning_content from both thinking blocks
|
||||
if got := assistantMsg.Get("reasoning_content").String(); got != "t1\n\nt2" {
|
||||
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
|
||||
}
|
||||
}
|
||||
@@ -480,15 +480,15 @@ func collectOpenAIReasoningTexts(node gjson.Result) []string {
|
||||
|
||||
switch node.Type {
|
||||
case gjson.String:
|
||||
if text := strings.TrimSpace(node.String()); text != "" {
|
||||
if text := node.String(); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
case gjson.JSON:
|
||||
if text := node.Get("text"); text.Exists() {
|
||||
if trimmed := strings.TrimSpace(text.String()); trimmed != "" {
|
||||
texts = append(texts, trimmed)
|
||||
if textStr := text.String(); textStr != "" {
|
||||
texts = append(texts, textStr)
|
||||
}
|
||||
} else if raw := strings.TrimSpace(node.Raw); raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
|
||||
} else if raw := node.Raw; raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
|
||||
texts = append(texts, raw)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package usage
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -90,7 +91,7 @@ type modelStats struct {
|
||||
type RequestDetail struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Source string `json:"source"`
|
||||
AuthIndex uint64 `json:"auth_index"`
|
||||
AuthIndex string `json:"auth_index"`
|
||||
Tokens TokenStats `json:"tokens"`
|
||||
Failed bool `json:"failed"`
|
||||
}
|
||||
@@ -281,6 +282,118 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
|
||||
return result
|
||||
}
|
||||
|
||||
type MergeResult struct {
|
||||
Added int64 `json:"added"`
|
||||
Skipped int64 `json:"skipped"`
|
||||
}
|
||||
|
||||
// MergeSnapshot merges an exported statistics snapshot into the current store.
|
||||
// Existing data is preserved and duplicate request details are skipped.
|
||||
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
|
||||
result := MergeResult{}
|
||||
if s == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
for apiName, stats := range s.apis {
|
||||
if stats == nil {
|
||||
continue
|
||||
}
|
||||
for modelName, modelStatsValue := range stats.Models {
|
||||
if modelStatsValue == nil {
|
||||
continue
|
||||
}
|
||||
for _, detail := range modelStatsValue.Details {
|
||||
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for apiName, apiSnapshot := range snapshot.APIs {
|
||||
apiName = strings.TrimSpace(apiName)
|
||||
if apiName == "" {
|
||||
continue
|
||||
}
|
||||
stats, ok := s.apis[apiName]
|
||||
if !ok || stats == nil {
|
||||
stats = &apiStats{Models: make(map[string]*modelStats)}
|
||||
s.apis[apiName] = stats
|
||||
} else if stats.Models == nil {
|
||||
stats.Models = make(map[string]*modelStats)
|
||||
}
|
||||
for modelName, modelSnapshot := range apiSnapshot.Models {
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if modelName == "" {
|
||||
modelName = "unknown"
|
||||
}
|
||||
for _, detail := range modelSnapshot.Details {
|
||||
detail.Tokens = normaliseTokenStats(detail.Tokens)
|
||||
if detail.Timestamp.IsZero() {
|
||||
detail.Timestamp = time.Now()
|
||||
}
|
||||
key := dedupKey(apiName, modelName, detail)
|
||||
if _, exists := seen[key]; exists {
|
||||
result.Skipped++
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
s.recordImported(apiName, modelName, stats, detail)
|
||||
result.Added++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
|
||||
totalTokens := detail.Tokens.TotalTokens
|
||||
if totalTokens < 0 {
|
||||
totalTokens = 0
|
||||
}
|
||||
|
||||
s.totalRequests++
|
||||
if detail.Failed {
|
||||
s.failureCount++
|
||||
} else {
|
||||
s.successCount++
|
||||
}
|
||||
s.totalTokens += totalTokens
|
||||
|
||||
s.updateAPIStats(stats, modelName, detail)
|
||||
|
||||
dayKey := detail.Timestamp.Format("2006-01-02")
|
||||
hourKey := detail.Timestamp.Hour()
|
||||
|
||||
s.requestsByDay[dayKey]++
|
||||
s.requestsByHour[hourKey]++
|
||||
s.tokensByDay[dayKey] += totalTokens
|
||||
s.tokensByHour[hourKey] += totalTokens
|
||||
}
|
||||
|
||||
func dedupKey(apiName, modelName string, detail RequestDetail) string {
|
||||
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
|
||||
tokens := normaliseTokenStats(detail.Tokens)
|
||||
return fmt.Sprintf(
|
||||
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
|
||||
apiName,
|
||||
modelName,
|
||||
timestamp,
|
||||
detail.Source,
|
||||
detail.AuthIndex,
|
||||
detail.Failed,
|
||||
tokens.InputTokens,
|
||||
tokens.OutputTokens,
|
||||
tokens.ReasoningTokens,
|
||||
tokens.CachedTokens,
|
||||
tokens.TotalTokens,
|
||||
)
|
||||
}
|
||||
|
||||
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
|
||||
if ctx != nil {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
|
||||
@@ -340,6 +453,16 @@ func normaliseDetail(detail coreusage.Detail) TokenStats {
|
||||
return tokens
|
||||
}
|
||||
|
||||
func normaliseTokenStats(tokens TokenStats) TokenStats {
|
||||
if tokens.TotalTokens == 0 {
|
||||
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
|
||||
}
|
||||
if tokens.TotalTokens == 0 {
|
||||
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func formatHour(hour int) string {
|
||||
if hour < 0 {
|
||||
hour = 0
|
||||
|
||||
@@ -344,7 +344,7 @@ func cleanupRequiredFields(jsonStr string) string {
|
||||
}
|
||||
|
||||
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
|
||||
// Claude VALIDATED mode requires at least one property in tool schemas.
|
||||
// Claude VALIDATED mode requires at least one required property in tool schemas.
|
||||
func addEmptySchemaPlaceholder(jsonStr string) string {
|
||||
// Find all "type" fields
|
||||
paths := findPaths(jsonStr, "type")
|
||||
@@ -364,6 +364,9 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
|
||||
// Check if properties exists and is empty or missing
|
||||
propsPath := joinPath(parentPath, "properties")
|
||||
propsVal := gjson.Get(jsonStr, propsPath)
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
reqVal := gjson.Get(jsonStr, reqPath)
|
||||
hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
|
||||
|
||||
needsPlaceholder := false
|
||||
if !propsVal.Exists() {
|
||||
@@ -381,8 +384,22 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
|
||||
|
||||
// Add to required array
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
|
||||
continue
|
||||
}
|
||||
|
||||
// If schema has properties but none are required, add a minimal placeholder.
|
||||
if propsVal.IsObject() && !hasRequiredProperties {
|
||||
// DO NOT add placeholder if it's a top-level schema (parentPath is empty)
|
||||
// or if we've already added a placeholder reason above.
|
||||
if parentPath == "" {
|
||||
continue
|
||||
}
|
||||
placeholderPath := joinPath(propsPath, "_")
|
||||
if !gjson.Get(jsonStr, placeholderPath).Exists() {
|
||||
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
|
||||
}
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -127,8 +127,10 @@ func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing
|
||||
"type": "object",
|
||||
"description": "Accepts: null | object",
|
||||
"properties": {
|
||||
"_": { "type": "boolean" },
|
||||
"kind": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["_"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
@@ -614,71 +616,6 @@ func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
|
||||
// propertyNames is used to validate object property names (e.g., must match a pattern)
|
||||
// Gemini doesn't support this keyword and will reject requests containing it
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
|
||||
},
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
|
||||
// Verify propertyNames is completely removed
|
||||
if strings.Contains(result, "propertyNames") {
|
||||
t.Errorf("propertyNames keyword should be removed, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
|
||||
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"propertyNames": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if strings.Contains(result, "propertyNames") {
|
||||
t.Errorf("Nested propertyNames should be removed, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
||||
var expMap, actMap map[string]interface{}
|
||||
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
||||
|
||||
@@ -288,37 +288,73 @@ func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadata applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal))
|
||||
// or numeric budget suffix (e.g., model(1000)) which gets converted to a thinkingLevel.
|
||||
func ApplyGemini3ThinkingLevelFromMetadata(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
// Use the alias from metadata if available for model type detection
|
||||
lookupModel := ResolveOriginalModel(model, metadata)
|
||||
if !IsGemini3Model(lookupModel) && !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
|
||||
// Determine which model to use for validation
|
||||
checkModel := model
|
||||
if IsGemini3Model(lookupModel) {
|
||||
checkModel = lookupModel
|
||||
}
|
||||
|
||||
// First try to get effort string from metadata
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
if ok && effort != "" {
|
||||
if level, valid := ValidateGemini3ThinkingLevel(checkModel, effort); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
}
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
|
||||
// Fallback: check for numeric budget and convert to thinkingLevel
|
||||
budget, _, _, matched := ThinkingFromMetadata(metadata)
|
||||
if matched && budget != nil {
|
||||
if level, valid := ThinkingBudgetToGemini3Level(checkModel, *budget); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadataCLI applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal))
|
||||
// or numeric budget suffix (e.g., model(1000)) which gets converted to a thinkingLevel.
|
||||
func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
// Use the alias from metadata if available for model type detection
|
||||
lookupModel := ResolveOriginalModel(model, metadata)
|
||||
if !IsGemini3Model(lookupModel) && !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
|
||||
// Determine which model to use for validation
|
||||
checkModel := model
|
||||
if IsGemini3Model(lookupModel) {
|
||||
checkModel = lookupModel
|
||||
}
|
||||
|
||||
// First try to get effort string from metadata
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
if ok && effort != "" {
|
||||
if level, valid := ValidateGemini3ThinkingLevel(checkModel, effort); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
}
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
|
||||
// Fallback: check for numeric budget and convert to thinkingLevel
|
||||
budget, _, _, matched := ThinkingFromMetadata(metadata)
|
||||
if matched && budget != nil {
|
||||
if level, valid := ThinkingBudgetToGemini3Level(checkModel, *budget); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -326,15 +362,17 @@ func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
|
||||
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||
if !ModelHasDefaultThinking(model) {
|
||||
func ApplyDefaultThinkingIfNeededCLI(model string, metadata map[string]any, body []byte) []byte {
|
||||
// Use the alias from metadata if available for model property lookup
|
||||
lookupModel := ResolveOriginalModel(model, metadata)
|
||||
if !ModelHasDefaultThinking(lookupModel) && !ModelHasDefaultThinking(model) {
|
||||
return body
|
||||
}
|
||||
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
|
||||
return body
|
||||
}
|
||||
// Gemini 3 models use thinkingLevel instead of thinkingBudget
|
||||
if IsGemini3Model(model) {
|
||||
if IsGemini3Model(lookupModel) || IsGemini3Model(model) {
|
||||
// Don't set a default - let the API use its dynamic default ("high")
|
||||
// Only set includeThoughts
|
||||
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
56
internal/util/sanitize_test.go
Normal file
56
internal/util/sanitize_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeFunctionName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Normal", "valid_name", "valid_name"},
|
||||
{"With Dots", "name.with.dots", "name.with.dots"},
|
||||
{"With Colons", "name:with:colons", "name:with:colons"},
|
||||
{"With Dashes", "name-with-dashes", "name-with-dashes"},
|
||||
{"Mixed Allowed", "name.with_dots:colons-dashes", "name.with_dots:colons-dashes"},
|
||||
{"Invalid Characters", "name!with@invalid#chars", "name_with_invalid_chars"},
|
||||
{"Spaces", "name with spaces", "name_with_spaces"},
|
||||
{"Non-ASCII", "name_with_你好_chars", "name_with____chars"},
|
||||
{"Starts with digit", "123name", "_123name"},
|
||||
{"Starts with dot", ".name", "_.name"},
|
||||
{"Starts with colon", ":name", "_:name"},
|
||||
{"Starts with dash", "-name", "_-name"},
|
||||
{"Starts with invalid char", "!name", "_name"},
|
||||
{"Exactly 64 chars", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
|
||||
{"Too long (65 chars)", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charactX", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
|
||||
{"Very long", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_limit_for_function_names", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_l"},
|
||||
{"Starts with digit (64 chars total)", "1234567890123456789012345678901234567890123456789012345678901234", "_123456789012345678901234567890123456789012345678901234567890123"},
|
||||
{"Starts with invalid char (64 chars total)", "!234567890123456789012345678901234567890123456789012345678901234", "_234567890123456789012345678901234567890123456789012345678901234"},
|
||||
{"Empty", "", ""},
|
||||
{"Single character invalid", "@", "_"},
|
||||
{"Single character valid", "a", "a"},
|
||||
{"Single character digit", "1", "_1"},
|
||||
{"Single character underscore", "_", "_"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SanitizeFunctionName(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("SanitizeFunctionName(%q) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
// Verify Gemini compliance
|
||||
if len(got) > 64 {
|
||||
t.Errorf("SanitizeFunctionName(%q) result too long: %d", tt.input, len(got))
|
||||
}
|
||||
if len(got) > 0 {
|
||||
first := got[0]
|
||||
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
|
||||
t.Errorf("SanitizeFunctionName(%q) result starts with invalid char: %c", tt.input, first)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,9 +12,18 @@ func ModelSupportsThinking(model string) bool {
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
// First check the global dynamic registry
|
||||
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil {
|
||||
return info.Thinking != nil
|
||||
}
|
||||
// Fallback: check static model definitions
|
||||
if info := registry.LookupStaticModelInfo(model); info != nil {
|
||||
return info.Thinking != nil
|
||||
}
|
||||
// Fallback: check Antigravity static config
|
||||
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil {
|
||||
return cfg.Thinking != nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -63,11 +72,19 @@ func thinkingRangeFromRegistry(model string) (found bool, min int, max int, zero
|
||||
if model == "" {
|
||||
return false, 0, 0, false, false
|
||||
}
|
||||
info := registry.GetGlobalRegistry().GetModelInfo(model)
|
||||
if info == nil || info.Thinking == nil {
|
||||
return false, 0, 0, false, false
|
||||
// First check global dynamic registry
|
||||
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil && info.Thinking != nil {
|
||||
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
|
||||
}
|
||||
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
|
||||
// Fallback: check static model definitions
|
||||
if info := registry.LookupStaticModelInfo(model); info != nil && info.Thinking != nil {
|
||||
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
|
||||
}
|
||||
// Fallback: check Antigravity static config
|
||||
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil && cfg.Thinking != nil {
|
||||
return true, cfg.Thinking.Min, cfg.Thinking.Max, cfg.Thinking.ZeroAllowed, cfg.Thinking.DynamicAllowed
|
||||
}
|
||||
return false, 0, 0, false, false
|
||||
}
|
||||
|
||||
// GetModelThinkingLevels returns the discrete reasoning effort levels for the model.
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ThinkingBudgetMetadataKey = "thinking_budget"
|
||||
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
|
||||
ReasoningEffortMetadataKey = "reasoning_effort"
|
||||
ThinkingOriginalModelMetadataKey = "thinking_original_model"
|
||||
ThinkingBudgetMetadataKey = "thinking_budget"
|
||||
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
|
||||
ReasoningEffortMetadataKey = "reasoning_effort"
|
||||
ThinkingOriginalModelMetadataKey = "thinking_original_model"
|
||||
ModelMappingOriginalModelMetadataKey = "model_mapping_original_model"
|
||||
)
|
||||
|
||||
// NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns
|
||||
@@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string {
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok {
|
||||
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
|
||||
if base := normalize(s); base != "" {
|
||||
return base
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok {
|
||||
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
|
||||
if base := normalize(s); base != "" {
|
||||
|
||||
@@ -8,12 +8,52 @@ import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var functionNameSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
|
||||
|
||||
// SanitizeFunctionName ensures a function name matches the requirements for Gemini/Vertex AI.
|
||||
// It replaces invalid characters with underscores, ensures it starts with a letter or underscore,
|
||||
// and truncates it to 64 characters if necessary.
|
||||
// Regex Rule: [^a-zA-Z0-9_.:-] replaced with _.
|
||||
func SanitizeFunctionName(name string) string {
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Replace invalid characters with underscore
|
||||
sanitized := functionNameSanitizer.ReplaceAllString(name, "_")
|
||||
|
||||
// Ensure it starts with a letter or underscore
|
||||
// Re-reading requirements: Must start with a letter or an underscore.
|
||||
if len(sanitized) > 0 {
|
||||
first := sanitized[0]
|
||||
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
|
||||
// If it starts with an allowed character but not allowed at the beginning (digit, dot, colon, dash),
|
||||
// we must prepend an underscore.
|
||||
|
||||
// To stay within the 64-character limit while prepending, we must truncate first.
|
||||
if len(sanitized) >= 64 {
|
||||
sanitized = sanitized[:63]
|
||||
}
|
||||
sanitized = "_" + sanitized
|
||||
}
|
||||
} else {
|
||||
sanitized = "_"
|
||||
}
|
||||
|
||||
// Truncate to 64 characters
|
||||
if len(sanitized) > 64 {
|
||||
sanitized = sanitized[:64]
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// SetLogLevel configures the logrus log level based on the configuration.
|
||||
// It sets the log level to DebugLevel if debug mode is enabled, otherwise to InfoLevel.
|
||||
func SetLogLevel(cfg *config.Config) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool {
|
||||
}
|
||||
|
||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
|
||||
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelMappings, newConfig.OAuthModelMappings))
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||
|
||||
@@ -90,6 +90,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldModels := SummarizeGeminiModels(o.Models)
|
||||
newModels := SummarizeGeminiModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
@@ -120,6 +125,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldModels := SummarizeClaudeModels(o.Models)
|
||||
newModels := SummarizeClaudeModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
@@ -150,6 +160,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldModels := SummarizeCodexModels(o.Models)
|
||||
newModels := SummarizeCodexModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
@@ -185,10 +200,18 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
||||
}
|
||||
oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys)
|
||||
newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys)
|
||||
if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount))
|
||||
}
|
||||
|
||||
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
if entries, _ := DiffOAuthModelMappingChanges(oldCfg.OAuthModelMappings, newCfg.OAuthModelMappings); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
@@ -301,3 +324,43 @@ func formatProxyURL(raw string) string {
|
||||
}
|
||||
return scheme + "://" + host
|
||||
}
|
||||
|
||||
func equalStringSet(a, b []string) bool {
|
||||
if len(a) == 0 && len(b) == 0 {
|
||||
return true
|
||||
}
|
||||
aSet := make(map[string]struct{}, len(a))
|
||||
for _, k := range a {
|
||||
aSet[strings.TrimSpace(k)] = struct{}{}
|
||||
}
|
||||
bSet := make(map[string]struct{}, len(b))
|
||||
for _, k := range b {
|
||||
bSet[strings.TrimSpace(k)] = struct{}{}
|
||||
}
|
||||
if len(aSet) != len(bSet) {
|
||||
return false
|
||||
}
|
||||
for k := range aSet {
|
||||
if _, ok := bSet[k]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality.
|
||||
// Comparison is done by count and content (upstream key and client keys).
|
||||
func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) {
|
||||
return false
|
||||
}
|
||||
if !equalStringSet(a[i].APIKeys, b[i].APIKeys) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -56,6 +56,36 @@ func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeCodexModelsHash returns a stable hash for Codex model aliases.
|
||||
func ComputeCodexModelsHash(models []config.CodexModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeGeminiModelsHash returns a stable hash for Gemini model aliases.
|
||||
func ComputeGeminiModelsHash(models []config.GeminiModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
|
||||
func ComputeExcludedModelsHash(excluded []string) string {
|
||||
if len(excluded) == 0 {
|
||||
|
||||
@@ -81,6 +81,15 @@ func TestComputeClaudeModelsHash_Empty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeCodexModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeCodexModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil models, got %q", got)
|
||||
}
|
||||
if got := ComputeCodexModelsHash([]config.CodexModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
|
||||
a := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
@@ -95,6 +104,20 @@ func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeCodexModelsHash_IgnoresBlankAndDedup(t *testing.T) {
|
||||
a := []config.CodexModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.CodexModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeCodexModelsHash(a), ComputeCodexModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
|
||||
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
|
||||
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
|
||||
@@ -157,3 +180,15 @@ func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
|
||||
t.Fatalf("expected different hash when models change, got %s", h3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeCodexModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.CodexModel{{Name: "a", Alias: "A"}, {Name: "b"}}
|
||||
h1 := ComputeCodexModelsHash(models)
|
||||
h2 := ComputeCodexModelsHash(models)
|
||||
if h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
|
||||
}
|
||||
if h3 := ComputeCodexModelsHash([]config.CodexModel{{Name: "a"}}); h3 == h1 {
|
||||
t.Fatalf("expected different hash when models change, got %s", h3)
|
||||
}
|
||||
}
|
||||
|
||||
121
internal/watcher/diff/models_summary.go
Normal file
121
internal/watcher/diff/models_summary.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
type GeminiModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
type ClaudeModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
type CodexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
type VertexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeGeminiModels hashes Gemini model aliases for change detection.
|
||||
func SummarizeGeminiModels(models []config.GeminiModel) GeminiModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return GeminiModelsSummary{}
|
||||
}
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return GeminiModelsSummary{
|
||||
hash: hashJoined(keys),
|
||||
count: len(keys),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeClaudeModels hashes Claude model aliases for change detection.
|
||||
func SummarizeClaudeModels(models []config.ClaudeModel) ClaudeModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return ClaudeModelsSummary{}
|
||||
}
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return ClaudeModelsSummary{
|
||||
hash: hashJoined(keys),
|
||||
count: len(keys),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeCodexModels hashes Codex model aliases for change detection.
|
||||
func SummarizeCodexModels(models []config.CodexModel) CodexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return CodexModelsSummary{}
|
||||
}
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return CodexModelsSummary{
|
||||
hash: hashJoined(keys),
|
||||
count: len(keys),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeVertexModels hashes Vertex-compatible model aliases for change detection.
|
||||
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
names := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" {
|
||||
name = alias
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
sort.Strings(names)
|
||||
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
|
||||
return VertexModelsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(names),
|
||||
}
|
||||
}
|
||||
@@ -116,36 +116,3 @@ func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappin
|
||||
count: len(entries),
|
||||
}
|
||||
}
|
||||
|
||||
type VertexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeVertexModels hashes vertex-compatible models for change detection.
|
||||
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
names := make([]string, 0, len(models))
|
||||
for _, m := range models {
|
||||
name := strings.TrimSpace(m.Name)
|
||||
alias := strings.TrimSpace(m.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" {
|
||||
name = alias
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
sort.Strings(names)
|
||||
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
|
||||
return VertexModelsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(names),
|
||||
}
|
||||
}
|
||||
|
||||
98
internal/watcher/diff/oauth_model_mappings.go
Normal file
98
internal/watcher/diff/oauth_model_mappings.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
type OAuthModelMappingsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeOAuthModelMappings summarizes OAuth model mappings per channel.
|
||||
func SummarizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string]OAuthModelMappingsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]OAuthModelMappingsSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = summarizeOAuthModelMappingList(v)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DiffOAuthModelMappingChanges compares OAuth model mappings maps.
|
||||
func DiffOAuthModelMappingChanges(oldMap, newMap map[string][]config.ModelNameMapping) ([]string, []string) {
|
||||
oldSummary := SummarizeOAuthModelMappings(oldMap)
|
||||
newSummary := SummarizeOAuthModelMappings(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMappingsSummary {
|
||||
if len(list) == 0 {
|
||||
return OAuthModelMappingsSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, mapping := range list {
|
||||
name := strings.ToLower(strings.TrimSpace(mapping.Name))
|
||||
alias := strings.ToLower(strings.TrimSpace(mapping.Alias))
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
key := name + "->" + alias
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
normalized = append(normalized, key)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return OAuthModelMappingsSummary{}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
sum := sha256.Sum256([]byte(strings.Join(normalized, "|")))
|
||||
return OAuthModelMappingsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
if hash := diff.ComputeGeminiModelsHash(entry.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(entry.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
@@ -147,6 +150,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
}
|
||||
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(ck.Headers, attrs)
|
||||
proxyURL := strings.TrimSpace(ck.ProxyURL)
|
||||
a := &coreauth.Auth{
|
||||
|
||||
@@ -104,8 +104,8 @@ func BuildErrorResponseBody(status int, errText string) []byte {
|
||||
// Returning 0 disables keep-alives (default when unset).
|
||||
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||
seconds := defaultStreamingKeepAliveSeconds
|
||||
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil {
|
||||
seconds = *cfg.Streaming.KeepAliveSeconds
|
||||
if cfg != nil {
|
||||
seconds = cfg.Streaming.KeepAliveSeconds
|
||||
}
|
||||
if seconds <= 0 {
|
||||
return 0
|
||||
@@ -116,8 +116,8 @@ func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
||||
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||
retries := defaultStreamingBootstrapRetries
|
||||
if cfg != nil && cfg.Streaming.BootstrapRetries != nil {
|
||||
retries = *cfg.Streaming.BootstrapRetries
|
||||
if cfg != nil {
|
||||
retries = cfg.Streaming.BootstrapRetries
|
||||
}
|
||||
if retries < 0 {
|
||||
retries = 0
|
||||
@@ -618,7 +618,22 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
||||
}
|
||||
|
||||
body := BuildErrorResponseBody(status, errText)
|
||||
c.Set("API_RESPONSE", bytes.Clone(body))
|
||||
// Append first to preserve upstream response logs, then drop duplicate payloads if already recorded.
|
||||
var previous []byte
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
previous = bytes.Clone(existingBytes)
|
||||
}
|
||||
}
|
||||
appendAPIResponse(c, body)
|
||||
trimmedErrText := strings.TrimSpace(errText)
|
||||
trimmedBody := bytes.TrimSpace(body)
|
||||
if len(previous) > 0 {
|
||||
if (trimmedErrText != "" && bytes.Contains(previous, []byte(trimmedErrText))) ||
|
||||
(len(trimmedBody) > 0 && bytes.Contains(previous, trimmedBody)) {
|
||||
c.Set("API_RESPONSE", previous)
|
||||
}
|
||||
}
|
||||
|
||||
if !c.Writer.Written() {
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -94,10 +94,9 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
bootstrapRetries := 1
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: &bootstrapRetries,
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
|
||||
67
sdk/api/management.go
Normal file
67
sdk/api/management.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Package api exposes helpers for embedding CLIProxyAPI.
|
||||
//
|
||||
// It wraps internal management handler types so external projects can integrate
|
||||
// management endpoints without importing internal packages.
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
internalmanagement "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
// ManagementTokenRequester exposes a limited subset of management endpoints for requesting tokens.
|
||||
type ManagementTokenRequester interface {
|
||||
RequestAnthropicToken(*gin.Context)
|
||||
RequestGeminiCLIToken(*gin.Context)
|
||||
RequestCodexToken(*gin.Context)
|
||||
RequestAntigravityToken(*gin.Context)
|
||||
RequestQwenToken(*gin.Context)
|
||||
RequestIFlowToken(*gin.Context)
|
||||
RequestIFlowCookieToken(*gin.Context)
|
||||
GetAuthStatus(c *gin.Context)
|
||||
}
|
||||
|
||||
type managementTokenRequester struct {
|
||||
handler *internalmanagement.Handler
|
||||
}
|
||||
|
||||
// NewManagementTokenRequester creates a limited management handler exposing only token request endpoints.
|
||||
func NewManagementTokenRequester(cfg *config.Config, manager *coreauth.Manager) ManagementTokenRequester {
|
||||
return &managementTokenRequester{
|
||||
handler: internalmanagement.NewHandlerWithoutConfigFilePath(cfg, manager),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestAnthropicToken(c *gin.Context) {
|
||||
m.handler.RequestAnthropicToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestGeminiCLIToken(c *gin.Context) {
|
||||
m.handler.RequestGeminiCLIToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestCodexToken(c *gin.Context) {
|
||||
m.handler.RequestCodexToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
|
||||
m.handler.RequestAntigravityToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
|
||||
m.handler.RequestQwenToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
|
||||
m.handler.RequestIFlowToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
|
||||
m.handler.RequestIFlowCookieToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) GetAuthStatus(c *gin.Context) {
|
||||
m.handler.GetAuthStatus(c)
|
||||
}
|
||||
@@ -111,6 +111,9 @@ type Manager struct {
|
||||
requestRetry atomic.Int32
|
||||
maxRetryInterval atomic.Int64
|
||||
|
||||
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
|
||||
modelNameMappings atomic.Value
|
||||
|
||||
// Optional HTTP RoundTripper provider injected by host.
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
@@ -203,10 +206,10 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, nil
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
if auth.ID == "" {
|
||||
auth.ID = uuid.NewString()
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
m.mu.Lock()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.mu.Unlock()
|
||||
@@ -221,7 +224,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
return nil, nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == 0 {
|
||||
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == "" {
|
||||
auth.Index = existing.Index
|
||||
auth.indexAssigned = existing.indexAssigned
|
||||
}
|
||||
@@ -263,7 +266,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -302,7 +304,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -341,7 +342,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -413,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -474,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -535,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
@@ -595,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
|
||||
keys := []string{
|
||||
util.ThinkingOriginalModelMetadataKey,
|
||||
util.GeminiOriginalModelMetadataKey,
|
||||
util.ModelMappingOriginalModelMetadataKey,
|
||||
}
|
||||
var out map[string]any
|
||||
for _, key := range keys {
|
||||
@@ -640,13 +644,20 @@ func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// rotateProviders returns a rotated view of the providers list starting from the
|
||||
// current offset for the model, and atomically increments the offset for the next call.
|
||||
// This ensures concurrent requests get different starting providers.
|
||||
func (m *Manager) rotateProviders(model string, providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
}
|
||||
m.mu.RLock()
|
||||
|
||||
// Atomic read-and-increment: get current offset and advance cursor in one lock
|
||||
m.mu.Lock()
|
||||
offset := m.providerOffsets[model]
|
||||
m.mu.RUnlock()
|
||||
m.providerOffsets[model] = (offset + 1) % len(providers)
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(providers) > 0 {
|
||||
offset %= len(providers)
|
||||
}
|
||||
@@ -662,19 +673,6 @@ func (m *Manager) rotateProviders(model string, providers []string) []string {
|
||||
return rotated
|
||||
}
|
||||
|
||||
func (m *Manager) advanceProviderCursor(model string, providers []string) {
|
||||
if len(providers) == 0 {
|
||||
m.mu.Lock()
|
||||
delete(m.providerOffsets, model)
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
current := m.providerOffsets[model]
|
||||
m.providerOffsets[model] = (current + 1) % len(providers)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) retrySettings() (int, time.Duration) {
|
||||
if m == nil {
|
||||
return 0, 0
|
||||
|
||||
171
sdk/cliproxy/auth/model_name_mappings.go
Normal file
171
sdk/cliproxy/auth/model_name_mappings.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
type modelNameMappingTable struct {
|
||||
// reverse maps channel -> alias (lower) -> original upstream model name.
|
||||
reverse map[string]map[string]string
|
||||
}
|
||||
|
||||
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
|
||||
if len(mappings) == 0 {
|
||||
return &modelNameMappingTable{}
|
||||
}
|
||||
out := &modelNameMappingTable{
|
||||
reverse: make(map[string]map[string]string, len(mappings)),
|
||||
}
|
||||
for rawChannel, entries := range mappings {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
rev := make(map[string]string, len(entries))
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, exists := rev[aliasKey]; exists {
|
||||
continue
|
||||
}
|
||||
rev[aliasKey] = name
|
||||
}
|
||||
if len(rev) > 0 {
|
||||
out.reverse[channel] = rev
|
||||
}
|
||||
}
|
||||
if len(out.reverse) == 0 {
|
||||
out.reverse = nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
|
||||
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
|
||||
// client-visible model name unchanged for translation/response formatting.
|
||||
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
table := compileModelNameMappingTable(mappings)
|
||||
// atomic.Value requires non-nil store values.
|
||||
if table == nil {
|
||||
table = &modelNameMappingTable{}
|
||||
}
|
||||
m.modelNameMappings.Store(table)
|
||||
}
|
||||
|
||||
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings
|
||||
// and returns the resolved model along with updated metadata. If a mapping exists,
|
||||
// the returned model is the upstream model and metadata contains the original
|
||||
// requested model for response translation.
|
||||
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
|
||||
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
|
||||
if upstreamModel == "" {
|
||||
return requestedModel, metadata
|
||||
}
|
||||
out := make(map[string]any, 1)
|
||||
if len(metadata) > 0 {
|
||||
out = make(map[string]any, len(metadata)+1)
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
// Store the requested alias (e.g., "gp") so downstream can use it to look up
|
||||
// model metadata from the global registry where it was registered under this alias.
|
||||
out[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
return upstreamModel, out
|
||||
}
|
||||
|
||||
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
|
||||
if m == nil || auth == nil {
|
||||
return ""
|
||||
}
|
||||
channel := modelMappingChannel(auth)
|
||||
if channel == "" {
|
||||
return ""
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
raw := m.modelNameMappings.Load()
|
||||
table, _ := raw.(*modelNameMappingTable)
|
||||
if table == nil || table.reverse == nil {
|
||||
return ""
|
||||
}
|
||||
rev := table.reverse[channel]
|
||||
if rev == nil {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(rev[key])
|
||||
if original == "" || strings.EqualFold(original, requestedModel) {
|
||||
return ""
|
||||
}
|
||||
return original
|
||||
}
|
||||
|
||||
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
|
||||
// It determines the provider and auth kind from the Auth's attributes and delegates
|
||||
// to OAuthModelMappingChannel for the actual channel resolution.
|
||||
func modelMappingChannel(auth *Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
authKind := ""
|
||||
if auth.Attributes != nil {
|
||||
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
|
||||
}
|
||||
if authKind == "" {
|
||||
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
return OAuthModelMappingChannel(provider, authKind)
|
||||
}
|
||||
|
||||
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model mappings (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
func OAuthModelMappingChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
switch provider {
|
||||
case "gemini":
|
||||
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
|
||||
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
|
||||
return ""
|
||||
case "vertex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "vertex"
|
||||
case "claude":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "claude"
|
||||
case "codex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
||||
@@ -15,8 +16,8 @@ import (
|
||||
type Auth struct {
|
||||
// ID uniquely identifies the auth record across restarts.
|
||||
ID string `json:"id"`
|
||||
// Index is a monotonically increasing runtime identifier used for diagnostics.
|
||||
Index uint64 `json:"-"`
|
||||
// Index is a stable runtime identifier derived from auth metadata (not persisted).
|
||||
Index string `json:"-"`
|
||||
// Provider is the upstream provider key (e.g. "gemini", "claude").
|
||||
Provider string `json:"provider"`
|
||||
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
|
||||
@@ -94,12 +95,6 @@ type ModelState struct {
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
var authIndexCounter atomic.Uint64
|
||||
|
||||
func nextAuthIndex() uint64 {
|
||||
return authIndexCounter.Add(1) - 1
|
||||
}
|
||||
|
||||
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
|
||||
func (a *Auth) Clone() *Auth {
|
||||
if a == nil {
|
||||
@@ -128,15 +123,41 @@ func (a *Auth) Clone() *Auth {
|
||||
return ©Auth
|
||||
}
|
||||
|
||||
// EnsureIndex returns the global index, assigning one if it was not set yet.
|
||||
func (a *Auth) EnsureIndex() uint64 {
|
||||
if a == nil {
|
||||
return 0
|
||||
func stableAuthIndex(seed string) string {
|
||||
seed = strings.TrimSpace(seed)
|
||||
if seed == "" {
|
||||
return ""
|
||||
}
|
||||
if a.indexAssigned {
|
||||
sum := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
// EnsureIndex returns a stable index derived from the auth file name or API key.
|
||||
func (a *Auth) EnsureIndex() string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
if a.indexAssigned && a.Index != "" {
|
||||
return a.Index
|
||||
}
|
||||
idx := nextAuthIndex()
|
||||
|
||||
seed := strings.TrimSpace(a.FileName)
|
||||
if seed != "" {
|
||||
seed = "file:" + seed
|
||||
} else if a.Attributes != nil {
|
||||
if apiKey := strings.TrimSpace(a.Attributes["api_key"]); apiKey != "" {
|
||||
seed = "api_key:" + apiKey
|
||||
}
|
||||
}
|
||||
if seed == "" {
|
||||
if id := strings.TrimSpace(a.ID); id != "" {
|
||||
seed = "id:" + id
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
idx := stableAuthIndex(seed)
|
||||
a.Index = idx
|
||||
a.indexAssigned = true
|
||||
return idx
|
||||
|
||||
@@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) {
|
||||
}
|
||||
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
||||
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
|
||||
coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings)
|
||||
|
||||
service := &Service{
|
||||
cfg: b.cfg,
|
||||
|
||||
@@ -13,6 +13,7 @@ type ModelRegistry interface {
|
||||
ClearModelQuotaExceeded(clientID, modelID string)
|
||||
ClientSupportsModel(clientID, modelID string) bool
|
||||
GetAvailableModels(handlerType string) []map[string]any
|
||||
GetAvailableModelsByProvider(provider string) []*ModelInfo
|
||||
}
|
||||
|
||||
// GlobalModelRegistry returns the shared registry instance.
|
||||
|
||||
@@ -552,6 +552,9 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.cfgMu.Lock()
|
||||
s.cfg = newCfg
|
||||
s.cfgMu.Unlock()
|
||||
if s.coreManager != nil {
|
||||
s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings)
|
||||
}
|
||||
s.rebindExecutors()
|
||||
}
|
||||
|
||||
@@ -677,6 +680,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
return
|
||||
}
|
||||
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
|
||||
if authKind == "" {
|
||||
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -702,6 +710,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "gemini":
|
||||
models = registry.GetGeminiModels()
|
||||
if entry := s.resolveConfigGeminiKey(a); entry != nil {
|
||||
if len(entry.Models) > 0 {
|
||||
models = buildGeminiConfigModels(entry)
|
||||
}
|
||||
if authKind == "apikey" {
|
||||
excluded = entry.ExcludedModels
|
||||
}
|
||||
@@ -741,6 +752,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "codex":
|
||||
models = registry.GetOpenAIModels()
|
||||
if entry := s.resolveConfigCodexKey(a); entry != nil {
|
||||
if len(entry.Models) > 0 {
|
||||
models = buildCodexConfigModels(entry)
|
||||
}
|
||||
if authKind == "apikey" {
|
||||
excluded = entry.ExcludedModels
|
||||
}
|
||||
@@ -833,6 +847,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
}
|
||||
}
|
||||
}
|
||||
models = applyOAuthModelMappings(s.cfg, provider, authKind, models)
|
||||
if len(models) > 0 {
|
||||
key := provider
|
||||
if key == "" {
|
||||
@@ -1104,17 +1119,22 @@ func matchWildcard(pattern, value string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
type modelEntry interface {
|
||||
GetName() string
|
||||
GetAlias() string
|
||||
}
|
||||
|
||||
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
out := make([]*ModelInfo, 0, len(entry.Models))
|
||||
seen := make(map[string]struct{}, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
out := make([]*ModelInfo, 0, len(models))
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for i := range models {
|
||||
model := models[i]
|
||||
name := strings.TrimSpace(model.GetName())
|
||||
alias := strings.TrimSpace(model.GetAlias())
|
||||
if alias == "" {
|
||||
alias = name
|
||||
}
|
||||
@@ -1130,52 +1150,135 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
||||
if display == "" {
|
||||
display = alias
|
||||
}
|
||||
out = append(out, &ModelInfo{
|
||||
info := &ModelInfo{
|
||||
ID: alias,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "vertex",
|
||||
Type: "vertex",
|
||||
OwnedBy: ownedBy,
|
||||
Type: modelType,
|
||||
DisplayName: display,
|
||||
})
|
||||
}
|
||||
if name != "" {
|
||||
if upstream := registry.LookupStaticModelInfo(name); upstream != nil && upstream.Thinking != nil {
|
||||
info.Thinking = upstream.Thinking
|
||||
}
|
||||
}
|
||||
out = append(out, info)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
out := make([]*ModelInfo, 0, len(entry.Models))
|
||||
seen := make(map[string]struct{}, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = name
|
||||
}
|
||||
if alias == "" {
|
||||
return buildConfigModels(entry.Models, "google", "vertex")
|
||||
}
|
||||
|
||||
func buildGeminiConfigModels(entry *config.GeminiKey) []*ModelInfo {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return buildConfigModels(entry.Models, "google", "gemini")
|
||||
}
|
||||
|
||||
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return buildConfigModels(entry.Models, "anthropic", "claude")
|
||||
}
|
||||
|
||||
func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return buildConfigModels(entry.Models, "openai", "openai")
|
||||
}
|
||||
|
||||
func rewriteModelInfoName(name, oldID, newID string) string {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return name
|
||||
}
|
||||
oldID = strings.TrimSpace(oldID)
|
||||
newID = strings.TrimSpace(newID)
|
||||
if oldID == "" || newID == "" {
|
||||
return name
|
||||
}
|
||||
if strings.EqualFold(oldID, newID) {
|
||||
return name
|
||||
}
|
||||
if strings.HasSuffix(trimmed, "/"+oldID) {
|
||||
prefix := strings.TrimSuffix(trimmed, oldID)
|
||||
return prefix + newID
|
||||
}
|
||||
if trimmed == "models/"+oldID {
|
||||
return "models/" + newID
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
|
||||
if cfg == nil || len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
channel := coreauth.OAuthModelMappingChannel(provider, authKind)
|
||||
if channel == "" || len(cfg.OAuthModelMappings) == 0 {
|
||||
return models
|
||||
}
|
||||
mappings := cfg.OAuthModelMappings[channel]
|
||||
if len(mappings) == 0 {
|
||||
return models
|
||||
}
|
||||
forward := make(map[string]string, len(mappings))
|
||||
for i := range mappings {
|
||||
name := strings.TrimSpace(mappings[i].Name)
|
||||
alias := strings.TrimSpace(mappings[i].Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(alias)
|
||||
if _, exists := seen[key]; exists {
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
display := name
|
||||
if display == "" {
|
||||
display = alias
|
||||
key := strings.ToLower(name)
|
||||
if _, exists := forward[key]; exists {
|
||||
continue
|
||||
}
|
||||
out = append(out, &ModelInfo{
|
||||
ID: alias,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "claude",
|
||||
Type: "claude",
|
||||
DisplayName: display,
|
||||
})
|
||||
forward[key] = alias
|
||||
}
|
||||
if len(forward) == 0 {
|
||||
return models
|
||||
}
|
||||
out := make([]*ModelInfo, 0, len(models))
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
id := strings.TrimSpace(model.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
mappedID := id
|
||||
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
|
||||
mappedID = strings.TrimSpace(to)
|
||||
}
|
||||
uniqueKey := strings.ToLower(mappedID)
|
||||
if _, exists := seen[uniqueKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[uniqueKey] = struct{}{}
|
||||
if mappedID == id {
|
||||
out = append(out, model)
|
||||
continue
|
||||
}
|
||||
clone := *model
|
||||
clone.ID = mappedID
|
||||
if clone.Name != "" {
|
||||
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
|
||||
}
|
||||
out = append(out, &clone)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ type Record struct {
|
||||
Model string
|
||||
APIKey string
|
||||
AuthID string
|
||||
AuthIndex uint64
|
||||
AuthIndex string
|
||||
Source string
|
||||
RequestedAt time.Time
|
||||
Failed bool
|
||||
|
||||
@@ -16,6 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig
|
||||
type TLSConfig = internalconfig.TLSConfig
|
||||
type RemoteManagement = internalconfig.RemoteManagement
|
||||
type AmpCode = internalconfig.AmpCode
|
||||
type ModelNameMapping = internalconfig.ModelNameMapping
|
||||
type PayloadConfig = internalconfig.PayloadConfig
|
||||
type PayloadRule = internalconfig.PayloadRule
|
||||
type PayloadModelRule = internalconfig.PayloadModelRule
|
||||
|
||||
@@ -56,6 +56,10 @@ func setupAmpRouter(h *management.Handler) *gin.Engine {
|
||||
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
|
||||
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
|
||||
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
|
||||
mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys)
|
||||
mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys)
|
||||
mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys)
|
||||
mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys)
|
||||
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
|
||||
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
|
||||
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
|
||||
@@ -188,6 +192,90 @@ func TestPutAmpUpstreamAPIKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) {
|
||||
h, configPath := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify it was persisted to disk
|
||||
loaded, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config from disk: %v", err)
|
||||
}
|
||||
if len(loaded.AmpCode.UpstreamAPIKeys) != 1 {
|
||||
t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys))
|
||||
}
|
||||
entry := loaded.AmpCode.UpstreamAPIKeys[0]
|
||||
if entry.UpstreamAPIKey != "u1" {
|
||||
t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey)
|
||||
}
|
||||
if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" {
|
||||
t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys)
|
||||
}
|
||||
|
||||
// Verify it is returned by GET /ampcode
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
var resp map[string]config.AmpCode
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" {
|
||||
t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
r := setupAmpRouter(h)
|
||||
|
||||
// Seed with one entry
|
||||
putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
deleteBody := `{"value":[]}`
|
||||
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil)
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
var resp map[string][]config.AmpUpstreamAPIKeyEntry
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 {
|
||||
t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
|
||||
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
|
||||
h, _ := newAmpTestHandler(t)
|
||||
|
||||
211
test/model_alias_thinking_suffix_test.go
Normal file
211
test/model_alias_thinking_suffix_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TestModelAliasThinkingSuffix tests the 32 test cases defined in docs/thinking_suffix_test_cases.md
|
||||
// These tests verify the thinking suffix parsing and application logic across different providers.
|
||||
func TestModelAliasThinkingSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
id int
|
||||
name string
|
||||
provider string
|
||||
requestModel string
|
||||
suffixType string
|
||||
expectedField string // "thinkingBudget", "thinkingLevel", "budget_tokens", "reasoning_effort", "enable_thinking"
|
||||
expectedValue any
|
||||
upstreamModel string // The upstream model after alias resolution
|
||||
isAlias bool
|
||||
}{
|
||||
// === 1. Antigravity Provider ===
|
||||
// 1.1 Budget-only models (Gemini 2.5)
|
||||
{1, "antigravity_original_numeric", "antigravity", "gemini-2.5-computer-use-preview-10-2025(1000)", "numeric", "thinkingBudget", 1000, "gemini-2.5-computer-use-preview-10-2025", false},
|
||||
{2, "antigravity_alias_numeric", "antigravity", "gp(1000)", "numeric", "thinkingBudget", 1000, "gemini-2.5-computer-use-preview-10-2025", true},
|
||||
// 1.2 Budget+Levels models (Gemini 3)
|
||||
{3, "antigravity_original_numeric_to_level", "antigravity", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{4, "antigravity_original_level", "antigravity", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{5, "antigravity_alias_numeric_to_level", "antigravity", "gf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
{6, "antigravity_alias_level", "antigravity", "gf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
|
||||
// === 2. Gemini CLI Provider ===
|
||||
// 2.1 Budget-only models
|
||||
{7, "gemini_cli_original_numeric", "gemini-cli", "gemini-2.5-pro(8192)", "numeric", "thinkingBudget", 8192, "gemini-2.5-pro", false},
|
||||
{8, "gemini_cli_alias_numeric", "gemini-cli", "g25p(8192)", "numeric", "thinkingBudget", 8192, "gemini-2.5-pro", true},
|
||||
// 2.2 Budget+Levels models
|
||||
{9, "gemini_cli_original_numeric_to_level", "gemini-cli", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{10, "gemini_cli_original_level", "gemini-cli", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{11, "gemini_cli_alias_numeric_to_level", "gemini-cli", "gf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
{12, "gemini_cli_alias_level", "gemini-cli", "gf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
|
||||
// === 3. Vertex Provider ===
|
||||
// 3.1 Budget-only models
|
||||
{13, "vertex_original_numeric", "vertex", "gemini-2.5-pro(16384)", "numeric", "thinkingBudget", 16384, "gemini-2.5-pro", false},
|
||||
{14, "vertex_alias_numeric", "vertex", "vg25p(16384)", "numeric", "thinkingBudget", 16384, "gemini-2.5-pro", true},
|
||||
// 3.2 Budget+Levels models
|
||||
{15, "vertex_original_numeric_to_level", "vertex", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{16, "vertex_original_level", "vertex", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{17, "vertex_alias_numeric_to_level", "vertex", "vgf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
{18, "vertex_alias_level", "vertex", "vgf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
|
||||
// === 4. AI Studio Provider ===
|
||||
// 4.1 Budget-only models
|
||||
{19, "aistudio_original_numeric", "aistudio", "gemini-2.5-pro(12000)", "numeric", "thinkingBudget", 12000, "gemini-2.5-pro", false},
|
||||
{20, "aistudio_alias_numeric", "aistudio", "ag25p(12000)", "numeric", "thinkingBudget", 12000, "gemini-2.5-pro", true},
|
||||
// 4.2 Budget+Levels models
|
||||
{21, "aistudio_original_numeric_to_level", "aistudio", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{22, "aistudio_original_level", "aistudio", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{23, "aistudio_alias_numeric_to_level", "aistudio", "agf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
{24, "aistudio_alias_level", "aistudio", "agf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
|
||||
// === 5. Claude Provider ===
|
||||
{25, "claude_original_numeric", "claude", "claude-sonnet-4-5-20250929(16384)", "numeric", "budget_tokens", 16384, "claude-sonnet-4-5-20250929", false},
|
||||
{26, "claude_alias_numeric", "claude", "cs45(16384)", "numeric", "budget_tokens", 16384, "claude-sonnet-4-5-20250929", true},
|
||||
|
||||
// === 6. Codex Provider ===
|
||||
{27, "codex_original_level", "codex", "gpt-5(high)", "level", "reasoning_effort", "high", "gpt-5", false},
|
||||
{28, "codex_alias_level", "codex", "g5(high)", "level", "reasoning_effort", "high", "gpt-5", true},
|
||||
|
||||
// === 7. Qwen Provider ===
|
||||
{29, "qwen_original_level", "qwen", "qwen3-coder-plus(high)", "level", "enable_thinking", true, "qwen3-coder-plus", false},
|
||||
{30, "qwen_alias_level", "qwen", "qcp(high)", "level", "enable_thinking", true, "qwen3-coder-plus", true},
|
||||
|
||||
// === 8. iFlow Provider ===
|
||||
{31, "iflow_original_level", "iflow", "glm-4.7(high)", "level", "reasoning_effort", "high", "glm-4.7", false},
|
||||
{32, "iflow_alias_level", "iflow", "glm(high)", "level", "reasoning_effort", "high", "glm-4.7", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Step 1: Parse model suffix (simulates SDK layer normalization)
|
||||
// For "gp(1000)" -> requestedModel="gp", metadata={thinking_budget: 1000}
|
||||
requestedModel, metadata := util.NormalizeThinkingModel(tt.requestModel)
|
||||
|
||||
// Verify suffix was parsed
|
||||
if metadata == nil && (tt.suffixType == "numeric" || tt.suffixType == "level") {
|
||||
t.Errorf("Case #%d: NormalizeThinkingModel(%q) metadata is nil", tt.id, tt.requestModel)
|
||||
return
|
||||
}
|
||||
|
||||
// Step 2: Simulate OAuth model mapping
|
||||
// Real flow: applyOAuthModelMapping stores requestedModel (the alias) in metadata
|
||||
if tt.isAlias {
|
||||
if metadata == nil {
|
||||
metadata = make(map[string]any)
|
||||
}
|
||||
metadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
}
|
||||
|
||||
// Step 3: Verify metadata extraction
|
||||
switch tt.suffixType {
|
||||
case "numeric":
|
||||
budget, _, _, matched := util.ThinkingFromMetadata(metadata)
|
||||
if !matched {
|
||||
t.Errorf("Case #%d: ThinkingFromMetadata did not match", tt.id)
|
||||
return
|
||||
}
|
||||
if budget == nil {
|
||||
t.Errorf("Case #%d: expected budget in metadata", tt.id)
|
||||
return
|
||||
}
|
||||
// For thinkingBudget/budget_tokens, verify the parsed budget value
|
||||
if tt.expectedField == "thinkingBudget" || tt.expectedField == "budget_tokens" {
|
||||
expectedBudget := tt.expectedValue.(int)
|
||||
if *budget != expectedBudget {
|
||||
t.Errorf("Case #%d: budget = %d, want %d", tt.id, *budget, expectedBudget)
|
||||
}
|
||||
}
|
||||
// For thinkingLevel (Gemini 3), verify conversion from budget to level
|
||||
if tt.expectedField == "thinkingLevel" {
|
||||
level, ok := util.ThinkingBudgetToGemini3Level(tt.upstreamModel, *budget)
|
||||
if !ok {
|
||||
t.Errorf("Case #%d: ThinkingBudgetToGemini3Level failed", tt.id)
|
||||
return
|
||||
}
|
||||
expectedLevel := tt.expectedValue.(string)
|
||||
if level != expectedLevel {
|
||||
t.Errorf("Case #%d: converted level = %q, want %q", tt.id, level, expectedLevel)
|
||||
}
|
||||
}
|
||||
|
||||
case "level":
|
||||
_, _, effort, matched := util.ThinkingFromMetadata(metadata)
|
||||
if !matched {
|
||||
t.Errorf("Case #%d: ThinkingFromMetadata did not match", tt.id)
|
||||
return
|
||||
}
|
||||
if effort == nil {
|
||||
t.Errorf("Case #%d: expected effort in metadata", tt.id)
|
||||
return
|
||||
}
|
||||
if tt.expectedField == "thinkingLevel" || tt.expectedField == "reasoning_effort" {
|
||||
expectedEffort := tt.expectedValue.(string)
|
||||
if *effort != expectedEffort {
|
||||
t.Errorf("Case #%d: effort = %q, want %q", tt.id, *effort, expectedEffort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Test Gemini-specific thinkingLevel conversion for Gemini 3 models
|
||||
if tt.expectedField == "thinkingLevel" && util.IsGemini3Model(tt.upstreamModel) {
|
||||
body := []byte(`{"request":{"contents":[]}}`)
|
||||
|
||||
// Build metadata simulating real OAuth flow:
|
||||
// - requestedModel (alias like "gf") is stored in model_mapping_original_model
|
||||
// - upstreamModel is passed as the model parameter
|
||||
testMetadata := make(map[string]any)
|
||||
if tt.isAlias {
|
||||
// Real flow: applyOAuthModelMapping stores requestedModel (the alias)
|
||||
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
}
|
||||
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
|
||||
for k, v := range metadata {
|
||||
testMetadata[k] = v
|
||||
}
|
||||
|
||||
result := util.ApplyGemini3ThinkingLevelFromMetadataCLI(tt.upstreamModel, testMetadata, body)
|
||||
levelVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
|
||||
expectedLevel := tt.expectedValue.(string)
|
||||
if !levelVal.Exists() {
|
||||
t.Errorf("Case #%d: expected thinkingLevel in result", tt.id)
|
||||
} else if levelVal.String() != expectedLevel {
|
||||
t.Errorf("Case #%d: thinkingLevel = %q, want %q", tt.id, levelVal.String(), expectedLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Test Gemini 2.5 thinkingBudget application using real ApplyThinkingMetadataCLI flow
|
||||
if tt.expectedField == "thinkingBudget" && util.IsGemini25Model(tt.upstreamModel) {
|
||||
body := []byte(`{"request":{"contents":[]}}`)
|
||||
|
||||
// Build metadata simulating real OAuth flow:
|
||||
// - requestedModel (alias like "gp") is stored in model_mapping_original_model
|
||||
// - upstreamModel is passed as the model parameter
|
||||
testMetadata := make(map[string]any)
|
||||
if tt.isAlias {
|
||||
// Real flow: applyOAuthModelMapping stores requestedModel (the alias)
|
||||
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
}
|
||||
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
|
||||
for k, v := range metadata {
|
||||
testMetadata[k] = v
|
||||
}
|
||||
|
||||
// Use the exported ApplyThinkingMetadataCLI which includes the fallback logic
|
||||
result := executor.ApplyThinkingMetadataCLI(body, testMetadata, tt.upstreamModel)
|
||||
budgetVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
|
||||
expectedBudget := tt.expectedValue.(int)
|
||||
if !budgetVal.Exists() {
|
||||
t.Errorf("Case #%d: expected thinkingBudget in result", tt.id)
|
||||
} else if int(budgetVal.Int()) != expectedBudget {
|
||||
t.Errorf("Case #%d: thinkingBudget = %d, want %d", tt.id, int(budgetVal.Int()), expectedBudget)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user