mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 12:30:50 +08:00
Compare commits
112 Commits
v6.6.72
...
v6.6.109-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7248f65c36 | ||
|
|
086eb3df7a | ||
|
|
5a7e5bd870 | ||
|
|
6f8a8f8136 | ||
|
|
b163f8ed9e | ||
|
|
a1da6ff5ac | ||
|
|
43652d044c | ||
|
|
b1b379ea18 | ||
|
|
21ac161b21 | ||
|
|
94e979865e | ||
|
|
6c324f2c8b | ||
|
|
543dfd67e0 | ||
|
|
28bd1323a2 | ||
|
|
220ca45f74 | ||
|
|
70a82d80ac | ||
|
|
ac626111ac | ||
|
|
8cfe26f10c | ||
|
|
80db2dc254 | ||
|
|
e8e3bc8616 | ||
|
|
bc3195c8d8 | ||
|
|
6494330c6b | ||
|
|
4d7f389b69 | ||
|
|
95f87d5669 | ||
|
|
c83365a349 | ||
|
|
6b3604cf2b | ||
|
|
af6bdca14f | ||
|
|
1c773c428f | ||
|
|
e785bfcd12 | ||
|
|
47dacce6ea | ||
|
|
dcac3407ab | ||
|
|
7004295e1d | ||
|
|
ee62ef4745 | ||
|
|
ef6bafbf7e | ||
|
|
ed28b71e87 | ||
|
|
d47b7dc79a | ||
|
|
49b9709ce5 | ||
|
|
a2eba2cdf5 | ||
|
|
3d01b3cfe8 | ||
|
|
af2efa6f7e | ||
|
|
d73b61d367 | ||
|
|
59a448b645 | ||
|
|
4adb9eed77 | ||
|
|
b6a0f7a07f | ||
|
|
1b2f907671 | ||
|
|
bda04eed8a | ||
|
|
67985d8226 | ||
|
|
cbcb061812 | ||
|
|
9fc2e1b3c8 | ||
|
|
3b484aea9e | ||
|
|
963a0950fa | ||
|
|
f4ba1ab910 | ||
|
|
2662f91082 | ||
|
|
c1db2c7d7c | ||
|
|
5e5d8142f9 | ||
|
|
b01619b441 | ||
|
|
f861bd6a94 | ||
|
|
6dbfdd140d | ||
|
|
386ccffed4 | ||
|
|
ffddd1c90a | ||
|
|
8f8dfd081b | ||
|
|
9f1b445c7c | ||
|
|
ae933dfe14 | ||
|
|
e124db723b | ||
|
|
05444cf32d | ||
|
|
8edbda57cf | ||
|
|
821249a5ed | ||
|
|
ee33863b47 | ||
|
|
cd22c849e2 | ||
|
|
f0e73efda2 | ||
|
|
3156109c71 | ||
|
|
6762e081f3 | ||
|
|
7815ee338d | ||
|
|
44b6c872e2 | ||
|
|
7a77b23f2d | ||
|
|
672e8549c0 | ||
|
|
66f5269a23 | ||
|
|
ebec293497 | ||
|
|
e02ceecd35 | ||
|
|
c8b33a8cc3 | ||
|
|
dca8d5ded8 | ||
|
|
2a7fd1e897 | ||
|
|
b9d1e70ac2 | ||
|
|
fdf5720217 | ||
|
|
f40bd0cd51 | ||
|
|
e33676bb87 | ||
|
|
2a663d5cba | ||
|
|
750b930679 | ||
|
|
3902fd7501 | ||
|
|
4fc3d5e935 | ||
|
|
2d2f4572a7 | ||
|
|
8f4c46f38d | ||
|
|
b6ba51bc2a | ||
|
|
6a66d32d37 | ||
|
|
8d15723195 | ||
|
|
736e0aae86 | ||
|
|
8bf3305b2b | ||
|
|
d00e3ea973 | ||
|
|
89db4e9481 | ||
|
|
e332419081 | ||
|
|
e998b1229a | ||
|
|
47b9503112 | ||
|
|
3b9253c2be | ||
|
|
d241359153 | ||
|
|
f4d4249ba5 | ||
|
|
414db44c00 | ||
|
|
cb3bdffb43 | ||
|
|
48f19aab51 | ||
|
|
48f6d7abdf | ||
|
|
79fbcb3ec4 | ||
|
|
0e4148b229 | ||
|
|
31bd90c748 | ||
|
|
0b834fcb54 |
23
README.md
23
README.md
@@ -118,9 +118,32 @@ Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings,
|
||||
|
||||
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
|
||||
|
||||
### [CodMate](https://github.com/loocor/CodMate)
|
||||
|
||||
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
|
||||
|
||||
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||
|
||||
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
|
||||
|
||||
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||
|
||||
VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
## More choices
|
||||
|
||||
Those projects are ports of CLIProxyAPI or inspired by it:
|
||||
|
||||
### [9Router](https://github.com/decolua/9router)
|
||||
|
||||
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
|
||||
|
||||
> [!NOTE]
|
||||
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
23
README_CN.md
23
README_CN.md
@@ -117,9 +117,32 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
|
||||
|
||||
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
|
||||
|
||||
### [CodMate](https://github.com/loocor/CodMate)
|
||||
|
||||
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话(Claude Code、Codex、Gemini CLI),提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
|
||||
|
||||
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||
|
||||
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
|
||||
|
||||
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||
|
||||
一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
## 更多选择
|
||||
|
||||
以下项目是 CLIProxyAPI 的移植版或受其启发:
|
||||
|
||||
### [9Router](https://github.com/decolua/9router)
|
||||
|
||||
基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
||||
|
||||
## 许可证
|
||||
|
||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||
|
||||
@@ -61,6 +61,7 @@ func main() {
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
@@ -75,6 +76,7 @@ func main() {
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
@@ -425,7 +427,8 @@ func main() {
|
||||
|
||||
// Create login options to be used in authentication flows.
|
||||
options := &cmd.LoginOptions{
|
||||
NoBrowser: noBrowser,
|
||||
NoBrowser: noBrowser,
|
||||
CallbackPort: oauthCallbackPort,
|
||||
}
|
||||
|
||||
// Register the shared token store once so all components use the same persistence backend.
|
||||
|
||||
@@ -77,6 +77,9 @@ routing:
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
|
||||
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
||||
nonstream-keepalive-interval: 0
|
||||
|
||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||
# streaming:
|
||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||
@@ -202,10 +205,12 @@ ws-auth: false
|
||||
# These mappings rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
# oauth-model-mappings:
|
||||
# gemini-cli:
|
||||
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||
# alias: "g2.5p" # client-visible alias
|
||||
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
|
||||
# vertex:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
|
||||
128
docker-build.sh
128
docker-build.sh
@@ -5,9 +5,115 @@
|
||||
# This script automates the process of building and running the Docker container
|
||||
# with version information dynamically injected at build time.
|
||||
|
||||
# Exit immediately if a command exits with a non-zero status.
|
||||
# Hidden feature: Preserve usage statistics across rebuilds
|
||||
# Usage: ./docker-build.sh --with-usage
|
||||
# First run prompts for management API key, saved to temp/stats/.api_secret
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATS_DIR="temp/stats"
|
||||
STATS_FILE="${STATS_DIR}/.usage_backup.json"
|
||||
SECRET_FILE="${STATS_DIR}/.api_secret"
|
||||
WITH_USAGE=false
|
||||
|
||||
get_port() {
|
||||
if [[ -f "config.yaml" ]]; then
|
||||
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
|
||||
else
|
||||
echo "8317"
|
||||
fi
|
||||
}
|
||||
|
||||
export_stats_api_secret() {
|
||||
if [[ -f "${SECRET_FILE}" ]]; then
|
||||
API_SECRET=$(cat "${SECRET_FILE}")
|
||||
else
|
||||
if [[ ! -d "${STATS_DIR}" ]]; then
|
||||
mkdir -p "${STATS_DIR}"
|
||||
fi
|
||||
echo "First time using --with-usage. Management API key required."
|
||||
read -r -p "Enter management key: " -s API_SECRET
|
||||
echo
|
||||
echo "${API_SECRET}" > "${SECRET_FILE}"
|
||||
chmod 600 "${SECRET_FILE}"
|
||||
fi
|
||||
}
|
||||
|
||||
check_container_running() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
|
||||
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
|
||||
echo "Please start the container first or use without --with-usage flag."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
export_stats() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
if [[ ! -d "${STATS_DIR}" ]]; then
|
||||
mkdir -p "${STATS_DIR}"
|
||||
fi
|
||||
check_container_running
|
||||
echo "Exporting usage statistics..."
|
||||
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
|
||||
"http://localhost:${port}/v0/management/usage/export")
|
||||
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
|
||||
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
|
||||
|
||||
if [[ "${HTTP_CODE}" != "200" ]]; then
|
||||
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
|
||||
echo "Statistics exported to ${STATS_FILE}"
|
||||
}
|
||||
|
||||
import_stats() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
echo "Importing usage statistics..."
|
||||
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
|
||||
-H "X-Management-Key: ${API_SECRET}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d @"${STATS_FILE}" \
|
||||
"http://localhost:${port}/v0/management/usage/import")
|
||||
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
|
||||
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
|
||||
|
||||
if [[ "${IMPORT_CODE}" == "200" ]]; then
|
||||
echo "Statistics imported successfully"
|
||||
else
|
||||
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
|
||||
fi
|
||||
|
||||
rm -f "${STATS_FILE}"
|
||||
}
|
||||
|
||||
wait_for_service() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
echo "Waiting for service to be ready..."
|
||||
for i in {1..30}; do
|
||||
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
sleep 2
|
||||
}
|
||||
|
||||
if [[ "${1:-}" == "--with-usage" ]]; then
|
||||
WITH_USAGE=true
|
||||
export_stats_api_secret
|
||||
fi
|
||||
|
||||
# --- Step 1: Choose Environment ---
|
||||
echo "Please select an option:"
|
||||
echo "1) Run using Pre-built Image (Recommended)"
|
||||
@@ -18,7 +124,14 @@ read -r -p "Enter choice [1-2]: " choice
|
||||
case "$choice" in
|
||||
1)
|
||||
echo "--- Running with Pre-built Image ---"
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
export_stats
|
||||
fi
|
||||
docker compose up -d --remove-orphans --no-build
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
wait_for_service
|
||||
import_stats
|
||||
fi
|
||||
echo "Services are starting from remote image."
|
||||
echo "Run 'docker compose logs -f' to see the logs."
|
||||
;;
|
||||
@@ -38,16 +151,25 @@ case "$choice" in
|
||||
|
||||
# Build and start the services with a local-only image tag
|
||||
export CLI_PROXY_IMAGE="cli-proxy-api:local"
|
||||
|
||||
|
||||
echo "Building the Docker image..."
|
||||
docker compose build \
|
||||
--build-arg VERSION="${VERSION}" \
|
||||
--build-arg COMMIT="${COMMIT}" \
|
||||
--build-arg BUILD_DATE="${BUILD_DATE}"
|
||||
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
export_stats
|
||||
fi
|
||||
|
||||
echo "Starting the services..."
|
||||
docker compose up -d --remove-orphans --pull never
|
||||
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
wait_for_service
|
||||
import_stats
|
||||
fi
|
||||
|
||||
echo "Build complete. Services are starting."
|
||||
echo "Run 'docker compose logs -f' to see the logs."
|
||||
;;
|
||||
@@ -55,4 +177,4 @@ case "$choice" in
|
||||
echo "Invalid choice. Please enter 1 or 2."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
esac
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -122,7 +123,9 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Inject credentials via PrepareRequest hook.
|
||||
_ = (MyExecutor{}).PrepareRequest(httpReq, a)
|
||||
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
|
||||
return clipexec.Response{}, errPrep
|
||||
}
|
||||
|
||||
resp, errDo := client.Do(httpReq)
|
||||
if errDo != nil {
|
||||
@@ -130,13 +133,28 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
// Best-effort close; log if needed in real projects.
|
||||
fmt.Fprintf(os.Stderr, "close response body error: %v\n", errClose)
|
||||
}
|
||||
}()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return clipexec.Response{Payload: body}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) HttpRequest(ctx context.Context, a *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("myprov executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
|
||||
return nil, errPrep
|
||||
}
|
||||
client := buildHTTPClient(a)
|
||||
return client.Do(httpReq)
|
||||
}
|
||||
|
||||
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
@@ -199,8 +217,8 @@ func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
panic(err)
|
||||
if errRun := svc.Run(ctx); errRun != nil && !errors.Is(errRun, context.Canceled) {
|
||||
panic(errRun)
|
||||
}
|
||||
_ = os.Stderr // keep os import used (demo only)
|
||||
_ = time.Second
|
||||
|
||||
140
examples/http-request/main.go
Normal file
140
examples/http-request/main.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Package main demonstrates how to use coreauth.Manager.HttpRequest/NewHttpRequest
|
||||
// to execute arbitrary HTTP requests with provider credentials injected.
|
||||
//
|
||||
// This example registers a minimal custom executor that injects an Authorization
|
||||
// header from auth.Attributes["api_key"], then performs two requests against
|
||||
// httpbin.org to show the injected headers.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const providerKey = "echo"
|
||||
|
||||
// EchoExecutor is a minimal provider implementation for demonstration purposes.
|
||||
type EchoExecutor struct{}
|
||||
|
||||
func (EchoExecutor) Identifier() string { return providerKey }
|
||||
|
||||
func (EchoExecutor) PrepareRequest(req *http.Request, auth *coreauth.Auth) error {
|
||||
if req == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if apiKey := strings.TrimSpace(auth.Attributes["api_key"]); apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (EchoExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("echo executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrep := (EchoExecutor{}).PrepareRequest(httpReq, auth); errPrep != nil {
|
||||
return nil, errPrep
|
||||
}
|
||||
return http.DefaultClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) Refresh(context.Context, *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return nil, errors.New("echo executor: Refresh not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("echo executor: CountTokens not implemented")
|
||||
}
|
||||
|
||||
func main() {
|
||||
log.SetLevel(log.InfoLevel)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
core := coreauth.NewManager(nil, nil, nil)
|
||||
core.RegisterExecutor(EchoExecutor{})
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "demo-echo",
|
||||
Provider: providerKey,
|
||||
Attributes: map[string]string{
|
||||
"api_key": "demo-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
// Example 1: Build a prepared request and execute it using your own http.Client.
|
||||
reqPrepared, errReqPrepared := core.NewHttpRequest(
|
||||
ctx,
|
||||
auth,
|
||||
http.MethodGet,
|
||||
"https://httpbin.org/anything",
|
||||
nil,
|
||||
http.Header{"X-Example": []string{"prepared"}},
|
||||
)
|
||||
if errReqPrepared != nil {
|
||||
panic(errReqPrepared)
|
||||
}
|
||||
respPrepared, errDoPrepared := http.DefaultClient.Do(reqPrepared)
|
||||
if errDoPrepared != nil {
|
||||
panic(errDoPrepared)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := respPrepared.Body.Close(); errClose != nil {
|
||||
log.Errorf("close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
bodyPrepared, errReadPrepared := io.ReadAll(respPrepared.Body)
|
||||
if errReadPrepared != nil {
|
||||
panic(errReadPrepared)
|
||||
}
|
||||
fmt.Printf("Prepared request status: %d\n%s\n\n", respPrepared.StatusCode, bodyPrepared)
|
||||
|
||||
// Example 2: Execute a raw request via core.HttpRequest (auto inject + do).
|
||||
rawBody := []byte(`{"hello":"world"}`)
|
||||
rawReq, errRawReq := http.NewRequestWithContext(ctx, http.MethodPost, "https://httpbin.org/anything", bytes.NewReader(rawBody))
|
||||
if errRawReq != nil {
|
||||
panic(errRawReq)
|
||||
}
|
||||
rawReq.Header.Set("Content-Type", "application/json")
|
||||
rawReq.Header.Set("X-Example", "executed")
|
||||
|
||||
respExec, errDoExec := core.HttpRequest(ctx, auth, rawReq)
|
||||
if errDoExec != nil {
|
||||
panic(errDoExec)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := respExec.Body.Close(); errClose != nil {
|
||||
log.Errorf("close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
bodyExec, errReadExec := io.ReadAll(respExec.Body)
|
||||
if errReadExec != nil {
|
||||
panic(errReadExec)
|
||||
}
|
||||
fmt.Printf("Manager HttpRequest status: %d\n%s\n", respExec.StatusCode, bodyExec)
|
||||
}
|
||||
@@ -33,6 +33,13 @@ var geminiOAuthScopes = []string{
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
const (
|
||||
antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityOAuthClientSecret = "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"
|
||||
)
|
||||
|
||||
var antigravityOAuthTokenURL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
type apiCallRequest struct {
|
||||
AuthIndexSnake *string `json:"auth_index"`
|
||||
AuthIndexCamel *string `json:"authIndex"`
|
||||
@@ -251,6 +258,10 @@ func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth)
|
||||
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
|
||||
return token, errToken
|
||||
}
|
||||
if provider == "antigravity" {
|
||||
token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth)
|
||||
return token, errToken
|
||||
}
|
||||
|
||||
return tokenValueForAuth(auth), nil
|
||||
}
|
||||
@@ -325,6 +336,161 @@ func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *corea
|
||||
return strings.TrimSpace(currentToken.AccessToken), nil
|
||||
}
|
||||
|
||||
func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
metadata := auth.Metadata
|
||||
if len(metadata) == 0 {
|
||||
return "", fmt.Errorf("antigravity oauth metadata missing")
|
||||
}
|
||||
|
||||
current := strings.TrimSpace(tokenValueFromMetadata(metadata))
|
||||
if current != "" && !antigravityTokenNeedsRefresh(metadata) {
|
||||
return current, nil
|
||||
}
|
||||
|
||||
refreshToken := stringValue(metadata, "refresh_token")
|
||||
if refreshToken == "" {
|
||||
return "", fmt.Errorf("antigravity refresh token missing")
|
||||
}
|
||||
|
||||
tokenURL := strings.TrimSpace(antigravityOAuthTokenURL)
|
||||
if tokenURL == "" {
|
||||
tokenURL = "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
form := url.Values{}
|
||||
form.Set("client_id", antigravityOAuthClientID)
|
||||
form.Set("client_secret", antigravityOAuthClientSecret)
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("refresh_token", refreshToken)
|
||||
|
||||
req, errReq := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||
if errReq != nil {
|
||||
return "", errReq
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
Transport: h.apiCallTransport(auth),
|
||||
}
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return "", errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errRead != nil {
|
||||
return "", errRead
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return "", fmt.Errorf("antigravity oauth token refresh failed: status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
if strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||
return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token")
|
||||
}
|
||||
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
now := time.Now()
|
||||
auth.Metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
|
||||
if strings.TrimSpace(tokenResp.RefreshToken) != "" {
|
||||
auth.Metadata["refresh_token"] = strings.TrimSpace(tokenResp.RefreshToken)
|
||||
}
|
||||
if tokenResp.ExpiresIn > 0 {
|
||||
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
|
||||
auth.Metadata["timestamp"] = now.UnixMilli()
|
||||
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
}
|
||||
auth.Metadata["type"] = "antigravity"
|
||||
|
||||
if h != nil && h.authManager != nil {
|
||||
auth.LastRefreshedAt = now
|
||||
auth.UpdatedAt = now
|
||||
_, _ = h.authManager.Update(ctx, auth)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(tokenResp.AccessToken), nil
|
||||
}
|
||||
|
||||
func antigravityTokenNeedsRefresh(metadata map[string]any) bool {
|
||||
// Refresh a bit early to avoid requests racing token expiry.
|
||||
const skew = 30 * time.Second
|
||||
|
||||
if metadata == nil {
|
||||
return true
|
||||
}
|
||||
if expStr, ok := metadata["expired"].(string); ok {
|
||||
if ts, errParse := time.Parse(time.RFC3339, strings.TrimSpace(expStr)); errParse == nil {
|
||||
return !ts.After(time.Now().Add(skew))
|
||||
}
|
||||
}
|
||||
expiresIn := int64Value(metadata["expires_in"])
|
||||
timestampMs := int64Value(metadata["timestamp"])
|
||||
if expiresIn > 0 && timestampMs > 0 {
|
||||
exp := time.UnixMilli(timestampMs).Add(time.Duration(expiresIn) * time.Second)
|
||||
return !exp.After(time.Now().Add(skew))
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func int64Value(raw any) int64 {
|
||||
switch typed := raw.(type) {
|
||||
case int:
|
||||
return int64(typed)
|
||||
case int32:
|
||||
return int64(typed)
|
||||
case int64:
|
||||
return typed
|
||||
case uint:
|
||||
return int64(typed)
|
||||
case uint32:
|
||||
return int64(typed)
|
||||
case uint64:
|
||||
if typed > uint64(^uint64(0)>>1) {
|
||||
return 0
|
||||
}
|
||||
return int64(typed)
|
||||
case float32:
|
||||
return int64(typed)
|
||||
case float64:
|
||||
return int64(typed)
|
||||
case json.Number:
|
||||
if i, errParse := typed.Int64(); errParse == nil {
|
||||
return i
|
||||
}
|
||||
case string:
|
||||
if s := strings.TrimSpace(typed); s != "" {
|
||||
if i, errParse := json.Number(s).Int64(); errParse == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
|
||||
if auth == nil {
|
||||
return nil, nil
|
||||
|
||||
173
internal/api/handlers/management/api_tools_test.go
Normal file
173
internal/api/handlers/management/api_tools_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
type memoryAuthStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]*coreauth.Auth
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||
for _, a := range s.items {
|
||||
out = append(out, a.Clone())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
_ = ctx
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.items == nil {
|
||||
s.items = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
s.items[auth.ID] = auth.Clone()
|
||||
s.mu.Unlock()
|
||||
return auth.ID, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
delete(s.items, id)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
||||
t.Fatalf("unexpected content-type: %s", ct)
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
values, err := url.ParseQuery(string(bodyBytes))
|
||||
if err != nil {
|
||||
t.Fatalf("parse form: %v", err)
|
||||
}
|
||||
if values.Get("grant_type") != "refresh_token" {
|
||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
||||
}
|
||||
if values.Get("refresh_token") != "rt" {
|
||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
||||
}
|
||||
if values.Get("client_id") != antigravityOAuthClientID {
|
||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
||||
}
|
||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
||||
t.Fatalf("unexpected client_secret")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-token",
|
||||
"refresh_token": "rt2",
|
||||
"expires_in": int64(3600),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-test.json",
|
||||
FileName: "antigravity-test.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "rt",
|
||||
"expires_in": int64(3600),
|
||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
h := &Handler{authManager: manager}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token, got %q", token)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID(auth.ID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth in manager after update")
|
||||
}
|
||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-valid.json",
|
||||
FileName: "antigravity-valid.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "ok-token",
|
||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
h := &Handler{}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
}
|
||||
if token != "ok-token" {
|
||||
t.Fatalf("expected existing token, got %q", token)
|
||||
}
|
||||
if callCount != 0 {
|
||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
@@ -460,6 +460,12 @@ func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
|
||||
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
|
||||
result["plan_type"] = v
|
||||
}
|
||||
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveStart; v != nil {
|
||||
result["chatgpt_subscription_active_start"] = v
|
||||
}
|
||||
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveUntil; v != nil {
|
||||
result["chatgpt_subscription_active_until"] = v
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -202,6 +202,26 @@ func (h *Handler) PutLoggingToFile(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
|
||||
}
|
||||
|
||||
// LogsMaxTotalSizeMB
|
||||
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
|
||||
}
|
||||
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
|
||||
var body struct {
|
||||
Value *int `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
if value < 0 {
|
||||
value = 0
|
||||
}
|
||||
h.cfg.LogsMaxTotalSizeMB = value
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// Request log
|
||||
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
||||
func (h *Handler) PutRequestLog(c *gin.Context) {
|
||||
@@ -232,6 +252,52 @@ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
|
||||
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
|
||||
}
|
||||
|
||||
// ForceModelPrefix
|
||||
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
|
||||
}
|
||||
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
|
||||
}
|
||||
|
||||
func normalizeRoutingStrategy(strategy string) (string, bool) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(strategy))
|
||||
switch normalized {
|
||||
case "", "round-robin", "roundrobin", "rr":
|
||||
return "round-robin", true
|
||||
case "fill-first", "fillfirst", "ff":
|
||||
return "fill-first", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// RoutingStrategy
|
||||
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
|
||||
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
|
||||
if !ok {
|
||||
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"strategy": strategy})
|
||||
}
|
||||
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
|
||||
var body struct {
|
||||
Value *string `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
normalized, ok := normalizeRoutingStrategy(*body.Value)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
|
||||
return
|
||||
}
|
||||
h.cfg.Routing.Strategy = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// Proxy URL
|
||||
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
||||
func (h *Handler) PutProxyURL(c *gin.Context) {
|
||||
|
||||
@@ -487,6 +487,137 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
||||
c.JSON(400, gin.H{"error": "missing name or index"})
|
||||
}
|
||||
|
||||
// vertex-api-key: []VertexCompatKey
|
||||
func (h *Handler) GetVertexCompatKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey})
|
||||
}
|
||||
func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var arr []config.VertexCompatKey
|
||||
if err = json.Unmarshal(data, &arr); err != nil {
|
||||
var obj struct {
|
||||
Items []config.VertexCompatKey `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
arr = obj.Items
|
||||
}
|
||||
for i := range arr {
|
||||
normalizeVertexCompatKey(&arr[i])
|
||||
}
|
||||
h.cfg.VertexCompatAPIKey = arr
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
||||
type vertexCompatPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
Models *[]config.VertexCompatModel `json:"models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *vertexCompatPatch `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) {
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
if match != "" {
|
||||
for i := range h.cfg.VertexCompatAPIKey {
|
||||
if h.cfg.VertexCompatAPIKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.VertexCompatAPIKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.APIKey)
|
||||
if trimmed == "" {
|
||||
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.APIKey = trimmed
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||
if trimmed == "" {
|
||||
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.BaseURL = trimmed
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
normalizeVertexCompatKey(&entry)
|
||||
h.cfg.VertexCompatAPIKey[targetIndex] = entry
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
||||
for _, v := range h.cfg.VertexCompatAPIKey {
|
||||
if v.APIKey != val {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
h.cfg.VertexCompatAPIKey = out
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if idxStr := c.Query("index"); idxStr != "" {
|
||||
var idx int
|
||||
_, errScan := fmt.Sscanf(idxStr, "%d", &idx)
|
||||
if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) {
|
||||
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...)
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
||||
}
|
||||
|
||||
// oauth-excluded-models: map[string][]string
|
||||
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
|
||||
@@ -572,6 +703,103 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// oauth-model-mappings: map[string][]ModelNameMapping
|
||||
func (h *Handler) GetOAuthModelMappings(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"oauth-model-mappings": sanitizedOAuthModelMappings(h.cfg.OAuthModelMappings)})
|
||||
}
|
||||
|
||||
func (h *Handler) PutOAuthModelMappings(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var entries map[string][]config.ModelNameMapping
|
||||
if err = json.Unmarshal(data, &entries); err != nil {
|
||||
var wrapper struct {
|
||||
Items map[string][]config.ModelNameMapping `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
entries = wrapper.Items
|
||||
}
|
||||
h.cfg.OAuthModelMappings = sanitizedOAuthModelMappings(entries)
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) PatchOAuthModelMappings(c *gin.Context) {
|
||||
var body struct {
|
||||
Provider *string `json:"provider"`
|
||||
Channel *string `json:"channel"`
|
||||
Mappings []config.ModelNameMapping `json:"mappings"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
channelRaw := ""
|
||||
if body.Channel != nil {
|
||||
channelRaw = *body.Channel
|
||||
} else if body.Provider != nil {
|
||||
channelRaw = *body.Provider
|
||||
}
|
||||
channel := strings.ToLower(strings.TrimSpace(channelRaw))
|
||||
if channel == "" {
|
||||
c.JSON(400, gin.H{"error": "invalid channel"})
|
||||
return
|
||||
}
|
||||
|
||||
normalizedMap := sanitizedOAuthModelMappings(map[string][]config.ModelNameMapping{channel: body.Mappings})
|
||||
normalized := normalizedMap[channel]
|
||||
if len(normalized) == 0 {
|
||||
if h.cfg.OAuthModelMappings == nil {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthModelMappings, channel)
|
||||
if len(h.cfg.OAuthModelMappings) == 0 {
|
||||
h.cfg.OAuthModelMappings = nil
|
||||
}
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if h.cfg.OAuthModelMappings == nil {
|
||||
h.cfg.OAuthModelMappings = make(map[string][]config.ModelNameMapping)
|
||||
}
|
||||
h.cfg.OAuthModelMappings[channel] = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteOAuthModelMappings(c *gin.Context) {
|
||||
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
|
||||
if channel == "" {
|
||||
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
|
||||
}
|
||||
if channel == "" {
|
||||
c.JSON(400, gin.H{"error": "missing channel"})
|
||||
return
|
||||
}
|
||||
if h.cfg.OAuthModelMappings == nil {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthModelMappings, channel)
|
||||
if len(h.cfg.OAuthModelMappings) == 0 {
|
||||
h.cfg.OAuthModelMappings = nil
|
||||
}
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// codex-api-key: []CodexKey
|
||||
func (h *Handler) GetCodexKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
|
||||
@@ -789,6 +1017,53 @@ func normalizeCodexKey(entry *config.CodexKey) {
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.Prefix = strings.TrimSpace(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
normalized := make([]config.VertexCompatModel, 0, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
model.Name = strings.TrimSpace(model.Name)
|
||||
model.Alias = strings.TrimSpace(model.Alias)
|
||||
if model.Name == "" || model.Alias == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, model)
|
||||
}
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
func sanitizedOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
copied := make(map[string][]config.ModelNameMapping, len(entries))
|
||||
for channel, mappings := range entries {
|
||||
if len(mappings) == 0 {
|
||||
continue
|
||||
}
|
||||
copied[channel] = append([]config.ModelNameMapping(nil), mappings...)
|
||||
}
|
||||
if len(copied) == 0 {
|
||||
return nil
|
||||
}
|
||||
cfg := config.Config{OAuthModelMappings: copied}
|
||||
cfg.SanitizeOAuthModelMappings()
|
||||
if len(cfg.OAuthModelMappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
return cfg.OAuthModelMappings
|
||||
}
|
||||
|
||||
// GetAmpCode returns the complete ampcode configuration.
|
||||
func (h *Handler) GetAmpCode(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
|
||||
@@ -24,8 +24,15 @@ import (
|
||||
type attemptInfo struct {
|
||||
count int
|
||||
blockedUntil time.Time
|
||||
lastActivity time.Time // track last activity for cleanup
|
||||
}
|
||||
|
||||
// attemptCleanupInterval controls how often stale IP entries are purged
|
||||
const attemptCleanupInterval = 1 * time.Hour
|
||||
|
||||
// attemptMaxIdleTime controls how long an IP can be idle before cleanup
|
||||
const attemptMaxIdleTime = 2 * time.Hour
|
||||
|
||||
// Handler aggregates config reference, persistence path and helpers.
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
@@ -47,7 +54,7 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
|
||||
envSecret = strings.TrimSpace(envSecret)
|
||||
|
||||
return &Handler{
|
||||
h := &Handler{
|
||||
cfg: cfg,
|
||||
configFilePath: configFilePath,
|
||||
failedAttempts: make(map[string]*attemptInfo),
|
||||
@@ -57,6 +64,38 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
allowRemoteOverride: envSecret != "",
|
||||
envSecret: envSecret,
|
||||
}
|
||||
h.startAttemptCleanup()
|
||||
return h
|
||||
}
|
||||
|
||||
// startAttemptCleanup launches a background goroutine that periodically
|
||||
// removes stale IP entries from failedAttempts to prevent memory leaks.
|
||||
func (h *Handler) startAttemptCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(attemptCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
h.purgeStaleAttempts()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeStaleAttempts removes IP entries that have been idle beyond attemptMaxIdleTime
|
||||
// and whose ban (if any) has expired.
|
||||
func (h *Handler) purgeStaleAttempts() {
|
||||
now := time.Now()
|
||||
h.attemptsMu.Lock()
|
||||
defer h.attemptsMu.Unlock()
|
||||
for ip, ai := range h.failedAttempts {
|
||||
// Skip if still banned
|
||||
if !ai.blockedUntil.IsZero() && now.Before(ai.blockedUntil) {
|
||||
continue
|
||||
}
|
||||
// Remove if idle too long
|
||||
if now.Sub(ai.lastActivity) > attemptMaxIdleTime {
|
||||
delete(h.failedAttempts, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
@@ -149,6 +188,7 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
h.failedAttempts[clientIP] = aip
|
||||
}
|
||||
aip.count++
|
||||
aip.lastActivity = time.Now()
|
||||
if aip.count >= maxFailures {
|
||||
aip.blockedUntil = time.Now().Add(banDuration)
|
||||
aip.count = 0
|
||||
|
||||
@@ -69,7 +69,30 @@ func (rw *ResponseRewriter) Flush() {
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
||||
// The Amp client struggles when both thinking and tool_use blocks are present
|
||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||
if filtered.Exists() {
|
||||
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||
filteredCount := filtered.Get("#").Int()
|
||||
|
||||
if originalCount > filteredCount {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||
} else {
|
||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
||||
// Log the result for verification
|
||||
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rw.originalModel == "" {
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -491,6 +492,10 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
|
||||
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
|
||||
|
||||
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
|
||||
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||
|
||||
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
||||
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||
@@ -563,6 +568,14 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
|
||||
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
|
||||
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
||||
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
||||
|
||||
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
|
||||
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
||||
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
||||
|
||||
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
||||
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
||||
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
|
||||
@@ -578,11 +591,21 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
|
||||
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
|
||||
|
||||
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
|
||||
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
|
||||
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
|
||||
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
|
||||
|
||||
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
|
||||
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
|
||||
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
|
||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||
|
||||
mgmt.GET("/oauth-model-mappings", s.mgmt.GetOAuthModelMappings)
|
||||
mgmt.PUT("/oauth-model-mappings", s.mgmt.PutOAuthModelMappings)
|
||||
mgmt.PATCH("/oauth-model-mappings", s.mgmt.PatchOAuthModelMappings)
|
||||
mgmt.DELETE("/oauth-model-mappings", s.mgmt.DeleteOAuthModelMappings)
|
||||
|
||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
@@ -966,8 +989,12 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
log.Warnf("amp module is nil, skipping config update")
|
||||
}
|
||||
|
||||
// Count client sources from configuration and auth directory
|
||||
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
||||
// Count client sources from configuration and auth store.
|
||||
tokenStore := sdkAuth.GetTokenStore()
|
||||
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
|
||||
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||
}
|
||||
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
|
||||
geminiAPIKeyCount := len(cfg.GeminiKey)
|
||||
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
||||
codexAPIKeyCount := len(cfg.CodexKey)
|
||||
@@ -978,10 +1005,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
openAICompatCount += len(entry.APIKeyEntries)
|
||||
}
|
||||
|
||||
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||
total,
|
||||
authFiles,
|
||||
authEntries,
|
||||
geminiAPIKeyCount,
|
||||
claudeAPIKeyCount,
|
||||
codexAPIKeyCount,
|
||||
|
||||
@@ -29,8 +29,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
geminiDefaultCallbackPort = 8085
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -49,8 +50,9 @@ type GeminiAuth struct {
|
||||
|
||||
// WebLoginOptions customizes the interactive OAuth flow.
|
||||
type WebLoginOptions struct {
|
||||
NoBrowser bool
|
||||
Prompt func(string) (string, error)
|
||||
NoBrowser bool
|
||||
CallbackPort int
|
||||
Prompt func(string) (string, error)
|
||||
}
|
||||
|
||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||
@@ -72,6 +74,12 @@ func NewGeminiAuth() *GeminiAuth {
|
||||
// - *http.Client: An HTTP client configured with authentication
|
||||
// - error: An error if the client configuration fails, nil otherwise
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||
callbackPort := geminiDefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
@@ -106,7 +114,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
conf := &oauth2.Config{
|
||||
ClientID: geminiOauthClientID,
|
||||
ClientSecret: geminiOauthClientSecret,
|
||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
||||
RedirectURL: callbackURL, // This will be used by the local server.
|
||||
Scopes: geminiOauthScopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
@@ -218,14 +226,20 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||
// - error: An error if the token acquisition fails, nil otherwise
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||
callbackPort := geminiDefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
codeChan := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// Create a new HTTP server with its own multiplexer.
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{Addr: ":8085", Handler: mux}
|
||||
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
|
||||
config.RedirectURL = callbackURL
|
||||
|
||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.URL.Query().Get("error"); err != "" {
|
||||
@@ -277,13 +291,13 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
// Check if browser is available
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available on this system")
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
} else {
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
|
||||
// Log platform info for debugging
|
||||
@@ -294,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||
}
|
||||
|
||||
|
||||
103
internal/cache/signature_cache.go
vendored
103
internal/cache/signature_cache.go
vendored
@@ -3,7 +3,6 @@ package cache
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -16,21 +15,24 @@ type SignatureEntry struct {
|
||||
|
||||
const (
|
||||
// SignatureCacheTTL is how long signatures are valid
|
||||
SignatureCacheTTL = 1 * time.Hour
|
||||
|
||||
// MaxEntriesPerSession limits memory usage per session
|
||||
MaxEntriesPerSession = 100
|
||||
SignatureCacheTTL = 3 * time.Hour
|
||||
|
||||
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
|
||||
SignatureTextHashLen = 16
|
||||
|
||||
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
||||
MinValidSignatureLen = 50
|
||||
|
||||
// SessionCleanupInterval controls how often stale sessions are purged
|
||||
SessionCleanupInterval = 10 * time.Minute
|
||||
)
|
||||
|
||||
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
|
||||
var signatureCache sync.Map
|
||||
|
||||
// sessionCleanupOnce ensures the background cleanup goroutine starts only once
|
||||
var sessionCleanupOnce sync.Once
|
||||
|
||||
// sessionCache is the inner map type
|
||||
type sessionCache struct {
|
||||
mu sync.RWMutex
|
||||
@@ -45,6 +47,9 @@ func hashText(text string) string {
|
||||
|
||||
// getOrCreateSession gets or creates a session cache
|
||||
func getOrCreateSession(sessionID string) *sessionCache {
|
||||
// Start background cleanup on first access
|
||||
sessionCleanupOnce.Do(startSessionCleanup)
|
||||
|
||||
if val, ok := signatureCache.Load(sessionID); ok {
|
||||
return val.(*sessionCache)
|
||||
}
|
||||
@@ -53,6 +58,40 @@ func getOrCreateSession(sessionID string) *sessionCache {
|
||||
return actual.(*sessionCache)
|
||||
}
|
||||
|
||||
// startSessionCleanup launches a background goroutine that periodically
|
||||
// removes sessions where all entries have expired.
|
||||
func startSessionCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(SessionCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredSessions()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeExpiredSessions removes sessions with no valid (non-expired) entries.
|
||||
func purgeExpiredSessions() {
|
||||
now := time.Now()
|
||||
signatureCache.Range(func(key, value any) bool {
|
||||
sc := value.(*sessionCache)
|
||||
sc.mu.Lock()
|
||||
// Remove expired entries
|
||||
for k, entry := range sc.entries {
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, k)
|
||||
}
|
||||
}
|
||||
isEmpty := len(sc.entries) == 0
|
||||
sc.mu.Unlock()
|
||||
// Remove session if empty
|
||||
if isEmpty {
|
||||
signatureCache.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// CacheSignature stores a thinking signature for a given session and text.
|
||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||
func CacheSignature(sessionID, text, signature string) {
|
||||
@@ -69,43 +108,6 @@ func CacheSignature(sessionID, text, signature string) {
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
// Evict expired entries if at capacity
|
||||
if len(sc.entries) >= MaxEntriesPerSession {
|
||||
now := time.Now()
|
||||
for key, entry := range sc.entries {
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, key)
|
||||
}
|
||||
}
|
||||
// If still at capacity, remove oldest entries
|
||||
if len(sc.entries) >= MaxEntriesPerSession {
|
||||
// Find and remove oldest quarter
|
||||
oldest := make([]struct {
|
||||
key string
|
||||
ts time.Time
|
||||
}, 0, len(sc.entries))
|
||||
for key, entry := range sc.entries {
|
||||
oldest = append(oldest, struct {
|
||||
key string
|
||||
ts time.Time
|
||||
}{key, entry.Timestamp})
|
||||
}
|
||||
// Sort by timestamp (oldest first) using sort.Slice
|
||||
sort.Slice(oldest, func(i, j int) bool {
|
||||
return oldest[i].ts.Before(oldest[j].ts)
|
||||
})
|
||||
|
||||
toRemove := len(oldest) / 4
|
||||
if toRemove < 1 {
|
||||
toRemove = 1
|
||||
}
|
||||
|
||||
for i := 0; i < toRemove; i++ {
|
||||
delete(sc.entries, oldest[i].key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sc.entries[textHash] = SignatureEntry{
|
||||
Signature: signature,
|
||||
Timestamp: time.Now(),
|
||||
@@ -127,22 +129,25 @@ func GetCachedSignature(sessionID, text string) string {
|
||||
|
||||
textHash := hashText(text)
|
||||
|
||||
sc.mu.RLock()
|
||||
entry, exists := sc.entries[textHash]
|
||||
sc.mu.RUnlock()
|
||||
now := time.Now()
|
||||
|
||||
sc.mu.Lock()
|
||||
entry, exists := sc.entries[textHash]
|
||||
if !exists {
|
||||
sc.mu.Unlock()
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Since(entry.Timestamp) > SignatureCacheTTL {
|
||||
sc.mu.Lock()
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, textHash)
|
||||
sc.mu.Unlock()
|
||||
return ""
|
||||
}
|
||||
|
||||
// Refresh TTL on access (sliding expiration).
|
||||
entry.Timestamp = now
|
||||
sc.entries[textHash] = entry
|
||||
sc.mu.Unlock()
|
||||
|
||||
return entry.Signature
|
||||
}
|
||||
|
||||
|
||||
@@ -32,9 +32,10 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
|
||||
@@ -22,9 +22,10 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
||||
|
||||
@@ -24,9 +24,10 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||
|
||||
@@ -67,10 +67,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
}
|
||||
|
||||
loginOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
ProjectID: trimmedProjectID,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: callbackPrompt,
|
||||
NoBrowser: options.NoBrowser,
|
||||
ProjectID: trimmedProjectID,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: callbackPrompt,
|
||||
}
|
||||
|
||||
authenticator := sdkAuth.NewGeminiAuthenticator()
|
||||
@@ -88,8 +89,9 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
geminiAuth := gemini.NewGeminiAuth()
|
||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Prompt: callbackPrompt,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Prompt: callbackPrompt,
|
||||
})
|
||||
if errClient != nil {
|
||||
log.Errorf("Gemini authentication failed: %v", errClient)
|
||||
|
||||
@@ -19,6 +19,9 @@ type LoginOptions struct {
|
||||
// NoBrowser indicates whether to skip opening the browser automatically.
|
||||
NoBrowser bool
|
||||
|
||||
// CallbackPort overrides the local OAuth callback port when set (>0).
|
||||
CallbackPort int
|
||||
|
||||
// Prompt allows the caller to provide interactive input when needed.
|
||||
Prompt func(prompt string) (string, error)
|
||||
}
|
||||
@@ -43,9 +46,10 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
|
||||
@@ -36,9 +36,10 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
|
||||
@@ -145,11 +145,14 @@ type RoutingConfig struct {
|
||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||
}
|
||||
|
||||
// ModelNameMapping defines a model ID rename mapping for a specific channel.
|
||||
// It maps the original model name (Name) to the client-visible alias (Alias).
|
||||
// ModelNameMapping defines a model ID mapping for a specific channel.
|
||||
// It maps the upstream model name (Name) to the client-visible alias (Alias).
|
||||
// When Fork is true, the alias is added as an additional model in listings while
|
||||
// keeping the original model ID available.
|
||||
type ModelNameMapping struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
|
||||
}
|
||||
|
||||
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||
@@ -239,6 +242,10 @@ type ClaudeKey struct {
|
||||
// APIKey is the authentication key for accessing Claude API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
@@ -277,6 +284,10 @@ type CodexKey struct {
|
||||
// APIKey is the authentication key for accessing Codex API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
@@ -315,6 +326,10 @@ type GeminiKey struct {
|
||||
// APIKey is the authentication key for accessing Gemini API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
@@ -352,6 +367,10 @@ type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Priority controls selection preference when multiple providers or credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
@@ -518,7 +537,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
|
||||
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// and ensures (From, To) pairs are unique within each channel.
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
func (cfg *Config) SanitizeOAuthModelMappings() {
|
||||
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
|
||||
return
|
||||
@@ -529,7 +548,6 @@ func (cfg *Config) SanitizeOAuthModelMappings() {
|
||||
if channel == "" || len(mappings) == 0 {
|
||||
continue
|
||||
}
|
||||
seenName := make(map[string]struct{}, len(mappings))
|
||||
seenAlias := make(map[string]struct{}, len(mappings))
|
||||
clean := make([]ModelNameMapping, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
@@ -541,17 +559,12 @@ func (cfg *Config) SanitizeOAuthModelMappings() {
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
nameKey := strings.ToLower(name)
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, ok := seenName[nameKey]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := seenAlias[aliasKey]; ok {
|
||||
continue
|
||||
}
|
||||
seenName[nameKey] = struct{}{}
|
||||
seenAlias[aliasKey] = struct{}{}
|
||||
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
|
||||
clean = append(clean, ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork})
|
||||
}
|
||||
if len(clean) > 0 {
|
||||
out[channel] = clean
|
||||
|
||||
56
internal/config/oauth_model_mappings_test.go
Normal file
56
internal/config/oauth_model_mappings_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeOAuthModelMappings_PreservesForkFlag(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelMappings: map[string][]ModelNameMapping{
|
||||
" CoDeX ": {
|
||||
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
|
||||
{Name: "gpt-6", Alias: "g6"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelMappings()
|
||||
|
||||
mappings := cfg.OAuthModelMappings["codex"]
|
||||
if len(mappings) != 2 {
|
||||
t.Fatalf("expected 2 sanitized mappings, got %d", len(mappings))
|
||||
}
|
||||
if mappings[0].Name != "gpt-5" || mappings[0].Alias != "g5" || !mappings[0].Fork {
|
||||
t.Fatalf("expected first mapping to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", mappings[0].Name, mappings[0].Alias, mappings[0].Fork)
|
||||
}
|
||||
if mappings[1].Name != "gpt-6" || mappings[1].Alias != "g6" || mappings[1].Fork {
|
||||
t.Fatalf("expected second mapping to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", mappings[1].Name, mappings[1].Alias, mappings[1].Fork)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelMappings_AllowsMultipleAliasesForSameName(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelMappings: map[string][]ModelNameMapping{
|
||||
"antigravity": {
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelMappings()
|
||||
|
||||
mappings := cfg.OAuthModelMappings["antigravity"]
|
||||
expected := []ModelNameMapping{
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
|
||||
}
|
||||
if len(mappings) != len(expected) {
|
||||
t.Fatalf("expected %d sanitized mappings, got %d", len(expected), len(mappings))
|
||||
}
|
||||
for i, exp := range expected {
|
||||
if mappings[i].Name != exp.Name || mappings[i].Alias != exp.Alias || mappings[i].Fork != exp.Fork {
|
||||
t.Fatalf("expected mapping %d to be name=%q alias=%q fork=%v, got name=%q alias=%q fork=%v", i, exp.Name, exp.Alias, exp.Fork, mappings[i].Name, mappings[i].Alias, mappings[i].Fork)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,10 @@ type SDKConfig struct {
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
|
||||
// NonStreamKeepAliveInterval controls how often blank lines are emitted for non-streaming responses.
|
||||
// <= 0 disables keep-alives. Value is in seconds.
|
||||
NonStreamKeepAliveInterval int `yaml:"nonstream-keepalive-interval,omitempty" json:"nonstream-keepalive-interval,omitempty"`
|
||||
}
|
||||
|
||||
// StreamingConfig holds server streaming behavior configuration.
|
||||
|
||||
@@ -13,6 +13,10 @@ type VertexCompatKey struct {
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
|
||||
@@ -24,10 +24,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -198,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
return
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
@@ -210,15 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if err := os.MkdirAll(staticDir, 0o755); err != nil {
|
||||
log.WithError(err).Warn("failed to prepare static directory for management asset")
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
@@ -229,6 +239,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
@@ -240,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
@@ -256,6 +280,22 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
}
|
||||
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to download fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
|
||||
return true
|
||||
}
|
||||
|
||||
func resolveReleaseURL(repo string) string {
|
||||
repo = strings.TrimSpace(repo)
|
||||
if repo == "" {
|
||||
|
||||
@@ -7,12 +7,77 @@ import (
|
||||
"embed"
|
||||
_ "embed"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
//go:embed codex_instructions
|
||||
var codexInstructionsDir embed.FS
|
||||
|
||||
func CodexInstructionsForModel(modelName, systemInstructions string) (bool, string) {
|
||||
//go:embed opencode_codex_instructions.txt
|
||||
var opencodeCodexInstructions string
|
||||
|
||||
const (
|
||||
codexUserAgentKey = "__cpa_user_agent"
|
||||
userAgentOpenAISDK = "ai-sdk/openai/"
|
||||
)
|
||||
|
||||
func InjectCodexUserAgent(raw []byte, userAgent string) []byte {
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
trimmed := strings.TrimSpace(userAgent)
|
||||
if trimmed == "" {
|
||||
return raw
|
||||
}
|
||||
updated, err := sjson.SetBytes(raw, codexUserAgentKey, trimmed)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func ExtractCodexUserAgent(raw []byte) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.GetBytes(raw, codexUserAgentKey).String())
|
||||
}
|
||||
|
||||
func StripCodexUserAgent(raw []byte) []byte {
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
if !gjson.GetBytes(raw, codexUserAgentKey).Exists() {
|
||||
return raw
|
||||
}
|
||||
updated, err := sjson.DeleteBytes(raw, codexUserAgentKey)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func codexInstructionsForOpenCode(systemInstructions string) (bool, string) {
|
||||
if opencodeCodexInstructions == "" {
|
||||
return false, ""
|
||||
}
|
||||
if strings.HasPrefix(systemInstructions, opencodeCodexInstructions) {
|
||||
return true, ""
|
||||
}
|
||||
return false, opencodeCodexInstructions
|
||||
}
|
||||
|
||||
func useOpenCodeInstructions(userAgent string) bool {
|
||||
return strings.Contains(strings.ToLower(userAgent), userAgentOpenAISDK)
|
||||
}
|
||||
|
||||
func IsOpenCodeUserAgent(userAgent string) bool {
|
||||
return useOpenCodeInstructions(userAgent)
|
||||
}
|
||||
|
||||
func codexInstructionsForCodex(modelName, systemInstructions string) (bool, string) {
|
||||
entries, _ := codexInstructionsDir.ReadDir("codex_instructions")
|
||||
|
||||
lastPrompt := ""
|
||||
@@ -57,3 +122,10 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
return false, lastPrompt
|
||||
}
|
||||
}
|
||||
|
||||
func CodexInstructionsForModel(modelName, systemInstructions, userAgent string) (bool, string) {
|
||||
if IsOpenCodeUserAgent(userAgent) {
|
||||
return codexInstructionsForOpenCode(systemInstructions)
|
||||
}
|
||||
return codexInstructionsForCodex(modelName, systemInstructions)
|
||||
}
|
||||
|
||||
318
internal/misc/opencode_codex_instructions.txt
Normal file
318
internal/misc/opencode_codex_instructions.txt
Normal file
@@ -0,0 +1,318 @@
|
||||
You are a coding agent running in the opencode, a terminal-based coding assistant. opencode is an open source project. You are expected to be precise, safe, and helpful.
|
||||
|
||||
Your capabilities:
|
||||
|
||||
- Receive user prompts and other context provided by the harness, such as files in the workspace.
|
||||
- Communicate with the user by streaming thinking & responses, and by making & updating plans.
|
||||
- Emit function calls to run terminal commands and apply edits. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section.
|
||||
|
||||
Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI).
|
||||
|
||||
# How you work
|
||||
|
||||
## Personality
|
||||
|
||||
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
|
||||
|
||||
# AGENTS.md spec
|
||||
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
|
||||
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
|
||||
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
|
||||
- Instructions in AGENTS.md files:
|
||||
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
|
||||
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
|
||||
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
|
||||
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
|
||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
||||
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
|
||||
|
||||
## Responsiveness
|
||||
|
||||
### Preamble messages
|
||||
|
||||
Before making tool calls, send a brief preamble to the user explaining what you’re about to do. When sending preamble messages, follow these principles and examples:
|
||||
|
||||
- **Logically group related actions**: if you’re about to run several related commands, describe them together in one preamble rather than sending a separate note for each.
|
||||
- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (8–12 words for quick updates).
|
||||
- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with what’s been done so far and create a sense of momentum and clarity for the user to understand your next actions.
|
||||
- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging.
|
||||
- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless it’s part of a larger grouped action.
|
||||
|
||||
**Examples:**
|
||||
|
||||
- “I’ve explored the repo; now checking the API route definitions.”
|
||||
- “Next, I’ll patch the config and update the related tests.”
|
||||
- “I’m about to scaffold the CLI commands and helper functions.”
|
||||
- “Ok cool, so I’ve wrapped my head around the repo. Now digging into the API routes.”
|
||||
- “Config’s looking tidy. Next up is editing helpers to keep things in sync.”
|
||||
- “Finished poking at the DB gateway. I will now chase down error handling.”
|
||||
- “Alright, build pipeline order is interesting. Checking how it reports failures.”
|
||||
- “Spotted a clever caching util; now hunting where it gets used.”
|
||||
|
||||
## Planning
|
||||
|
||||
You have access to an `todowrite` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go.
|
||||
|
||||
Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately.
|
||||
|
||||
Do not repeat the full contents of the plan after an `todowrite` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step.
|
||||
|
||||
Before running a command, consider whether or not you have completed the
|
||||
previous step, and make sure to mark it as completed before moving on to the
|
||||
next step. It may be the case that you complete all steps in your plan after a
|
||||
single pass of implementation. If this is the case, you can simply mark all the
|
||||
planned steps as completed. Sometimes, you may need to change plans in the
|
||||
middle of a task: call `todowrite` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
|
||||
|
||||
Use a plan when:
|
||||
|
||||
- The task is non-trivial and will require multiple actions over a long time horizon.
|
||||
- There are logical phases or dependencies where sequencing matters.
|
||||
- The work has ambiguity that benefits from outlining high-level goals.
|
||||
- You want intermediate checkpoints for feedback and validation.
|
||||
- When the user asked you to do more than one thing in a single prompt
|
||||
- The user has asked you to use the plan tool (aka "TODOs")
|
||||
- You generate additional steps while working, and plan to do them before yielding to the user
|
||||
|
||||
### Examples
|
||||
|
||||
**High-quality plans**
|
||||
|
||||
Example 1:
|
||||
|
||||
1. Add CLI entry with file args
|
||||
2. Parse Markdown via CommonMark library
|
||||
3. Apply semantic HTML template
|
||||
4. Handle code blocks, images, links
|
||||
5. Add error handling for invalid files
|
||||
|
||||
Example 2:
|
||||
|
||||
1. Define CSS variables for colors
|
||||
2. Add toggle with localStorage state
|
||||
3. Refactor components to use variables
|
||||
4. Verify all views for readability
|
||||
5. Add smooth theme-change transition
|
||||
|
||||
Example 3:
|
||||
|
||||
1. Set up Node.js + WebSocket server
|
||||
2. Add join/leave broadcast events
|
||||
3. Implement messaging with timestamps
|
||||
4. Add usernames + mention highlighting
|
||||
5. Persist messages in lightweight DB
|
||||
6. Add typing indicators + unread count
|
||||
|
||||
**Low-quality plans**
|
||||
|
||||
Example 1:
|
||||
|
||||
1. Create CLI tool
|
||||
2. Add Markdown parser
|
||||
3. Convert to HTML
|
||||
|
||||
Example 2:
|
||||
|
||||
1. Add dark mode toggle
|
||||
2. Save preference
|
||||
3. Make styles look good
|
||||
|
||||
Example 3:
|
||||
|
||||
1. Create single-file HTML game
|
||||
2. Run quick sanity check
|
||||
3. Summarize usage instructions
|
||||
|
||||
If you need to write a plan, only write high quality plans, not low quality ones.
|
||||
|
||||
## Task execution
|
||||
|
||||
You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer.
|
||||
|
||||
You MUST adhere to the following criteria when solving queries:
|
||||
|
||||
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
|
||||
- Analyzing code for vulnerabilities is allowed.
|
||||
- Showing user code and tool call details is allowed.
|
||||
- Use the `edit` tool to edit files
|
||||
|
||||
If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines:
|
||||
|
||||
- Fix the problem at the root cause rather than applying surface-level patches, when possible.
|
||||
- Avoid unneeded complexity in your solution.
|
||||
- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
|
||||
- Update documentation as necessary.
|
||||
- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
|
||||
- Use `git log` and `git blame` to search the history of the codebase if additional context is required.
|
||||
- NEVER add copyright or license headers unless specifically requested.
|
||||
- Do not waste tokens by re-reading files after calling `edit` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc.
|
||||
- Do not `git commit` your changes or create new git branches unless explicitly requested.
|
||||
- Do not add inline comments within code unless explicitly requested.
|
||||
- Do not use one-letter variable names unless explicitly requested.
|
||||
- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor.
|
||||
|
||||
## Sandbox and approvals
|
||||
|
||||
The Codex CLI harness supports several different sandboxing, and approval configurations that the user can choose from.
|
||||
|
||||
Filesystem sandboxing prevents you from editing files without user approval. The options are:
|
||||
|
||||
- **read-only**: You can only read files.
|
||||
- **workspace-write**: You can read files. You can write to files in your workspace folder, but not outside it.
|
||||
- **danger-full-access**: No filesystem sandboxing.
|
||||
|
||||
Network sandboxing prevents you from accessing network without approval. Options are
|
||||
|
||||
- **restricted**
|
||||
- **enabled**
|
||||
|
||||
Approvals are your mechanism to get user consent to perform more privileged actions. Although they introduce friction to the user because your work is paused until the user responds, you should leverage them to accomplish your important work. Do not let these settings or the sandbox deter you from attempting to accomplish the user's task. Approval options are
|
||||
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is pared with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with approvals `on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /tmp)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (For all of these, you should weigh alternative paths that do not require approval.)
|
||||
|
||||
Note that when sandboxing is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing ON, and approval on-failure.
|
||||
|
||||
## Validating your work
|
||||
|
||||
If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete.
|
||||
|
||||
When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests.
|
||||
|
||||
Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one.
|
||||
|
||||
For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
|
||||
|
||||
Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance:
|
||||
|
||||
- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task.
|
||||
- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first.
|
||||
- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task.
|
||||
|
||||
## Ambition vs. precision
|
||||
|
||||
For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation.
|
||||
|
||||
If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature.
|
||||
|
||||
You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified.
|
||||
|
||||
## Sharing progress updates
|
||||
|
||||
For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next.
|
||||
|
||||
Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why.
|
||||
|
||||
The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the user’s style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges.
|
||||
|
||||
You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multisection structured responses for results that need grouping or explanation.
|
||||
|
||||
The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `edit`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path.
|
||||
|
||||
If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If there’s something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly.
|
||||
|
||||
Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
**Section Headers**
|
||||
|
||||
- Use only when they improve clarity — they are not mandatory for every answer.
|
||||
- Choose descriptive names that fit the content
|
||||
- Keep headers short (1–3 words) and in `**Title Case**`. Always start headers with `**` and end with `**`
|
||||
- Leave no blank line before the first bullet under a header.
|
||||
- Section headers should only be used where they genuinely improve scannability; avoid fragmenting the answer.
|
||||
|
||||
**Bullets**
|
||||
|
||||
- Use `-` followed by a space for every bullet.
|
||||
- Merge related points when possible; avoid a bullet for every trivial detail.
|
||||
- Keep bullets to one line unless breaking for clarity is unavoidable.
|
||||
- Group into short lists (4–6 bullets) ordered by importance.
|
||||
- Use consistent keyword phrasing and formatting across sections.
|
||||
|
||||
**Monospace**
|
||||
|
||||
- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``).
|
||||
- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.
|
||||
- Never mix monospace and bold markers; choose one based on whether it’s a keyword (`**`) or inline code/path (`` ` ``).
|
||||
|
||||
**File References**
|
||||
When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a standalone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Line/column (1‑based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
|
||||
**Structure**
|
||||
|
||||
- Place related bullets together; don’t mix unrelated concepts in the same section.
|
||||
- Order sections from general → specific → supporting info.
|
||||
- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it.
|
||||
- Match structure to complexity:
|
||||
- Multi-part or detailed results → use clear headers and grouped bullets.
|
||||
- Simple results → minimal headers, possibly just a short list or paragraph.
|
||||
|
||||
**Tone**
|
||||
|
||||
- Keep the voice collaborative and natural, like a coding partner handing off work.
|
||||
- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition
|
||||
- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”).
|
||||
- Keep descriptions self-contained; don’t refer to “above” or “below”.
|
||||
- Use parallel structure in lists for consistency.
|
||||
|
||||
**Don’t**
|
||||
|
||||
- Don’t use literal words “bold” or “monospace” in the content.
|
||||
- Don’t nest bullets or create deep hierarchies.
|
||||
- Don’t output ANSI escape codes directly — the CLI renderer applies them.
|
||||
- Don’t cram unrelated keywords into a single bullet; split for clarity.
|
||||
- Don’t let keyword lists run long — wrap or reformat for scannability.
|
||||
|
||||
Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with what’s needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable.
|
||||
|
||||
For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting.
|
||||
|
||||
# Tool Guidelines
|
||||
|
||||
## Shell commands
|
||||
|
||||
When using the shell, you must adhere to the following guidelines:
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
- Read files in chunks with a max chunk size of 250 lines. Do not use python scripts to attempt to output larger chunks of a file. Command line output will be truncated after 10 kilobytes or 256 lines of output, regardless of the command used.
|
||||
|
||||
## `todowrite`
|
||||
|
||||
A tool named `todowrite` is available to you. You can use it to keep an up‑to‑date, step‑by‑step plan for the task.
|
||||
|
||||
To create a new plan, call `todowrite` with a short list of 1‑sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`).
|
||||
|
||||
When steps have been completed, use `todowrite` to mark each finished step as
|
||||
`completed` and the next step you are working on as `in_progress`. There should
|
||||
always be exactly one `in_progress` step until everything is done. You can mark
|
||||
multiple items as complete in a single `todowrite` call.
|
||||
|
||||
If all steps are complete, ensure you call `todowrite` to mark all steps as `completed`.
|
||||
@@ -740,7 +740,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
@@ -773,7 +773,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
return map[string]*AntigravityModelConfig{
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -84,6 +85,13 @@ type ModelRegistration struct {
|
||||
SuspendedClients map[string]string
|
||||
}
|
||||
|
||||
// ModelRegistryHook provides optional callbacks for external integrations to track model list changes.
|
||||
// Hook implementations must be non-blocking and resilient; calls are executed asynchronously and panics are recovered.
|
||||
type ModelRegistryHook interface {
|
||||
OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo)
|
||||
OnModelsUnregistered(ctx context.Context, provider, clientID string)
|
||||
}
|
||||
|
||||
// ModelRegistry manages the global registry of available models
|
||||
type ModelRegistry struct {
|
||||
// models maps model ID to registration information
|
||||
@@ -97,6 +105,8 @@ type ModelRegistry struct {
|
||||
clientProviders map[string]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
mutex *sync.RWMutex
|
||||
// hook is an optional callback sink for model registration changes
|
||||
hook ModelRegistryHook
|
||||
}
|
||||
|
||||
// Global model registry instance
|
||||
@@ -117,6 +127,53 @@ func GetGlobalRegistry() *ModelRegistry {
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
// SetHook sets an optional hook for observing model registration changes.
|
||||
func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.hook = hook
|
||||
}
|
||||
|
||||
const defaultModelRegistryHookTimeout = 5 * time.Second
|
||||
|
||||
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
||||
hook := r.hook
|
||||
if hook == nil {
|
||||
return
|
||||
}
|
||||
modelsCopy := cloneModelInfosUnique(models)
|
||||
go func() {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
log.Errorf("model registry hook OnModelsRegistered panic: %v", recovered)
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
|
||||
defer cancel()
|
||||
hook.OnModelsRegistered(ctx, provider, clientID, modelsCopy)
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
|
||||
hook := r.hook
|
||||
if hook == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
log.Errorf("model registry hook OnModelsUnregistered panic: %v", recovered)
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
|
||||
defer cancel()
|
||||
hook.OnModelsUnregistered(ctx, provider, clientID)
|
||||
}()
|
||||
}
|
||||
|
||||
// RegisterClient registers a client and its supported models
|
||||
// Parameters:
|
||||
// - clientID: Unique identifier for the client
|
||||
@@ -177,6 +234,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
} else {
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
r.triggerModelsRegistered(provider, clientID, models)
|
||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||
misc.LogCredentialSeparator()
|
||||
return
|
||||
@@ -310,6 +368,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
|
||||
r.triggerModelsRegistered(provider, clientID, models)
|
||||
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||
return
|
||||
@@ -400,6 +459,25 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
||||
return ©Model
|
||||
}
|
||||
|
||||
func cloneModelInfosUnique(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make([]*ModelInfo, 0, len(models))
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil || model.ID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[model.ID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[model.ID] = struct{}{}
|
||||
cloned = append(cloned, cloneModelInfo(model))
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
// UnregisterClient removes a client and decrements counts for its models
|
||||
// Parameters:
|
||||
// - clientID: Unique identifier for the client to remove
|
||||
@@ -460,6 +538,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
log.Debugf("Unregistered client %s", clientID)
|
||||
// Separator line after completing client unregistration (after the summary line)
|
||||
misc.LogCredentialSeparator()
|
||||
r.triggerModelsUnregistered(provider, clientID)
|
||||
}
|
||||
|
||||
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
|
||||
@@ -625,6 +704,131 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
||||
return models
|
||||
}
|
||||
|
||||
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
||||
// Parameters:
|
||||
// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity")
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: List of available models for the provider
|
||||
func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
type providerModel struct {
|
||||
count int
|
||||
info *ModelInfo
|
||||
}
|
||||
|
||||
providerModels := make(map[string]*providerModel)
|
||||
|
||||
for clientID, clientProvider := range r.clientProviders {
|
||||
if clientProvider != provider {
|
||||
continue
|
||||
}
|
||||
modelIDs := r.clientModels[clientID]
|
||||
if len(modelIDs) == 0 {
|
||||
continue
|
||||
}
|
||||
clientInfos := r.clientModelInfos[clientID]
|
||||
for _, modelID := range modelIDs {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
entry := providerModels[modelID]
|
||||
if entry == nil {
|
||||
entry = &providerModel{}
|
||||
providerModels[modelID] = entry
|
||||
}
|
||||
entry.count++
|
||||
if entry.info == nil {
|
||||
if clientInfos != nil {
|
||||
if info := clientInfos[modelID]; info != nil {
|
||||
entry.info = info
|
||||
}
|
||||
}
|
||||
if entry.info == nil {
|
||||
if reg, ok := r.models[modelID]; ok && reg != nil && reg.Info != nil {
|
||||
entry.info = reg.Info
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(providerModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
now := time.Now()
|
||||
result := make([]*ModelInfo, 0, len(providerModels))
|
||||
|
||||
for modelID, entry := range providerModels {
|
||||
if entry == nil || entry.count <= 0 {
|
||||
continue
|
||||
}
|
||||
registration, ok := r.models[modelID]
|
||||
|
||||
expiredClients := 0
|
||||
cooldownSuspended := 0
|
||||
otherSuspended := 0
|
||||
if ok && registration != nil {
|
||||
if registration.QuotaExceededClients != nil {
|
||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||
continue
|
||||
}
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
}
|
||||
if registration.SuspendedClients != nil {
|
||||
for clientID, reason := range registration.SuspendedClients {
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(reason, "quota") {
|
||||
cooldownSuspended++
|
||||
continue
|
||||
}
|
||||
otherSuspended++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
availableClients := entry.count
|
||||
effectiveClients := availableClients - expiredClients - otherSuspended
|
||||
if effectiveClients < 0 {
|
||||
effectiveClients = 0
|
||||
}
|
||||
|
||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||
if entry.info != nil {
|
||||
result = append(result, entry.info)
|
||||
continue
|
||||
}
|
||||
if ok && registration != nil && registration.Info != nil {
|
||||
result = append(result, registration.Info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetModelCount returns the number of available clients for a specific model
|
||||
// Parameters:
|
||||
// - modelID: The model ID to check
|
||||
|
||||
204
internal/registry/model_registry_hook_test.go
Normal file
204
internal/registry/model_registry_hook_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestModelRegistry() *ModelRegistry {
|
||||
return &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
type registeredCall struct {
|
||||
provider string
|
||||
clientID string
|
||||
models []*ModelInfo
|
||||
}
|
||||
|
||||
type unregisteredCall struct {
|
||||
provider string
|
||||
clientID string
|
||||
}
|
||||
|
||||
type capturingHook struct {
|
||||
registeredCh chan registeredCall
|
||||
unregisteredCh chan unregisteredCall
|
||||
}
|
||||
|
||||
func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
||||
h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models}
|
||||
}
|
||||
|
||||
func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
|
||||
h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID}
|
||||
}
|
||||
|
||||
func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
hook := &capturingHook{
|
||||
registeredCh: make(chan registeredCall, 1),
|
||||
unregisteredCh: make(chan unregisteredCall, 1),
|
||||
}
|
||||
r.SetHook(hook)
|
||||
|
||||
inputModels := []*ModelInfo{
|
||||
{ID: "m1", DisplayName: "Model One"},
|
||||
{ID: "m2", DisplayName: "Model Two"},
|
||||
}
|
||||
r.RegisterClient("client-1", "OpenAI", inputModels)
|
||||
|
||||
select {
|
||||
case call := <-hook.registeredCh:
|
||||
if call.provider != "openai" {
|
||||
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
|
||||
}
|
||||
if call.clientID != "client-1" {
|
||||
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
|
||||
}
|
||||
if len(call.models) != 2 {
|
||||
t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2)
|
||||
}
|
||||
if call.models[0] == nil || call.models[0].ID != "m1" {
|
||||
t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1")
|
||||
}
|
||||
if call.models[1] == nil || call.models[1].ID != "m2" {
|
||||
t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
hook := &capturingHook{
|
||||
registeredCh: make(chan registeredCall, 1),
|
||||
unregisteredCh: make(chan unregisteredCall, 1),
|
||||
}
|
||||
r.SetHook(hook)
|
||||
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
||||
select {
|
||||
case <-hook.registeredCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
||||
}
|
||||
|
||||
r.UnregisterClient("client-1")
|
||||
|
||||
select {
|
||||
case call := <-hook.unregisteredCh:
|
||||
if call.provider != "openai" {
|
||||
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
|
||||
}
|
||||
if call.clientID != "client-1" {
|
||||
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
|
||||
}
|
||||
}
|
||||
|
||||
type blockingHook struct {
|
||||
started chan struct{}
|
||||
unblock chan struct{}
|
||||
}
|
||||
|
||||
func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
||||
select {
|
||||
case <-h.started:
|
||||
default:
|
||||
close(h.started)
|
||||
}
|
||||
<-h.unblock
|
||||
}
|
||||
|
||||
func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {}
|
||||
|
||||
func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
hook := &blockingHook{
|
||||
started: make(chan struct{}),
|
||||
unblock: make(chan struct{}),
|
||||
}
|
||||
r.SetHook(hook)
|
||||
defer close(hook.unblock)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-hook.started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for hook to start")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("RegisterClient appears to be blocked by hook")
|
||||
}
|
||||
|
||||
if !r.ClientSupportsModel("client-1", "m1") {
|
||||
t.Fatal("model registration failed; expected client to support model")
|
||||
}
|
||||
}
|
||||
|
||||
type panicHook struct {
|
||||
registeredCalled chan struct{}
|
||||
unregisteredCalled chan struct{}
|
||||
}
|
||||
|
||||
func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
||||
if h.registeredCalled != nil {
|
||||
h.registeredCalled <- struct{}{}
|
||||
}
|
||||
panic("boom")
|
||||
}
|
||||
|
||||
func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
|
||||
if h.unregisteredCalled != nil {
|
||||
h.unregisteredCalled <- struct{}{}
|
||||
}
|
||||
panic("boom")
|
||||
}
|
||||
|
||||
func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) {
|
||||
r := newTestModelRegistry()
|
||||
hook := &panicHook{
|
||||
registeredCalled: make(chan struct{}, 1),
|
||||
unregisteredCalled: make(chan struct{}, 1),
|
||||
}
|
||||
r.SetHook(hook)
|
||||
|
||||
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
||||
|
||||
select {
|
||||
case <-hook.registeredCalled:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
||||
}
|
||||
|
||||
if !r.ClientSupportsModel("client-1", "m1") {
|
||||
t.Fatal("model registration failed; expected client to support model")
|
||||
}
|
||||
|
||||
r.UnregisterClient("client-1")
|
||||
|
||||
select {
|
||||
case <-hook.unregisteredCalled:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -50,6 +51,64 @@ func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest forwards an arbitrary HTTP request through the websocket relay.
|
||||
func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("aistudio executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
if e.relay == nil {
|
||||
return nil, fmt.Errorf("aistudio executor: ws relay is nil")
|
||||
}
|
||||
if auth == nil || auth.ID == "" {
|
||||
return nil, fmt.Errorf("aistudio executor: missing auth")
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
|
||||
return nil, fmt.Errorf("aistudio executor: request URL is empty")
|
||||
}
|
||||
|
||||
var body []byte
|
||||
if httpReq.Body != nil {
|
||||
b, errRead := io.ReadAll(httpReq.Body)
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
body = b
|
||||
httpReq.Body = io.NopCloser(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: httpReq.Method,
|
||||
URL: httpReq.URL.String(),
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
}
|
||||
wsResp, errRelay := e.relay.NonStream(ctx, auth.ID, wsReq)
|
||||
if errRelay != nil {
|
||||
return nil, errRelay
|
||||
}
|
||||
if wsResp == nil {
|
||||
return nil, fmt.Errorf("aistudio executor: ws response is nil")
|
||||
}
|
||||
|
||||
statusText := http.StatusText(wsResp.Status)
|
||||
if statusText == "" {
|
||||
statusText = "Unknown"
|
||||
}
|
||||
resp := &http.Response{
|
||||
StatusCode: wsResp.Status,
|
||||
Status: fmt.Sprintf("%d %s", wsResp.Status, statusText),
|
||||
Header: wsResp.Headers.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(wsResp.Body)),
|
||||
ContentLength: int64(len(wsResp.Body)),
|
||||
Request: httpReq,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request to the AI Studio API.
|
||||
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
@@ -323,6 +382,11 @@ type translatedPayload struct {
|
||||
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
|
||||
@@ -331,7 +395,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
||||
payload = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", payload, originalTranslated)
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -45,6 +47,7 @@ const (
|
||||
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -71,13 +74,42 @@ func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
|
||||
|
||||
// PrepareRequest prepares the HTTP request for execution (no-op for Antigravity).
|
||||
func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||
// PrepareRequest injects Antigravity credentials into the outgoing HTTP request.
|
||||
func (e *AntigravityExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token, _, errToken := e.ensureAccessToken(req.Context(), auth)
|
||||
if errToken != nil {
|
||||
return errToken
|
||||
}
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Antigravity credentials into the request and executes it.
|
||||
func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("antigravity executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
if isClaude {
|
||||
if isClaude || strings.Contains(req.Model, "gemini-3-pro") {
|
||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
@@ -94,13 +126,18 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -119,6 +156,9 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return resp, errDo
|
||||
}
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
@@ -151,7 +191,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -165,7 +211,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
if lastStatus == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
@@ -189,13 +241,18 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, true)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -214,6 +271,9 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return resp, errDo
|
||||
}
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
@@ -232,6 +292,14 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
if errCtx := ctx.Err(); errCtx != nil {
|
||||
err = errCtx
|
||||
return resp, err
|
||||
}
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errRead
|
||||
@@ -250,7 +318,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -315,7 +389,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
if lastStatus == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
@@ -525,13 +605,18 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -550,6 +635,9 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return nil, errDo
|
||||
}
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
@@ -568,6 +656,14 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||
err = errRead
|
||||
return nil, err
|
||||
}
|
||||
if errCtx := ctx.Err(); errCtx != nil {
|
||||
err = errCtx
|
||||
return nil, err
|
||||
}
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errRead
|
||||
@@ -586,7 +682,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -641,7 +743,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
if lastStatus == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
err = sErr
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
@@ -697,8 +805,8 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
@@ -744,6 +852,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
@@ -778,12 +889,24 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
return cliproxyexecutor.Response{}, sErr
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
if lastStatus == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
return cliproxyexecutor.Response{}, sErr
|
||||
case lastErr != nil:
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
default:
|
||||
@@ -820,6 +943,9 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
@@ -899,7 +1025,13 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr
|
||||
if accessToken != "" && expiry.After(time.Now().Add(refreshSkew)) {
|
||||
return accessToken, nil, nil
|
||||
}
|
||||
updated, errRefresh := e.refreshToken(ctx, auth.Clone())
|
||||
refreshCtx := context.Background()
|
||||
if ctx != nil {
|
||||
if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil {
|
||||
refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
}
|
||||
updated, errRefresh := e.refreshToken(refreshCtx, auth.Clone())
|
||||
if errRefresh != nil {
|
||||
return "", nil, errRefresh
|
||||
}
|
||||
@@ -946,7 +1078,13 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
|
||||
}
|
||||
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
return auth, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||
sErr.retryAfter = retryAfter
|
||||
}
|
||||
}
|
||||
return auth, sErr
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
@@ -967,12 +1105,49 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
|
||||
auth.Metadata["refresh_token"] = tokenResp.RefreshToken
|
||||
}
|
||||
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
|
||||
auth.Metadata["timestamp"] = time.Now().UnixMilli()
|
||||
auth.Metadata["expired"] = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
now := time.Now()
|
||||
auth.Metadata["timestamp"] = now.UnixMilli()
|
||||
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
auth.Metadata["type"] = antigravityAuthType
|
||||
if errProject := e.ensureAntigravityProjectID(ctx, auth, tokenResp.AccessToken); errProject != nil {
|
||||
log.Warnf("antigravity executor: ensure project id failed: %v", errProject)
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) error {
|
||||
if auth == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if auth.Metadata["project_id"] != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(accessToken)
|
||||
if token == "" {
|
||||
token = metaStringValue(auth.Metadata, "access_token")
|
||||
}
|
||||
if token == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
|
||||
if errFetch != nil {
|
||||
return errFetch
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return nil
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["project_id"] = strings.TrimSpace(projectID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
|
||||
if token == "" {
|
||||
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
@@ -1026,6 +1201,19 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-preview") {
|
||||
systemInstructionPartsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
|
||||
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(partResult.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
@@ -1160,8 +1348,8 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
||||
return []string{base}
|
||||
}
|
||||
return []string{
|
||||
antigravityBaseURLDaily,
|
||||
antigravitySandboxBaseURLDaily,
|
||||
antigravityBaseURLDaily,
|
||||
antigravityBaseURLProd,
|
||||
}
|
||||
}
|
||||
@@ -1189,6 +1377,7 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
||||
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||
template, _ = sjson.Set(template, "requestType", "agent")
|
||||
|
||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||
if projectID != "" {
|
||||
|
||||
@@ -1,10 +1,68 @@
|
||||
package executor
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type codexCache struct {
|
||||
ID string
|
||||
Expire time.Time
|
||||
}
|
||||
|
||||
var codexCacheMap = map[string]codexCache{}
|
||||
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
|
||||
// Protected by codexCacheMu. Entries expire after 1 hour.
|
||||
var (
|
||||
codexCacheMap = make(map[string]codexCache)
|
||||
codexCacheMu sync.RWMutex
|
||||
)
|
||||
|
||||
// codexCacheCleanupInterval controls how often expired entries are purged.
|
||||
const codexCacheCleanupInterval = 15 * time.Minute
|
||||
|
||||
// codexCacheCleanupOnce ensures the background cleanup goroutine starts only once.
|
||||
var codexCacheCleanupOnce sync.Once
|
||||
|
||||
// startCodexCacheCleanup launches a background goroutine that periodically
|
||||
// removes expired entries from codexCacheMap to prevent memory leaks.
|
||||
func startCodexCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(codexCacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredCodexCache()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeExpiredCodexCache removes entries that have expired.
|
||||
func purgeExpiredCodexCache() {
|
||||
now := time.Now()
|
||||
codexCacheMu.Lock()
|
||||
defer codexCacheMu.Unlock()
|
||||
for key, cache := range codexCacheMap {
|
||||
if cache.Expire.Before(now) {
|
||||
delete(codexCacheMap, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
|
||||
func getCodexCache(key string) (codexCache, bool) {
|
||||
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||
codexCacheMu.RLock()
|
||||
cache, ok := codexCacheMap[key]
|
||||
codexCacheMu.RUnlock()
|
||||
if !ok || cache.Expire.Before(time.Now()) {
|
||||
return codexCache{}, false
|
||||
}
|
||||
return cache, true
|
||||
}
|
||||
|
||||
// setCodexCache stores a cache entry.
|
||||
func setCodexCache(key string, cache codexCache) {
|
||||
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||
codexCacheMu.Lock()
|
||||
codexCacheMap[key] = cache
|
||||
codexCacheMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -35,11 +35,53 @@ type ClaudeExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
const claudeToolPrefix = "proxy_"
|
||||
|
||||
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
||||
|
||||
func (e *ClaudeExecutor) Identifier() string { return "claude" }
|
||||
|
||||
func (e *ClaudeExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||
// PrepareRequest injects Claude credentials into the outgoing HTTP request.
|
||||
func (e *ClaudeExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey, _ := claudeCreds(auth)
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
return nil
|
||||
}
|
||||
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||
isAnthropicBase := req.URL != nil && strings.EqualFold(req.URL.Scheme, "https") && strings.EqualFold(req.URL.Host, "api.anthropic.com")
|
||||
if isAnthropicBase && useAPIKey {
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
} else {
|
||||
req.Header.Del("x-api-key")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Claude credentials into the request and executes it.
|
||||
func (e *ClaudeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("claude executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
apiKey, baseURL := claudeCreds(auth)
|
||||
@@ -57,6 +99,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
// Inject thinking config based on model metadata for thinking variants
|
||||
@@ -65,7 +112,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -76,9 +123,14 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -93,7 +145,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Body: bodyForUpstream,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -147,8 +199,20 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
} else {
|
||||
reporter.publish(ctx, parseClaudeUsage(data))
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||
}
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(
|
||||
ctx,
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
bodyForTranslation,
|
||||
data,
|
||||
¶m,
|
||||
)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -167,12 +231,17 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
// Inject thinking config based on model metadata for thinking variants
|
||||
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||
body = checkSystemInstructions(body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -183,9 +252,14 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -200,7 +274,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Body: bodyForUpstream,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -253,6 +327,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
// Forward the line as-is to preserve SSE format
|
||||
cloned := make([]byte, len(line)+1)
|
||||
copy(cloned, line)
|
||||
@@ -277,7 +354,19 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseClaudeStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(
|
||||
ctx,
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
bodyForTranslation,
|
||||
bytes.Clone(line),
|
||||
¶m,
|
||||
)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -316,6 +405,9 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -760,3 +852,107 @@ func checkSystemInstructions(payload []byte) []byte {
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func isClaudeOAuthToken(apiKey string) bool {
|
||||
return strings.Contains(apiKey, "sk-ant-oat")
|
||||
}
|
||||
|
||||
func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
if prefix == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||
name := tool.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("tools.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
|
||||
name := gjson.GetBytes(body, "tool_choice.name").String()
|
||||
if name != "" && !strings.HasPrefix(name, prefix) {
|
||||
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
|
||||
}
|
||||
}
|
||||
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
messages.ForEach(func(msgIndex, msg gjson.Result) bool {
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
return true
|
||||
}
|
||||
content.ForEach(func(contentIndex, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, prefix+name)
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
|
||||
if prefix == "" {
|
||||
return body
|
||||
}
|
||||
content := gjson.GetBytes(body, "content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
return body
|
||||
}
|
||||
content.ForEach(func(index, part gjson.Result) bool {
|
||||
if part.Get("type").String() != "tool_use" {
|
||||
return true
|
||||
}
|
||||
name := part.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
}
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
|
||||
return true
|
||||
})
|
||||
return body
|
||||
}
|
||||
|
||||
func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
||||
if prefix == "" {
|
||||
return line
|
||||
}
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return line
|
||||
}
|
||||
contentBlock := gjson.GetBytes(payload, "content_block")
|
||||
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
|
||||
return line
|
||||
}
|
||||
name := contentBlock.Get("name").String()
|
||||
if !strings.HasPrefix(name, prefix) {
|
||||
return line
|
||||
}
|
||||
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
return append([]byte("data: "), updated...)
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
51
internal/runtime/executor/claude_executor_test.go
Normal file
51
internal/runtime/executor/claude_executor_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_alpha" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_alpha")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_bravo" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_bravo")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "proxy_charlie" {
|
||||
t.Fatalf("tool_choice.name = %q, want %q", got, "proxy_charlie")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_delta" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_delta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "content.0.name").String(); got != "alpha" {
|
||||
t.Fatalf("content.0.name = %q, want %q", got, "alpha")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "content.1.name").String(); got != "bravo" {
|
||||
t.Fatalf("content.1.name = %q, want %q", got, "bravo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
|
||||
payload := bytes.TrimSpace(out)
|
||||
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "content_block.name").String(); got != "alpha" {
|
||||
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,38 @@ func NewCodexExecutor(cfg *config.Config) *CodexExecutor { return &CodexExecutor
|
||||
|
||||
func (e *CodexExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *CodexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||
// PrepareRequest injects Codex credentials into the outgoing HTTP request.
|
||||
func (e *CodexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey, _ := codexCreds(auth)
|
||||
if strings.TrimSpace(apiKey) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Codex credentials into the request and executes it.
|
||||
func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("codex executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
apiKey, baseURL := codexCreds(auth)
|
||||
@@ -56,16 +87,26 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
userAgent := codexUserAgent(ctx)
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload = misc.InjectCodexUserAgent(originalPayload, userAgent)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
|
||||
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
|
||||
body = sdktranslator.TranslateRequest(from, to, model, body, false)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -132,7 +173,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, line, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -156,15 +197,25 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
userAgent := codexUserAgent(ctx)
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload = misc.InjectCodexUserAgent(originalPayload, userAgent)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
|
||||
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
|
||||
body = sdktranslator.TranslateRequest(from, to, model, body, true)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
@@ -237,7 +288,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -259,11 +310,15 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
userAgent := codexUserAgent(ctx)
|
||||
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
|
||||
body = sdktranslator.TranslateRequest(from, to, model, body, false)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
enc, err := tokenizerForCodexModel(model)
|
||||
@@ -447,14 +502,14 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
var hasKey bool
|
||||
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
||||
if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) {
|
||||
var ok bool
|
||||
if cache, ok = getCodexCache(key); !ok {
|
||||
cache = codexCache{
|
||||
ID: uuid.New().String(),
|
||||
Expire: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
codexCacheMap[key] = cache
|
||||
setCodexCache(key, cache)
|
||||
}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
@@ -512,6 +567,16 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string) {
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
}
|
||||
|
||||
func codexUserAgent(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
return strings.TrimSpace(ginCtx.Request.UserAgent())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
@@ -63,8 +63,42 @@ func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor {
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini CLI).
|
||||
func (e *GeminiCLIExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||
// PrepareRequest injects Gemini CLI credentials into the outgoing HTTP request.
|
||||
func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
tokenSource, _, errSource := prepareGeminiCLITokenSource(req.Context(), e.cfg, auth)
|
||||
if errSource != nil {
|
||||
return errSource
|
||||
}
|
||||
tok, errTok := tokenSource.Token()
|
||||
if errTok != nil {
|
||||
return errTok
|
||||
}
|
||||
if strings.TrimSpace(tok.AccessToken) == "" {
|
||||
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||
applyGeminiCLIHeaders(req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Gemini CLI credentials into the request and executes it.
|
||||
func (e *GeminiCLIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("gemini-cli executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request to the Gemini CLI API.
|
||||
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
@@ -77,14 +111,19 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -216,14 +255,19 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
|
||||
|
||||
projectID := resolveGeminiProjectID(auth)
|
||||
|
||||
@@ -421,7 +465,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
// Gemini CLI endpoint when iterating fallback variants.
|
||||
for _, attemptModel := range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
|
||||
@@ -55,8 +55,38 @@ func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor {
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *GeminiExecutor) Identifier() string { return "gemini" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini).
|
||||
func (e *GeminiExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||
// PrepareRequest injects Gemini credentials into the outgoing HTTP request.
|
||||
func (e *GeminiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey, bearer := geminiCreds(auth)
|
||||
if apiKey != "" {
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
req.Header.Del("Authorization")
|
||||
} else if bearer != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+bearer)
|
||||
req.Header.Del("x-goog-api-key")
|
||||
}
|
||||
applyGeminiHeaders(req, auth)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Gemini credentials into the request and executes it.
|
||||
func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("gemini executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request to the Gemini API.
|
||||
// It translates the request to Gemini format, sends it to the API, and translates
|
||||
@@ -85,13 +115,18 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
action := "generateContent"
|
||||
@@ -183,13 +218,18 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
|
||||
@@ -50,11 +50,49 @@ func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request for execution (no-op for Vertex).
|
||||
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||
// PrepareRequest injects Vertex credentials into the outgoing HTTP request.
|
||||
func (e *GeminiVertexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey, _ := vertexAPICreds(auth)
|
||||
if strings.TrimSpace(apiKey) != "" {
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
req.Header.Del("Authorization")
|
||||
return nil
|
||||
}
|
||||
_, _, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return errCreds
|
||||
}
|
||||
token, errToken := vertexAccessToken(req.Context(), e.cfg, auth, saJSON)
|
||||
if errToken != nil {
|
||||
return errToken
|
||||
}
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Del("x-goog-api-key")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Vertex credentials into the request and executes it.
|
||||
func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("vertex executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request to the Vertex AI API.
|
||||
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
// Try API key authentication first
|
||||
@@ -122,6 +160,11 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
@@ -134,7 +177,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
|
||||
action := "generateContent"
|
||||
@@ -225,6 +268,11 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
@@ -237,7 +285,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
action := "generateContent"
|
||||
@@ -324,6 +372,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
@@ -336,7 +389,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
@@ -444,6 +497,11 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
@@ -456,7 +514,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
|
||||
@@ -37,8 +37,33 @@ func NewIFlowExecutor(cfg *config.Config) *IFlowExecutor { return &IFlowExecutor
|
||||
// Identifier returns the provider key.
|
||||
func (e *IFlowExecutor) Identifier() string { return "iflow" }
|
||||
|
||||
// PrepareRequest implements ProviderExecutor but requires no preprocessing.
|
||||
func (e *IFlowExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||
// PrepareRequest injects iFlow credentials into the outgoing HTTP request.
|
||||
func (e *IFlowExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey, _ := iflowCreds(auth)
|
||||
if strings.TrimSpace(apiKey) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects iFlow credentials into the request and executes it.
|
||||
func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("iflow executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming chat completion request.
|
||||
func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
@@ -56,6 +81,11 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
@@ -65,7 +95,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
|
||||
@@ -145,6 +175,11 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
@@ -160,7 +195,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
body = ensureToolsArray(body)
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
|
||||
@@ -441,21 +476,18 @@ func ensureToolsArray(body []byte) []byte {
|
||||
return updated
|
||||
}
|
||||
|
||||
// preserveReasoningContentInMessages ensures reasoning_content from assistant messages in the
|
||||
// conversation history is preserved when sending to iFlow models that support thinking.
|
||||
// This is critical for multi-turn conversations where the model needs to see its previous
|
||||
// reasoning to maintain coherent thought chains across tool calls and conversation turns.
|
||||
// preserveReasoningContentInMessages checks if reasoning_content from assistant messages
|
||||
// is preserved in conversation history for iFlow models that support thinking.
|
||||
// This is helpful for multi-turn conversations where the model may benefit from seeing
|
||||
// its previous reasoning to maintain coherent thought chains.
|
||||
//
|
||||
// For GLM-4.7 and MiniMax-M2.1, the full assistant response (including reasoning) must be
|
||||
// appended back into message history before the next call.
|
||||
// For GLM-4.6/4.7 and MiniMax M2/M2.1, it is recommended to include the full assistant
|
||||
// response (including reasoning_content) in message history for better context continuity.
|
||||
func preserveReasoningContentInMessages(body []byte) []byte {
|
||||
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||
|
||||
// Only apply to models that support thinking with history preservation
|
||||
needsPreservation := strings.HasPrefix(model, "glm-4.7") ||
|
||||
strings.HasPrefix(model, "glm-4-7") ||
|
||||
strings.HasPrefix(model, "minimax-m2.1") ||
|
||||
strings.HasPrefix(model, "minimax-m2-1")
|
||||
needsPreservation := strings.HasPrefix(model, "glm-4") || strings.HasPrefix(model, "minimax-m2")
|
||||
|
||||
if !needsPreservation {
|
||||
return body
|
||||
@@ -493,45 +525,35 @@ func preserveReasoningContentInMessages(body []byte) []byte {
|
||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
||||
//
|
||||
// Model-specific handling:
|
||||
// - GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
|
||||
// - MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
|
||||
// - Other iFlow models: Uses chat_template_kwargs.enable_thinking (boolean)
|
||||
// - GLM-4.6/4.7: Uses chat_template_kwargs.enable_thinking (boolean) and chat_template_kwargs.clear_thinking=false
|
||||
// - MiniMax M2/M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
|
||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
||||
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||
|
||||
// Check if thinking should be enabled
|
||||
val := ""
|
||||
if effort.Exists() {
|
||||
val = strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
if !effort.Exists() {
|
||||
return body
|
||||
}
|
||||
enableThinking := effort.Exists() && val != "none" && val != ""
|
||||
|
||||
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||
val := strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
enableThinking := val != "none" && val != ""
|
||||
|
||||
// Remove reasoning_effort as we'll convert to model-specific format
|
||||
if effort.Exists() {
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
body, _ = sjson.DeleteBytes(body, "thinking")
|
||||
|
||||
// GLM-4.7: Use extra_body with thinking config and clear_thinking: false
|
||||
if strings.HasPrefix(model, "glm-4.7") || strings.HasPrefix(model, "glm-4-7") {
|
||||
if enableThinking {
|
||||
body, _ = sjson.SetBytes(body, "extra_body.thinking.type", "enabled")
|
||||
body, _ = sjson.SetBytes(body, "extra_body.clear_thinking", false)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// MiniMax-M2.1: Use reasoning_split=true for interleaved thinking
|
||||
if strings.HasPrefix(model, "minimax-m2.1") || strings.HasPrefix(model, "minimax-m2-1") {
|
||||
if enableThinking {
|
||||
body, _ = sjson.SetBytes(body, "reasoning_split", true)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// Other iFlow models (including GLM-4.6): Use chat_template_kwargs.enable_thinking
|
||||
if effort.Exists() {
|
||||
// GLM-4.6/4.7: Use chat_template_kwargs
|
||||
if strings.HasPrefix(model, "glm-4") {
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||
if enableThinking {
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.clear_thinking", false)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// MiniMax M2/M2.1: Use reasoning_split
|
||||
if strings.HasPrefix(model, "minimax-m2") {
|
||||
body, _ = sjson.SetBytes(body, "reasoning_split", enableThinking)
|
||||
return body
|
||||
}
|
||||
|
||||
return body
|
||||
|
||||
@@ -304,11 +304,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
|
||||
parts = append(parts, "type=api_key")
|
||||
}
|
||||
case "oauth":
|
||||
if authValue != "" {
|
||||
parts = append(parts, fmt.Sprintf("type=oauth account=%s", authValue))
|
||||
} else {
|
||||
parts = append(parts, "type=oauth")
|
||||
}
|
||||
parts = append(parts, "type=oauth")
|
||||
default:
|
||||
if authType != "" {
|
||||
if authValue != "" {
|
||||
|
||||
@@ -35,11 +35,39 @@ func NewOpenAICompatExecutor(provider string, cfg *config.Config) *OpenAICompatE
|
||||
// Identifier implements cliproxyauth.ProviderExecutor.
|
||||
func (e *OpenAICompatExecutor) Identifier() string { return e.provider }
|
||||
|
||||
// PrepareRequest is a no-op for now (credentials are added via headers at execution time).
|
||||
func (e *OpenAICompatExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||
// PrepareRequest injects OpenAI-compatible credentials into the outgoing HTTP request.
|
||||
func (e *OpenAICompatExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
_, apiKey := e.resolveCredentials(auth)
|
||||
if strings.TrimSpace(apiKey) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects OpenAI-compatible credentials into the request and executes it.
|
||||
func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("openai compat executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
@@ -53,12 +81,17 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
// Translate inbound request to OpenAI format
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||
if modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||
@@ -145,12 +178,17 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
}
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||
if modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||
@@ -231,6 +269,11 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
|
||||
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
|
||||
// Pass through translator; it yields one or more chunks for the target schema.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m)
|
||||
|
||||
@@ -14,32 +14,54 @@ import (
|
||||
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
|
||||
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
|
||||
// Use the alias from metadata if available, as it's registered in the global registry
|
||||
// with thinking metadata; the upstream model name may not be registered.
|
||||
lookupModel := util.ResolveOriginalModel(model, metadata)
|
||||
|
||||
// Determine which model to use for thinking support check.
|
||||
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
|
||||
thinkingModel := lookupModel
|
||||
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
|
||||
thinkingModel = model
|
||||
}
|
||||
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
|
||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
||||
return payload
|
||||
}
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
if !util.ModelSupportsThinking(thinkingModel) {
|
||||
return payload
|
||||
}
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
|
||||
// applyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// ApplyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
|
||||
func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
|
||||
func ApplyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
|
||||
// Use the alias from metadata if available, as it's registered in the global registry
|
||||
// with thinking metadata; the upstream model name may not be registered.
|
||||
lookupModel := util.ResolveOriginalModel(model, metadata)
|
||||
|
||||
// Determine which model to use for thinking support check.
|
||||
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
|
||||
thinkingModel := lookupModel
|
||||
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
|
||||
thinkingModel = model
|
||||
}
|
||||
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
|
||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
||||
return payload
|
||||
}
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
if !util.ModelSupportsThinking(thinkingModel) {
|
||||
return payload
|
||||
}
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
@@ -82,17 +104,11 @@ func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model
|
||||
return payload
|
||||
}
|
||||
|
||||
// applyPayloadConfig applies payload default and override rules from configuration
|
||||
// to the given JSON payload for the specified model.
|
||||
// Defaults only fill missing fields, while overrides always overwrite existing values.
|
||||
func applyPayloadConfig(cfg *config.Config, model string, payload []byte) []byte {
|
||||
return applyPayloadConfigWithRoot(cfg, model, "", "", payload)
|
||||
}
|
||||
|
||||
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
||||
// and restricts matches to the given protocol when supplied.
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload []byte) []byte {
|
||||
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||
// against the original payload when provided.
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte) []byte {
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
@@ -105,6 +121,11 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
return payload
|
||||
}
|
||||
out := payload
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
source = payload
|
||||
}
|
||||
appliedDefaults := make(map[string]struct{})
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
@@ -116,7 +137,10 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(out, fullPath).Exists() {
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
@@ -124,6 +148,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
|
||||
@@ -36,7 +36,33 @@ func NewQwenExecutor(cfg *config.Config) *QwenExecutor { return &QwenExecutor{cf
|
||||
|
||||
func (e *QwenExecutor) Identifier() string { return "qwen" }
|
||||
|
||||
func (e *QwenExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||
// PrepareRequest injects Qwen credentials into the outgoing HTTP request.
|
||||
func (e *QwenExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token, _ := qwenCreds(auth)
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Qwen credentials into the request and executes it.
|
||||
func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("qwen executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
token, baseURL := qwenCreds(auth)
|
||||
@@ -49,6 +75,11 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
@@ -56,7 +87,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -125,6 +156,11 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
@@ -140,7 +176,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -136,14 +135,14 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if sessionID != "" && thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
|
||||
signature = cachedSig
|
||||
log.Debugf("Using cached signature for thinking block")
|
||||
// log.Debugf("Using cached signature for thinking block")
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to client signature only if cache miss and client signature is valid
|
||||
if signature == "" && cache.HasValidSignature(clientSignature) {
|
||||
signature = clientSignature
|
||||
log.Debugf("Using client-provided signature for thinking block")
|
||||
// log.Debugf("Using client-provided signature for thinking block")
|
||||
}
|
||||
|
||||
// Store for subsequent tool_use in the same message
|
||||
@@ -158,8 +157,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||
// Converting to text would break this requirement
|
||||
if isUnsigned {
|
||||
// TypeScript plugin approach: drop unsigned thinking blocks entirely
|
||||
log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -183,7 +181,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||
// Antigravity API validates signatures, so dummy values are rejected.
|
||||
// The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies.
|
||||
|
||||
functionName := contentResult.Get("name").String()
|
||||
argsResult := contentResult.Get("input")
|
||||
|
||||
@@ -136,11 +136,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Process thinking content (internal reasoning)
|
||||
if partResult.Get("thought").Bool() {
|
||||
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||
log.Debug("Branch: signature_delta")
|
||||
// log.Debug("Branch: signature_delta")
|
||||
|
||||
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
|
||||
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||
log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
||||
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
||||
params.CurrentThinkingText.Reset()
|
||||
}
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
role := m.Get("role").String()
|
||||
content := m.Get("content")
|
||||
|
||||
if role == "system" && len(arr) > 1 {
|
||||
if (role == "system" || role == "developer") && len(arr) > 1 {
|
||||
// system -> request.systemInstruction as a user message style
|
||||
if content.Type == gjson.String {
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
|
||||
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if role == "user" || (role == "system" && len(arr) == 1) {
|
||||
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
|
||||
// Build single user content node to avoid splitting into multiple contents
|
||||
node := []byte(`{"role":"user","parts":[]}`)
|
||||
if content.Type == gjson.String {
|
||||
@@ -223,6 +223,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
}
|
||||
}
|
||||
@@ -266,6 +267,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,16 @@ type claudeToResponsesState struct {
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
|
||||
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
|
||||
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
|
||||
return originalRequestRawJSON
|
||||
}
|
||||
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
|
||||
return requestRawJSON
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func emitEvent(event string, payload string) string {
|
||||
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
||||
}
|
||||
@@ -241,6 +251,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
st.InFuncBlock = false
|
||||
} else if st.ReasoningActive {
|
||||
@@ -279,8 +290,9 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
|
||||
// Inject original request fields into response as per docs/response.completed.json
|
||||
|
||||
if requestRawJSON != nil {
|
||||
req := gjson.ParseBytes(requestRawJSON)
|
||||
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
|
||||
if len(reqBytes) > 0 {
|
||||
req := gjson.ParseBytes(reqBytes)
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.instructions", v.String())
|
||||
}
|
||||
@@ -549,8 +561,9 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
out, _ = sjson.Set(out, "created_at", createdAt)
|
||||
|
||||
// Inject request echo fields as top-level (similar to streaming variant)
|
||||
if requestRawJSON != nil {
|
||||
req := gjson.ParseBytes(requestRawJSON)
|
||||
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
|
||||
if len(reqBytes) > 0 {
|
||||
req := gjson.ParseBytes(reqBytes)
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "instructions", v.String())
|
||||
}
|
||||
|
||||
@@ -37,10 +37,11 @@ import (
|
||||
// - []byte: The transformed request data in internal client format
|
||||
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
userAgent := misc.ExtractCodexUserAgent(rawJSON)
|
||||
|
||||
template := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
_, instructions := misc.CodexInstructionsForModel(modelName, "")
|
||||
_, instructions := misc.CodexInstructionsForModel(modelName, "", userAgent)
|
||||
template, _ = sjson.Set(template, "instructions", instructions)
|
||||
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -20,6 +20,12 @@ var (
|
||||
dataTag = []byte("data:")
|
||||
)
|
||||
|
||||
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
||||
type ConvertCodexResponseToClaudeParams struct {
|
||||
HasToolCall bool
|
||||
BlockIndex int
|
||||
}
|
||||
|
||||
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
||||
// This function implements a complex state machine that translates Codex API responses
|
||||
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||
@@ -38,8 +44,10 @@ var (
|
||||
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
|
||||
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
if *param == nil {
|
||||
hasToolCall := false
|
||||
*param = &hasToolCall
|
||||
*param = &ConvertCodexResponseToClaudeParams{
|
||||
HasToolCall: false,
|
||||
BlockIndex: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// log.Debugf("rawJSON: %s", string(rawJSON))
|
||||
@@ -62,46 +70,49 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
} else if typeStr == "response.reasoning_summary_part.added" {
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output = "event: content_block_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
} else if typeStr == "response.reasoning_summary_text.delta" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
|
||||
|
||||
output = "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
} else if typeStr == "response.reasoning_summary_part.done" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
|
||||
} else if typeStr == "response.content_part.added" {
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output = "event: content_block_start\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
} else if typeStr == "response.output_text.delta" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
|
||||
|
||||
output = "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
} else if typeStr == "response.content_part.done" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
} else if typeStr == "response.completed" {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
p := (*param).(*bool)
|
||||
if *p {
|
||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
||||
if p {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||
@@ -118,10 +129,9 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
itemResult := rootResult.Get("item")
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
p := true
|
||||
*param = &p
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||
{
|
||||
// Restore original tool name if shortened
|
||||
@@ -137,7 +147,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
@@ -147,14 +157,15 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
itemType := itemResult.Get("type").String()
|
||||
if itemType == "function_call" {
|
||||
template = `{"type":"content_block_stop","index":0}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
|
||||
|
||||
output = "event: content_block_stop\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
}
|
||||
} else if typeStr == "response.function_call_arguments.delta" {
|
||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||
|
||||
output += "event: content_block_delta\n"
|
||||
|
||||
@@ -38,11 +38,12 @@ import (
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
userAgent := misc.ExtractCodexUserAgent(rawJSON)
|
||||
// Base template
|
||||
out := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
// Inject standard Codex instructions
|
||||
_, instructions := misc.CodexInstructionsForModel(modelName, "")
|
||||
_, instructions := misc.CodexInstructionsForModel(modelName, "", userAgent)
|
||||
out, _ = sjson.Set(out, "instructions", instructions)
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
// - []byte: The transformed request data in OpenAI Responses API format
|
||||
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
userAgent := misc.ExtractCodexUserAgent(rawJSON)
|
||||
// Start with empty JSON object
|
||||
out := `{}`
|
||||
|
||||
@@ -96,7 +97,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
|
||||
// Extract system instructions from first system message (string or text object)
|
||||
messages := gjson.GetBytes(rawJSON, "messages")
|
||||
_, instructions := misc.CodexInstructionsForModel(modelName, "")
|
||||
_, instructions := misc.CodexInstructionsForModel(modelName, "", userAgent)
|
||||
out, _ = sjson.Set(out, "instructions", instructions)
|
||||
// if messages.IsArray() {
|
||||
// arr := messages.Array()
|
||||
@@ -275,7 +276,15 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
arr := tools.Array()
|
||||
for i := 0; i < len(arr); i++ {
|
||||
t := arr[i]
|
||||
if t.Get("type").String() == "function" {
|
||||
toolType := t.Get("type").String()
|
||||
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
|
||||
// Only "function" needs structural conversion because Chat Completions nests details under "function".
|
||||
if toolType != "" && toolType != "function" && t.IsObject() {
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
|
||||
continue
|
||||
}
|
||||
|
||||
if toolType == "function" {
|
||||
item := `{}`
|
||||
item, _ = sjson.Set(item, "type", "function")
|
||||
fn := t.Get("function")
|
||||
@@ -304,6 +313,37 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
}
|
||||
}
|
||||
|
||||
// Map tool_choice when present.
|
||||
// Chat Completions: "tool_choice" can be a string ("auto"/"none") or an object (e.g. {"type":"function","function":{"name":"..."}}).
|
||||
// Responses API: keep built-in tool choices as-is; flatten function choice to {"type":"function","name":"..."}.
|
||||
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
|
||||
switch {
|
||||
case tc.Type == gjson.String:
|
||||
out, _ = sjson.Set(out, "tool_choice", tc.String())
|
||||
case tc.IsObject():
|
||||
tcType := tc.Get("type").String()
|
||||
if tcType == "function" {
|
||||
name := tc.Get("function.name").String()
|
||||
if name != "" {
|
||||
if short, ok := originalToolNameMap[name]; ok {
|
||||
name = short
|
||||
} else {
|
||||
name = shortenNameIfNeeded(name)
|
||||
}
|
||||
}
|
||||
choice := `{}`
|
||||
choice, _ = sjson.Set(choice, "type", "function")
|
||||
if name != "" {
|
||||
choice, _ = sjson.Set(choice, "name", name)
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", choice)
|
||||
} else if tcType != "" {
|
||||
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
|
||||
func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
userAgent := misc.ExtractCodexUserAgent(rawJSON)
|
||||
rawJSON = misc.StripCodexUserAgent(rawJSON)
|
||||
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "store", false)
|
||||
@@ -32,7 +34,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
originalInstructionsText = originalInstructionsResult.String()
|
||||
}
|
||||
|
||||
hasOfficialInstructions, instructions := misc.CodexInstructionsForModel(modelName, originalInstructionsResult.String())
|
||||
hasOfficialInstructions, instructions := misc.CodexInstructionsForModel(modelName, originalInstructionsResult.String(), userAgent)
|
||||
|
||||
inputResult := gjson.GetBytes(rawJSON, "input")
|
||||
var inputResults []gjson.Result
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -18,7 +19,10 @@ func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string
|
||||
if typeResult := gjson.GetBytes(rawJSON, "type"); typeResult.Exists() {
|
||||
typeStr := typeResult.String()
|
||||
if typeStr == "response.created" || typeStr == "response.in_progress" || typeStr == "response.completed" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", gjson.GetBytes(originalRequestRawJSON, "instructions").String())
|
||||
if gjson.GetBytes(rawJSON, "response.instructions").Exists() {
|
||||
instructions := selectInstructions(originalRequestRawJSON, requestRawJSON)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", instructions)
|
||||
}
|
||||
}
|
||||
}
|
||||
out := fmt.Sprintf("data: %s", string(rawJSON))
|
||||
@@ -37,6 +41,16 @@ func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, modelName
|
||||
}
|
||||
responseResult := rootResult.Get("response")
|
||||
template := responseResult.Raw
|
||||
template, _ = sjson.Set(template, "instructions", gjson.GetBytes(originalRequestRawJSON, "instructions").String())
|
||||
if responseResult.Get("instructions").Exists() {
|
||||
template, _ = sjson.Set(template, "instructions", selectInstructions(originalRequestRawJSON, requestRawJSON))
|
||||
}
|
||||
return template
|
||||
}
|
||||
|
||||
func selectInstructions(originalRequestRawJSON, requestRawJSON []byte) string {
|
||||
userAgent := misc.ExtractCodexUserAgent(originalRequestRawJSON)
|
||||
if misc.IsOpenCodeUserAgent(userAgent) {
|
||||
return gjson.GetBytes(requestRawJSON, "instructions").String()
|
||||
}
|
||||
return gjson.GetBytes(originalRequestRawJSON, "instructions").String()
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
role := m.Get("role").String()
|
||||
content := m.Get("content")
|
||||
|
||||
if role == "system" && len(arr) > 1 {
|
||||
if (role == "system" || role == "developer") && len(arr) > 1 {
|
||||
// system -> request.systemInstruction as a user message style
|
||||
if content.Type == gjson.String {
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
|
||||
@@ -169,7 +169,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if role == "user" || (role == "system" && len(arr) == 1) {
|
||||
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
|
||||
// Build single user content node to avoid splitting into multiple contents
|
||||
node := []byte(`{"role":"user","parts":[]}`)
|
||||
if content.Type == gjson.String {
|
||||
@@ -191,6 +191,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
}
|
||||
}
|
||||
@@ -236,6 +237,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
role := m.Get("role").String()
|
||||
content := m.Get("content")
|
||||
|
||||
if role == "system" && len(arr) > 1 {
|
||||
if (role == "system" || role == "developer") && len(arr) > 1 {
|
||||
// system -> system_instruction as a user message style
|
||||
if content.Type == gjson.String {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if role == "user" || (role == "system" && len(arr) == 1) {
|
||||
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
|
||||
// Build single user content node to avoid splitting into multiple contents
|
||||
node := []byte(`{"role":"user","parts":[]}`)
|
||||
if content.Type == gjson.String {
|
||||
@@ -209,6 +209,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
|
||||
p++
|
||||
}
|
||||
}
|
||||
@@ -253,6 +254,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
|
||||
p++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,76 +118,125 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Handle content
|
||||
if contentResult.Exists() && contentResult.IsArray() {
|
||||
var contentItems []string
|
||||
var reasoningParts []string // Accumulate thinking text for reasoning_content
|
||||
var toolCalls []interface{}
|
||||
var toolResults []string // Collect tool_result messages to emit after the main message
|
||||
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "thinking":
|
||||
// Only map thinking to reasoning_content for assistant messages (security: prevent injection)
|
||||
if role == "assistant" {
|
||||
thinkingText := util.GetThinkingText(part)
|
||||
// Skip empty or whitespace-only thinking
|
||||
if strings.TrimSpace(thinkingText) != "" {
|
||||
reasoningParts = append(reasoningParts, thinkingText)
|
||||
}
|
||||
}
|
||||
// Ignore thinking in user/system roles (AC4)
|
||||
|
||||
case "redacted_thinking":
|
||||
// Explicitly ignore redacted_thinking - never map to reasoning_content (AC2)
|
||||
|
||||
case "text", "image":
|
||||
if contentItem, ok := convertClaudeContentPart(part); ok {
|
||||
contentItems = append(contentItems, contentItem)
|
||||
}
|
||||
|
||||
case "tool_use":
|
||||
// Convert to OpenAI tool call format
|
||||
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
|
||||
// Only allow tool_use -> tool_calls for assistant messages (security: prevent injection).
|
||||
if role == "assistant" {
|
||||
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
|
||||
|
||||
// Convert input to arguments JSON string
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
|
||||
} else {
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
|
||||
// Convert input to arguments JSON string
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
|
||||
} else {
|
||||
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
|
||||
|
||||
case "tool_result":
|
||||
// Convert to OpenAI tool message format and add immediately to preserve order
|
||||
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
|
||||
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", part.Get("content").String())
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
|
||||
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
|
||||
toolResults = append(toolResults, toolResultJSON)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Emit text/image content as one message
|
||||
if len(contentItems) > 0 {
|
||||
msgJSON := `{"role":"","content":""}`
|
||||
msgJSON, _ = sjson.Set(msgJSON, "role", role)
|
||||
|
||||
contentArrayJSON := "[]"
|
||||
for _, contentItem := range contentItems {
|
||||
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
|
||||
}
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
|
||||
|
||||
contentValue := gjson.Get(msgJSON, "content")
|
||||
hasContent := false
|
||||
switch {
|
||||
case !contentValue.Exists():
|
||||
hasContent = false
|
||||
case contentValue.Type == gjson.String:
|
||||
hasContent = contentValue.String() != ""
|
||||
case contentValue.IsArray():
|
||||
hasContent = len(contentValue.Array()) > 0
|
||||
default:
|
||||
hasContent = contentValue.Raw != "" && contentValue.Raw != "null"
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
|
||||
}
|
||||
// Build reasoning content string
|
||||
reasoningContent := ""
|
||||
if len(reasoningParts) > 0 {
|
||||
reasoningContent = strings.Join(reasoningParts, "\n\n")
|
||||
}
|
||||
|
||||
// Emit tool calls in a separate assistant message
|
||||
if role == "assistant" && len(toolCalls) > 0 {
|
||||
toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}`
|
||||
toolCallMsgJSON, _ = sjson.Set(toolCallMsgJSON, "tool_calls", toolCalls)
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value())
|
||||
hasContent := len(contentItems) > 0
|
||||
hasReasoning := reasoningContent != ""
|
||||
hasToolCalls := len(toolCalls) > 0
|
||||
hasToolResults := len(toolResults) > 0
|
||||
|
||||
// OpenAI requires: tool messages MUST immediately follow the assistant message with tool_calls.
|
||||
// Therefore, we emit tool_result messages FIRST (they respond to the previous assistant's tool_calls),
|
||||
// then emit the current message's content.
|
||||
for _, toolResultJSON := range toolResults {
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
|
||||
}
|
||||
|
||||
// For assistant messages: emit a single unified message with content, tool_calls, and reasoning_content
|
||||
// This avoids splitting into multiple assistant messages which breaks OpenAI tool-call adjacency
|
||||
if role == "assistant" {
|
||||
if hasContent || hasReasoning || hasToolCalls {
|
||||
msgJSON := `{"role":"assistant"}`
|
||||
|
||||
// Add content (as array if we have items, empty string if reasoning-only)
|
||||
if hasContent {
|
||||
contentArrayJSON := "[]"
|
||||
for _, contentItem := range contentItems {
|
||||
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
|
||||
}
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
|
||||
} else {
|
||||
// Ensure content field exists for OpenAI compatibility
|
||||
msgJSON, _ = sjson.Set(msgJSON, "content", "")
|
||||
}
|
||||
|
||||
// Add reasoning_content if present
|
||||
if hasReasoning {
|
||||
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
|
||||
}
|
||||
|
||||
// Add tool_calls if present (in same message as content)
|
||||
if hasToolCalls {
|
||||
msgJSON, _ = sjson.Set(msgJSON, "tool_calls", toolCalls)
|
||||
}
|
||||
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
|
||||
}
|
||||
} else {
|
||||
// For non-assistant roles: emit content message if we have content
|
||||
// If the message only contains tool_results (no text/image), we still processed them above
|
||||
if hasContent {
|
||||
msgJSON := `{"role":""}`
|
||||
msgJSON, _ = sjson.Set(msgJSON, "role", role)
|
||||
|
||||
contentArrayJSON := "[]"
|
||||
for _, contentItem := range contentItems {
|
||||
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
|
||||
}
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
|
||||
|
||||
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
|
||||
} else if hasToolResults && !hasContent {
|
||||
// tool_results already emitted above, no additional user message needed
|
||||
}
|
||||
}
|
||||
|
||||
} else if contentResult.Exists() && contentResult.Type == gjson.String {
|
||||
@@ -307,3 +356,43 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func convertClaudeToolResultContentToString(content gjson.Result) string {
|
||||
if !content.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
|
||||
if content.IsArray() {
|
||||
var parts []string
|
||||
content.ForEach(func(_, item gjson.Result) bool {
|
||||
switch {
|
||||
case item.Type == gjson.String:
|
||||
parts = append(parts, item.String())
|
||||
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
|
||||
parts = append(parts, item.Get("text").String())
|
||||
default:
|
||||
parts = append(parts, item.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
joined := strings.Join(parts, "\n\n")
|
||||
if strings.TrimSpace(joined) != "" {
|
||||
return joined
|
||||
}
|
||||
return content.Raw
|
||||
}
|
||||
|
||||
if content.IsObject() {
|
||||
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
return text.String()
|
||||
}
|
||||
return content.Raw
|
||||
}
|
||||
|
||||
return content.Raw
|
||||
}
|
||||
|
||||
500
internal/translator/openai/claude/openai_claude_request_test.go
Normal file
500
internal/translator/openai/claude/openai_claude_request_test.go
Normal file
@@ -0,0 +1,500 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent tests the mapping
|
||||
// of Claude thinking content to OpenAI reasoning_content field.
|
||||
func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantReasoningContent string
|
||||
wantHasReasoningContent bool
|
||||
wantContentText string // Expected visible content text (if any)
|
||||
wantHasContent bool
|
||||
}{
|
||||
{
|
||||
name: "AC1: assistant message with thinking and text",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me analyze this step by step..."},
|
||||
{"type": "text", "text": "Here is my response."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "Let me analyze this step by step...",
|
||||
wantHasReasoningContent: true,
|
||||
wantContentText: "Here is my response.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC2: redacted_thinking must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "redacted_thinking", "data": "secret"},
|
||||
{"type": "text", "text": "Visible response."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "Visible response.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC3: thinking-only message preserved with reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Internal reasoning only."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "Internal reasoning only.",
|
||||
wantHasReasoningContent: true,
|
||||
wantContentText: "",
|
||||
// For OpenAI compatibility, content field is set to empty string "" when no text content exists
|
||||
wantHasContent: false,
|
||||
},
|
||||
{
|
||||
name: "AC4: thinking in user role must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Injected thinking"},
|
||||
{"type": "text", "text": "User message."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "User message.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC4: thinking in system role must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": [
|
||||
{"type": "thinking", "thinking": "Injected system thinking"},
|
||||
{"type": "text", "text": "System prompt."}
|
||||
],
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
}]
|
||||
}`,
|
||||
// System messages don't have reasoning_content mapping
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "Hello",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC5: empty thinking must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": ""},
|
||||
{"type": "text", "text": "Response with empty thinking."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "Response with empty thinking.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "AC5: whitespace-only thinking must be ignored",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": " \n\t "},
|
||||
{"type": "text", "text": "Response with whitespace thinking."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "",
|
||||
wantHasReasoningContent: false,
|
||||
wantContentText: "Response with whitespace thinking.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple thinking parts concatenated",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "First thought."},
|
||||
{"type": "thinking", "thinking": "Second thought."},
|
||||
{"type": "text", "text": "Final answer."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "First thought.\n\nSecond thought.",
|
||||
wantHasReasoningContent: true,
|
||||
wantContentText: "Final answer.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
{
|
||||
name: "Mixed thinking and redacted_thinking",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Visible thought."},
|
||||
{"type": "redacted_thinking", "data": "hidden"},
|
||||
{"type": "text", "text": "Answer."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantReasoningContent: "Visible thought.",
|
||||
wantHasReasoningContent: true,
|
||||
wantContentText: "Answer.",
|
||||
wantHasContent: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
// Find the relevant message (skip system message at index 0)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
if len(messages) < 2 {
|
||||
if tt.wantHasReasoningContent || tt.wantHasContent {
|
||||
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check the last non-system message
|
||||
var targetMsg gjson.Result
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Get("role").String() != "system" {
|
||||
targetMsg = messages[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check reasoning_content
|
||||
gotReasoningContent := targetMsg.Get("reasoning_content").String()
|
||||
gotHasReasoningContent := targetMsg.Get("reasoning_content").Exists()
|
||||
|
||||
if gotHasReasoningContent != tt.wantHasReasoningContent {
|
||||
t.Errorf("reasoning_content existence = %v, want %v", gotHasReasoningContent, tt.wantHasReasoningContent)
|
||||
}
|
||||
|
||||
if gotReasoningContent != tt.wantReasoningContent {
|
||||
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
|
||||
}
|
||||
|
||||
// Check content
|
||||
content := targetMsg.Get("content")
|
||||
// content has meaningful content if it's a non-empty array, or a non-empty string
|
||||
var gotHasContent bool
|
||||
switch {
|
||||
case content.IsArray():
|
||||
gotHasContent = len(content.Array()) > 0
|
||||
case content.Type == gjson.String:
|
||||
gotHasContent = content.String() != ""
|
||||
default:
|
||||
gotHasContent = false
|
||||
}
|
||||
|
||||
if gotHasContent != tt.wantHasContent {
|
||||
t.Errorf("content existence = %v, want %v", gotHasContent, tt.wantHasContent)
|
||||
}
|
||||
|
||||
if tt.wantHasContent && tt.wantContentText != "" {
|
||||
// Find text content
|
||||
var foundText string
|
||||
content.ForEach(func(_, v gjson.Result) bool {
|
||||
if v.Get("type").String() == "text" {
|
||||
foundText = v.Get("text").String()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if foundText != tt.wantContentText {
|
||||
t.Errorf("content text = %q, want %q", foundText, tt.wantContentText)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved tests AC3:
|
||||
// that a message with only thinking content is preserved (not dropped).
|
||||
func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "What is 2+2?"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "thinking", "thinking": "Let me calculate: 2+2=4"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Thanks"}]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
// Check the assistant message (index 2) has reasoning_content
|
||||
assistantMsg := messages[2]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
if !assistantMsg.Get("reasoning_content").Exists() {
|
||||
t.Error("Expected assistant message to have reasoning_content")
|
||||
}
|
||||
|
||||
if assistantMsg.Get("reasoning_content").String() != "Let me calculate: 2+2=4" {
|
||||
t.Errorf("Unexpected reasoning_content: %s", assistantMsg.Get("reasoning_content").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "before"},
|
||||
{"type": "tool_result", "tool_use_id": "call_1", "content": [{"type":"text","text":"tool ok"}]},
|
||||
{"type": "text", "text": "after"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
|
||||
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after)
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[0].Get("role").String() != "system" {
|
||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
||||
}
|
||||
|
||||
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
|
||||
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
|
||||
}
|
||||
|
||||
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
|
||||
if messages[2].Get("role").String() != "tool" {
|
||||
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String())
|
||||
}
|
||||
if got := messages[2].Get("tool_call_id").String(); got != "call_1" {
|
||||
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
|
||||
}
|
||||
if got := messages[2].Get("content").String(); got != "tool ok" {
|
||||
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
|
||||
}
|
||||
|
||||
// User message comes after tool message
|
||||
if messages[3].Get("role").String() != "user" {
|
||||
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String())
|
||||
}
|
||||
// User message should contain both "before" and "after" text
|
||||
if got := messages[3].Get("content.0.text").String(); got != "before" {
|
||||
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
|
||||
}
|
||||
if got := messages[3].Get("content.1.text").String(); got != "after" {
|
||||
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_1", "content": {"foo": "bar"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// system + assistant(tool_calls) + tool(result)
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[2].Get("role").String() != "tool" {
|
||||
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String())
|
||||
}
|
||||
|
||||
toolContent := messages[2].Get("content").String()
|
||||
parsed := gjson.Parse(toolContent)
|
||||
if parsed.Get("foo").String() != "bar" {
|
||||
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "pre"},
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
|
||||
{"type": "text", "text": "post"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// New behavior: content + tool_calls unified in single assistant message
|
||||
// Expect: system + assistant(content[pre,post] + tool_calls)
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[0].Get("role").String() != "system" {
|
||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
||||
}
|
||||
|
||||
assistantMsg := messages[1]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
// Should have both content and tool_calls in same message
|
||||
if !assistantMsg.Get("tool_calls").Exists() {
|
||||
t.Fatalf("Expected assistant message to have tool_calls")
|
||||
}
|
||||
if got := assistantMsg.Get("tool_calls.0.id").String(); got != "call_1" {
|
||||
t.Fatalf("Expected tool_call id %q, got %q", "call_1", got)
|
||||
}
|
||||
if got := assistantMsg.Get("tool_calls.0.function.name").String(); got != "do_work" {
|
||||
t.Fatalf("Expected tool_call name %q, got %q", "do_work", got)
|
||||
}
|
||||
|
||||
// Content should have both pre and post text
|
||||
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
|
||||
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
|
||||
}
|
||||
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
|
||||
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "t1"},
|
||||
{"type": "text", "text": "pre"},
|
||||
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
|
||||
{"type": "thinking", "thinking": "t2"},
|
||||
{"type": "text", "text": "post"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// New behavior: all content, thinking, and tool_calls unified in single assistant message
|
||||
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
assistantMsg := messages[1]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
// Should have content with both pre and post
|
||||
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
|
||||
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
|
||||
}
|
||||
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
|
||||
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
|
||||
}
|
||||
|
||||
// Should have tool_calls
|
||||
if !assistantMsg.Get("tool_calls").Exists() {
|
||||
t.Fatalf("Expected assistant message to have tool_calls")
|
||||
}
|
||||
|
||||
// Should have combined reasoning_content from both thinking blocks
|
||||
if got := assistantMsg.Get("reasoning_content").String(); got != "t1\n\nt2" {
|
||||
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
|
||||
}
|
||||
}
|
||||
@@ -299,17 +299,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
inputTokens = promptTokens.Int()
|
||||
outputTokens = completionTokens.Int()
|
||||
}
|
||||
// Send message_delta with usage
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
||||
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
||||
param.MessageDeltaSent = true
|
||||
|
||||
emitMessageStopIfNeeded(param, &results)
|
||||
}
|
||||
// Send message_delta with usage
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
||||
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
||||
param.MessageDeltaSent = true
|
||||
|
||||
emitMessageStopIfNeeded(param, &results)
|
||||
|
||||
}
|
||||
|
||||
return results
|
||||
@@ -480,15 +479,15 @@ func collectOpenAIReasoningTexts(node gjson.Result) []string {
|
||||
|
||||
switch node.Type {
|
||||
case gjson.String:
|
||||
if text := strings.TrimSpace(node.String()); text != "" {
|
||||
if text := node.String(); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
case gjson.JSON:
|
||||
if text := node.Get("text"); text.Exists() {
|
||||
if trimmed := strings.TrimSpace(text.String()); trimmed != "" {
|
||||
texts = append(texts, trimmed)
|
||||
if textStr := text.String(); textStr != "" {
|
||||
texts = append(texts, textStr)
|
||||
}
|
||||
} else if raw := strings.TrimSpace(node.Raw); raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
|
||||
} else if raw := node.Raw; raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
|
||||
texts = append(texts, raw)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,6 +163,14 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
var chatCompletionsTools []interface{}
|
||||
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
// Built-in tools (e.g. {"type":"web_search"}) are already compatible with the Chat Completions schema.
|
||||
// Only function tools need structural conversion because Chat Completions nests details under "function".
|
||||
toolType := tool.Get("type").String()
|
||||
if toolType != "" && toolType != "function" && tool.IsObject() {
|
||||
chatCompletionsTools = append(chatCompletionsTools, tool.Value())
|
||||
return true
|
||||
}
|
||||
|
||||
chatTool := `{"type":"function","function":{}}`
|
||||
|
||||
// Convert tool structure from responses format to chat completions format
|
||||
|
||||
@@ -390,6 +390,11 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
|
||||
|
||||
// If schema has properties but none are required, add a minimal placeholder.
|
||||
if propsVal.IsObject() && !hasRequiredProperties {
|
||||
// DO NOT add placeholder if it's a top-level schema (parentPath is empty)
|
||||
// or if we've already added a placeholder reason above.
|
||||
if parentPath == "" {
|
||||
continue
|
||||
}
|
||||
placeholderPath := joinPath(propsPath, "_")
|
||||
if !gjson.Get(jsonStr, placeholderPath).Exists() {
|
||||
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
|
||||
|
||||
@@ -127,8 +127,10 @@ func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing
|
||||
"type": "object",
|
||||
"description": "Accepts: null | object",
|
||||
"properties": {
|
||||
"_": { "type": "boolean" },
|
||||
"kind": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["_"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
@@ -71,10 +71,13 @@ func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool)
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "generationConfig.thinkingConfig.include_thoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
if !gjson.GetBytes(updated, "generationConfig.thinkingConfig.includeThoughts").Exists() &&
|
||||
!gjson.GetBytes(updated, "generationConfig.thinkingConfig.include_thoughts").Exists() {
|
||||
valuePath := "generationConfig.thinkingConfig.include_thoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
}
|
||||
return updated
|
||||
@@ -99,10 +102,13 @@ func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *boo
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
if !gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.includeThoughts").Exists() &&
|
||||
!gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts").Exists() {
|
||||
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
}
|
||||
return updated
|
||||
@@ -130,15 +136,15 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool)
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
if !gjson.GetBytes(updated, "generationConfig.thinkingConfig.includeThoughts").Exists() &&
|
||||
!gjson.GetBytes(updated, "generationConfig.thinkingConfig.include_thoughts").Exists() {
|
||||
valuePath := "generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
}
|
||||
if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts")
|
||||
}
|
||||
if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
}
|
||||
@@ -167,15 +173,15 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
if !gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.includeThoughts").Exists() &&
|
||||
!gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts").Exists() {
|
||||
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
}
|
||||
if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
}
|
||||
if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
}
|
||||
@@ -251,9 +257,14 @@ func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) {
|
||||
|
||||
// modelsWithDefaultThinking lists models that should have thinking enabled by default
|
||||
// when no explicit thinkingConfig is provided.
|
||||
// Note: Gemini 3 models are NOT included here because per Google's official documentation:
|
||||
// - thinkingLevel defaults to "high" (dynamic thinking)
|
||||
// - includeThoughts defaults to false
|
||||
//
|
||||
// We should not override these API defaults; let users explicitly configure if needed.
|
||||
var modelsWithDefaultThinking = map[string]bool{
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-image-preview": true,
|
||||
// "gemini-3-pro-preview": true,
|
||||
// "gemini-3-pro-image-preview": true,
|
||||
// "gemini-3-flash-preview": true,
|
||||
}
|
||||
|
||||
@@ -288,37 +299,73 @@ func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadata applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal))
|
||||
// or numeric budget suffix (e.g., model(1000)) which gets converted to a thinkingLevel.
|
||||
func ApplyGemini3ThinkingLevelFromMetadata(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
// Use the alias from metadata if available for model type detection
|
||||
lookupModel := ResolveOriginalModel(model, metadata)
|
||||
if !IsGemini3Model(lookupModel) && !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
|
||||
// Determine which model to use for validation
|
||||
checkModel := model
|
||||
if IsGemini3Model(lookupModel) {
|
||||
checkModel = lookupModel
|
||||
}
|
||||
|
||||
// First try to get effort string from metadata
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
if ok && effort != "" {
|
||||
if level, valid := ValidateGemini3ThinkingLevel(checkModel, effort); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
}
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
|
||||
// Fallback: check for numeric budget and convert to thinkingLevel
|
||||
budget, _, _, matched := ThinkingFromMetadata(metadata)
|
||||
if matched && budget != nil {
|
||||
if level, valid := ThinkingBudgetToGemini3Level(checkModel, *budget); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadataCLI applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal))
|
||||
// or numeric budget suffix (e.g., model(1000)) which gets converted to a thinkingLevel.
|
||||
func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
// Use the alias from metadata if available for model type detection
|
||||
lookupModel := ResolveOriginalModel(model, metadata)
|
||||
if !IsGemini3Model(lookupModel) && !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
|
||||
// Determine which model to use for validation
|
||||
checkModel := model
|
||||
if IsGemini3Model(lookupModel) {
|
||||
checkModel = lookupModel
|
||||
}
|
||||
|
||||
// First try to get effort string from metadata
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
if ok && effort != "" {
|
||||
if level, valid := ValidateGemini3ThinkingLevel(checkModel, effort); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
}
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
|
||||
// Fallback: check for numeric budget and convert to thinkingLevel
|
||||
budget, _, _, matched := ThinkingFromMetadata(metadata)
|
||||
if matched && budget != nil {
|
||||
if level, valid := ThinkingBudgetToGemini3Level(checkModel, *budget); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -326,15 +373,17 @@ func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
|
||||
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||
if !ModelHasDefaultThinking(model) {
|
||||
func ApplyDefaultThinkingIfNeededCLI(model string, metadata map[string]any, body []byte) []byte {
|
||||
// Use the alias from metadata if available for model property lookup
|
||||
lookupModel := ResolveOriginalModel(model, metadata)
|
||||
if !ModelHasDefaultThinking(lookupModel) && !ModelHasDefaultThinking(model) {
|
||||
return body
|
||||
}
|
||||
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
|
||||
return body
|
||||
}
|
||||
// Gemini 3 models use thinkingLevel instead of thinkingBudget
|
||||
if IsGemini3Model(model) {
|
||||
if IsGemini3Model(lookupModel) || IsGemini3Model(model) {
|
||||
// Don't set a default - let the API use its dynamic default ("high")
|
||||
// Only set includeThoughts
|
||||
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
56
internal/util/sanitize_test.go
Normal file
56
internal/util/sanitize_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeFunctionName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Normal", "valid_name", "valid_name"},
|
||||
{"With Dots", "name.with.dots", "name.with.dots"},
|
||||
{"With Colons", "name:with:colons", "name:with:colons"},
|
||||
{"With Dashes", "name-with-dashes", "name-with-dashes"},
|
||||
{"Mixed Allowed", "name.with_dots:colons-dashes", "name.with_dots:colons-dashes"},
|
||||
{"Invalid Characters", "name!with@invalid#chars", "name_with_invalid_chars"},
|
||||
{"Spaces", "name with spaces", "name_with_spaces"},
|
||||
{"Non-ASCII", "name_with_你好_chars", "name_with____chars"},
|
||||
{"Starts with digit", "123name", "_123name"},
|
||||
{"Starts with dot", ".name", "_.name"},
|
||||
{"Starts with colon", ":name", "_:name"},
|
||||
{"Starts with dash", "-name", "_-name"},
|
||||
{"Starts with invalid char", "!name", "_name"},
|
||||
{"Exactly 64 chars", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
|
||||
{"Too long (65 chars)", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charactX", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
|
||||
{"Very long", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_limit_for_function_names", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_l"},
|
||||
{"Starts with digit (64 chars total)", "1234567890123456789012345678901234567890123456789012345678901234", "_123456789012345678901234567890123456789012345678901234567890123"},
|
||||
{"Starts with invalid char (64 chars total)", "!234567890123456789012345678901234567890123456789012345678901234", "_234567890123456789012345678901234567890123456789012345678901234"},
|
||||
{"Empty", "", ""},
|
||||
{"Single character invalid", "@", "_"},
|
||||
{"Single character valid", "a", "a"},
|
||||
{"Single character digit", "1", "_1"},
|
||||
{"Single character underscore", "_", "_"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := SanitizeFunctionName(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("SanitizeFunctionName(%q) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
// Verify Gemini compliance
|
||||
if len(got) > 64 {
|
||||
t.Errorf("SanitizeFunctionName(%q) result too long: %d", tt.input, len(got))
|
||||
}
|
||||
if len(got) > 0 {
|
||||
first := got[0]
|
||||
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
|
||||
t.Errorf("SanitizeFunctionName(%q) result starts with invalid char: %c", tt.input, first)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,9 +12,18 @@ func ModelSupportsThinking(model string) bool {
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
// First check the global dynamic registry
|
||||
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil {
|
||||
return info.Thinking != nil
|
||||
}
|
||||
// Fallback: check static model definitions
|
||||
if info := registry.LookupStaticModelInfo(model); info != nil {
|
||||
return info.Thinking != nil
|
||||
}
|
||||
// Fallback: check Antigravity static config
|
||||
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil {
|
||||
return cfg.Thinking != nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -63,11 +72,19 @@ func thinkingRangeFromRegistry(model string) (found bool, min int, max int, zero
|
||||
if model == "" {
|
||||
return false, 0, 0, false, false
|
||||
}
|
||||
info := registry.GetGlobalRegistry().GetModelInfo(model)
|
||||
if info == nil || info.Thinking == nil {
|
||||
return false, 0, 0, false, false
|
||||
// First check global dynamic registry
|
||||
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil && info.Thinking != nil {
|
||||
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
|
||||
}
|
||||
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
|
||||
// Fallback: check static model definitions
|
||||
if info := registry.LookupStaticModelInfo(model); info != nil && info.Thinking != nil {
|
||||
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
|
||||
}
|
||||
// Fallback: check Antigravity static config
|
||||
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil && cfg.Thinking != nil {
|
||||
return true, cfg.Thinking.Min, cfg.Thinking.Max, cfg.Thinking.ZeroAllowed, cfg.Thinking.DynamicAllowed
|
||||
}
|
||||
return false, 0, 0, false, false
|
||||
}
|
||||
|
||||
// GetModelThinkingLevels returns the discrete reasoning effort levels for the model.
|
||||
|
||||
@@ -4,16 +4,56 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var functionNameSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
|
||||
|
||||
// SanitizeFunctionName ensures a function name matches the requirements for Gemini/Vertex AI.
|
||||
// It replaces invalid characters with underscores, ensures it starts with a letter or underscore,
|
||||
// and truncates it to 64 characters if necessary.
|
||||
// Regex Rule: [^a-zA-Z0-9_.:-] replaced with _.
|
||||
func SanitizeFunctionName(name string) string {
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Replace invalid characters with underscore
|
||||
sanitized := functionNameSanitizer.ReplaceAllString(name, "_")
|
||||
|
||||
// Ensure it starts with a letter or underscore
|
||||
// Re-reading requirements: Must start with a letter or an underscore.
|
||||
if len(sanitized) > 0 {
|
||||
first := sanitized[0]
|
||||
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
|
||||
// If it starts with an allowed character but not allowed at the beginning (digit, dot, colon, dash),
|
||||
// we must prepend an underscore.
|
||||
|
||||
// To stay within the 64-character limit while prepending, we must truncate first.
|
||||
if len(sanitized) >= 64 {
|
||||
sanitized = sanitized[:63]
|
||||
}
|
||||
sanitized = "_" + sanitized
|
||||
}
|
||||
} else {
|
||||
sanitized = "_"
|
||||
}
|
||||
|
||||
// Truncate to 64 characters
|
||||
if len(sanitized) > 64 {
|
||||
sanitized = sanitized[:64]
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// SetLogLevel configures the logrus log level based on the configuration.
|
||||
// It sets the log level to DebugLevel if debug mode is enabled, otherwise to InfoLevel.
|
||||
func SetLogLevel(cfg *config.Config) {
|
||||
@@ -53,36 +93,23 @@ func ResolveAuthDir(authDir string) (string, error) {
|
||||
return filepath.Clean(authDir), nil
|
||||
}
|
||||
|
||||
// CountAuthFiles returns the number of JSON auth files located under the provided directory.
|
||||
// The function resolves leading tildes to the user's home directory and performs a case-insensitive
|
||||
// match on the ".json" suffix so that files saved with uppercase extensions are also counted.
|
||||
func CountAuthFiles(authDir string) int {
|
||||
dir, err := ResolveAuthDir(authDir)
|
||||
// CountAuthFiles returns the number of auth records available through the provided Store.
|
||||
// For filesystem-backed stores, this reflects the number of JSON auth files under the configured directory.
|
||||
func CountAuthFiles[T any](ctx context.Context, store interface {
|
||||
List(context.Context) ([]T, error)
|
||||
}) int {
|
||||
if store == nil {
|
||||
return 0
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
entries, err := store.List(ctx)
|
||||
if err != nil {
|
||||
log.Debugf("countAuthFiles: failed to resolve auth directory: %v", err)
|
||||
log.Debugf("countAuthFiles: failed to list auth records: %v", err)
|
||||
return 0
|
||||
}
|
||||
if dir == "" {
|
||||
return 0
|
||||
}
|
||||
count := 0
|
||||
walkErr := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
log.Debugf("countAuthFiles: error accessing %s: %v", path, err)
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
|
||||
count++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if walkErr != nil {
|
||||
log.Debugf("countAuthFiles: walk error: %v", walkErr)
|
||||
}
|
||||
return count
|
||||
return len(entries)
|
||||
}
|
||||
|
||||
// WritablePath returns the cleaned WRITABLE_PATH environment variable when it is set.
|
||||
|
||||
@@ -54,6 +54,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
|
||||
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
|
||||
}
|
||||
if oldCfg.NonStreamKeepAliveInterval != newCfg.NonStreamKeepAliveInterval {
|
||||
changes = append(changes, fmt.Sprintf("nonstream-keepalive-interval: %d -> %d", oldCfg.NonStreamKeepAliveInterval, newCfg.NonStreamKeepAliveInterval))
|
||||
}
|
||||
|
||||
// Quota-exceeded behavior
|
||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||
|
||||
@@ -231,10 +231,11 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
|
||||
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{"key-1"},
|
||||
ForceModelPrefix: false,
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{"key-1"},
|
||||
ForceModelPrefix: false,
|
||||
NonStreamKeepAliveInterval: 0,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
@@ -267,10 +268,11 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{" key-1 ", "key-2"},
|
||||
ForceModelPrefix: true,
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{" key-1 ", "key-2"},
|
||||
ForceModelPrefix: true,
|
||||
NonStreamKeepAliveInterval: 5,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -285,6 +287,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, details, "ws-auth: false -> true")
|
||||
expectContains(t, details, "force-model-prefix: false -> true")
|
||||
expectContains(t, details, "nonstream-keepalive-interval: 0 -> 5")
|
||||
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, details, "api-keys count: 1 -> 2")
|
||||
|
||||
@@ -80,6 +80,9 @@ func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMa
|
||||
continue
|
||||
}
|
||||
key := name + "->" + alias
|
||||
if mapping.Fork {
|
||||
key += "|fork"
|
||||
}
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package synthesizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
@@ -59,6 +60,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
|
||||
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if entry.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(entry.Priority)
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
@@ -103,6 +107,9 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea
|
||||
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if ck.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(ck.Priority)
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
@@ -147,6 +154,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
||||
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if ck.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(ck.Priority)
|
||||
}
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
}
|
||||
@@ -202,6 +212,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if compat.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
@@ -233,6 +246,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if compat.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||
}
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
@@ -275,6 +291,9 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor
|
||||
"base_url": base,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if compat.Priority != 0 {
|
||||
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
|
||||
@@ -146,10 +146,12 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
||||
c.Header("Content-Type", "application/json")
|
||||
alt := h.GetAlt(c)
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
@@ -159,13 +161,18 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
||||
// Decompress gzipped responses - Claude API sometimes returns gzip without Content-Encoding header
|
||||
// This fixes title generation and other non-streaming responses that arrive compressed
|
||||
if len(resp) >= 2 && resp[0] == 0x1f && resp[1] == 0x8b {
|
||||
gzReader, err := gzip.NewReader(bytes.NewReader(resp))
|
||||
if err != nil {
|
||||
log.Warnf("failed to decompress gzipped Claude response: %v", err)
|
||||
gzReader, errGzip := gzip.NewReader(bytes.NewReader(resp))
|
||||
if errGzip != nil {
|
||||
log.Warnf("failed to decompress gzipped Claude response: %v", errGzip)
|
||||
} else {
|
||||
defer gzReader.Close()
|
||||
if decompressed, err := io.ReadAll(gzReader); err != nil {
|
||||
log.Warnf("failed to read decompressed Claude response: %v", err)
|
||||
defer func() {
|
||||
if errClose := gzReader.Close(); errClose != nil {
|
||||
log.Warnf("failed to close Claude gzip reader: %v", errClose)
|
||||
}
|
||||
}()
|
||||
decompressed, errRead := io.ReadAll(gzReader)
|
||||
if errRead != nil {
|
||||
log.Warnf("failed to read decompressed Claude response: %v", errRead)
|
||||
} else {
|
||||
resp = decompressed
|
||||
}
|
||||
|
||||
@@ -336,7 +336,9 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
||||
c.Header("Content-Type", "application/json")
|
||||
alt := h.GetAlt(c)
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -113,6 +114,19 @@ func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// NonStreamingKeepAliveInterval returns the keep-alive interval for non-streaming responses.
|
||||
// Returning 0 disables keep-alives (default when unset).
|
||||
func NonStreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||
seconds := 0
|
||||
if cfg != nil {
|
||||
seconds = cfg.NonStreamKeepAliveInterval
|
||||
}
|
||||
if seconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
||||
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||
retries := defaultStreamingBootstrapRetries
|
||||
@@ -293,6 +307,53 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
}
|
||||
}
|
||||
|
||||
// StartNonStreamingKeepAlive emits blank lines every 5 seconds while waiting for a non-streaming response.
|
||||
// It returns a stop function that must be called before writing the final response.
|
||||
func (h *BaseAPIHandler) StartNonStreamingKeepAlive(c *gin.Context, ctx context.Context) func() {
|
||||
if h == nil || c == nil {
|
||||
return func() {}
|
||||
}
|
||||
interval := NonStreamingKeepAliveInterval(h.Cfg)
|
||||
if interval <= 0 {
|
||||
return func() {}
|
||||
}
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return func() {}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
var stopOnce sync.Once
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return func() {
|
||||
stopOnce.Do(func() {
|
||||
close(stopChan)
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// appendAPIResponse preserves any previously captured API response and appends new data.
|
||||
func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
if c == nil || len(data) == 0 {
|
||||
|
||||
@@ -56,6 +56,14 @@ func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, co
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Calls() int {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
@@ -524,7 +524,9 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
||||
|
||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
|
||||
@@ -103,20 +103,17 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
defer func() {
|
||||
cliCancel()
|
||||
}()
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
|
||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
cliCancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(resp)
|
||||
return
|
||||
|
||||
// no legacy fallback
|
||||
|
||||
cliCancel()
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses for Gemini models.
|
||||
|
||||
@@ -21,6 +21,7 @@ type ManagementTokenRequester interface {
|
||||
RequestIFlowToken(*gin.Context)
|
||||
RequestIFlowCookieToken(*gin.Context)
|
||||
GetAuthStatus(c *gin.Context)
|
||||
PostOAuthCallback(c *gin.Context)
|
||||
}
|
||||
|
||||
type managementTokenRequester struct {
|
||||
@@ -65,3 +66,7 @@ func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
|
||||
func (m *managementTokenRequester) GetAuthStatus(c *gin.Context) {
|
||||
m.handler.GetAuthStatus(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) PostOAuthCallback(c *gin.Context) {
|
||||
m.handler.PostOAuthCallback(c)
|
||||
}
|
||||
|
||||
@@ -60,6 +60,11 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := antigravityCallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
||||
|
||||
state, err := misc.GenerateRandomState()
|
||||
@@ -67,7 +72,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
return nil, fmt.Errorf("antigravity: failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
srv, port, cbChan, errServer := startAntigravityCallbackServer()
|
||||
srv, port, cbChan, errServer := startAntigravityCallbackServer(callbackPort)
|
||||
if errServer != nil {
|
||||
return nil, fmt.Errorf("antigravity: failed to start callback server: %w", errServer)
|
||||
}
|
||||
@@ -224,13 +229,16 @@ type callbackResult struct {
|
||||
State string
|
||||
}
|
||||
|
||||
func startAntigravityCallbackServer() (*http.Server, int, <-chan callbackResult, error) {
|
||||
addr := fmt.Sprintf(":%d", antigravityCallbackPort)
|
||||
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
||||
if port <= 0 {
|
||||
port = antigravityCallbackPort
|
||||
}
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
port = listener.Addr().(*net.TCPAddr).Port
|
||||
resultCh := make(chan callbackResult, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@@ -374,7 +382,7 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
|
||||
// Call loadCodeAssist to get the project
|
||||
loadReqBody := map[string]any{
|
||||
"metadata": map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
@@ -434,8 +442,134 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return "", fmt.Errorf("no cloudaicompanionProject in response")
|
||||
tierID := "legacy-tier"
|
||||
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||
for _, rawTier := range tiers {
|
||||
tier, okTier := rawTier.(map[string]any)
|
||||
if !okTier {
|
||||
continue
|
||||
}
|
||||
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||
tierID = strings.TrimSpace(id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
|
||||
// It returns an empty string when the operation times out or completes without a project ID.
|
||||
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
fmt.Println("Antigravity: onboarding user...", tierID)
|
||||
requestBody := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(requestBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
maxAttempts := 5
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||
|
||||
reqCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if reqCtx == nil {
|
||||
reqCtx = context.Background()
|
||||
}
|
||||
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
|
||||
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if errRequest != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
||||
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("close body error: %v", errClose)
|
||||
}
|
||||
cancel()
|
||||
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var data map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
if done, okDone := data["done"].(bool); okDone && done {
|
||||
projectID := ""
|
||||
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||
case map[string]any:
|
||||
if id, okID := projectValue["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
case string:
|
||||
projectID = strings.TrimSpace(projectValue)
|
||||
}
|
||||
}
|
||||
|
||||
if projectID != "" {
|
||||
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no project_id in response")
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||
if len(responsePreview) > 500 {
|
||||
responsePreview = responsePreview[:500]
|
||||
}
|
||||
|
||||
responseErr := responsePreview
|
||||
if len(responseErr) > 200 {
|
||||
responseErr = responseErr[:200]
|
||||
}
|
||||
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -47,6 +47,11 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
pkceCodes, err := claude.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude pkce generation failed: %w", err)
|
||||
@@ -57,7 +62,7 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
return nil, fmt.Errorf("claude state generation failed: %w", err)
|
||||
}
|
||||
|
||||
oauthServer := claude.NewOAuthServer(a.CallbackPort)
|
||||
oauthServer := claude.NewOAuthServer(callbackPort)
|
||||
if err = oauthServer.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "already in use") {
|
||||
return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err)
|
||||
@@ -84,15 +89,15 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
fmt.Println("Opening browser for Claude authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -47,6 +47,11 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
pkceCodes, err := codex.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codex pkce generation failed: %w", err)
|
||||
@@ -57,7 +62,7 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
return nil, fmt.Errorf("codex state generation failed: %w", err)
|
||||
}
|
||||
|
||||
oauthServer := codex.NewOAuthServer(a.CallbackPort)
|
||||
oauthServer := codex.NewOAuthServer(callbackPort)
|
||||
if err = oauthServer.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "already in use") {
|
||||
return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err)
|
||||
@@ -83,15 +88,15 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
fmt.Println("Opening browser for Codex authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -77,15 +79,23 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
||||
if metadataEqualIgnoringTimestamps(existing, raw) {
|
||||
return path, nil
|
||||
}
|
||||
} else if errRead != nil && !os.IsNotExist(errRead) {
|
||||
file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
|
||||
if errOpen != nil {
|
||||
return "", fmt.Errorf("auth filestore: open existing failed: %w", errOpen)
|
||||
}
|
||||
if _, errWrite := file.Write(raw); errWrite != nil {
|
||||
_ = file.Close()
|
||||
return "", fmt.Errorf("auth filestore: write existing failed: %w", errWrite)
|
||||
}
|
||||
if errClose := file.Close(); errClose != nil {
|
||||
return "", fmt.Errorf("auth filestore: close existing failed: %w", errClose)
|
||||
}
|
||||
return path, nil
|
||||
} else if !os.IsNotExist(errRead) {
|
||||
return "", fmt.Errorf("auth filestore: read existing failed: %w", errRead)
|
||||
}
|
||||
tmp := path + ".tmp"
|
||||
if errWrite := os.WriteFile(tmp, raw, 0o600); errWrite != nil {
|
||||
return "", fmt.Errorf("auth filestore: write temp failed: %w", errWrite)
|
||||
}
|
||||
if errRename := os.Rename(tmp, path); errRename != nil {
|
||||
return "", fmt.Errorf("auth filestore: rename failed: %w", errRename)
|
||||
if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil {
|
||||
return "", fmt.Errorf("auth filestore: write file failed: %w", errWrite)
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("auth filestore: nothing to persist for %s", auth.ID)
|
||||
@@ -178,6 +188,30 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
if provider == "" {
|
||||
provider = "unknown"
|
||||
}
|
||||
if provider == "antigravity" {
|
||||
projectID := ""
|
||||
if pid, ok := metadata["project_id"].(string); ok {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
if projectID == "" {
|
||||
accessToken := ""
|
||||
if token, ok := metadata["access_token"].(string); ok {
|
||||
accessToken = strings.TrimSpace(token)
|
||||
}
|
||||
if accessToken != "" {
|
||||
fetchedProjectID, errFetch := FetchAntigravityProjectID(context.Background(), accessToken, http.DefaultClient)
|
||||
if errFetch == nil && strings.TrimSpace(fetchedProjectID) != "" {
|
||||
metadata["project_id"] = strings.TrimSpace(fetchedProjectID)
|
||||
if raw, errMarshal := json.Marshal(metadata); errMarshal == nil {
|
||||
if file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600); errOpen == nil {
|
||||
_, _ = file.Write(raw)
|
||||
_ = file.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat file: %w", err)
|
||||
@@ -266,92 +300,28 @@ func (s *FileTokenStore) baseDirSnapshot() string {
|
||||
return s.baseDir
|
||||
}
|
||||
|
||||
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata.
|
||||
// This function is kept for backward compatibility but can cause refresh loops.
|
||||
func jsonEqual(a, b []byte) bool {
|
||||
var objA any
|
||||
var objB any
|
||||
if err := json.Unmarshal(a, &objA); err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(b, &objB); err != nil {
|
||||
return false
|
||||
}
|
||||
return deepEqualJSON(objA, objB)
|
||||
}
|
||||
|
||||
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
|
||||
// ignoring fields that change on every refresh but don't affect functionality.
|
||||
// This prevents unnecessary file writes that would trigger watcher events and
|
||||
// create refresh loops.
|
||||
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs, ignoring volatile fields that
|
||||
// change on every refresh but don't affect authentication logic.
|
||||
func metadataEqualIgnoringTimestamps(a, b []byte) bool {
|
||||
var objA, objB map[string]any
|
||||
if err := json.Unmarshal(a, &objA); err != nil {
|
||||
var objA map[string]any
|
||||
var objB map[string]any
|
||||
if errUnmarshalA := json.Unmarshal(a, &objA); errUnmarshalA != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(b, &objB); err != nil {
|
||||
if errUnmarshalB := json.Unmarshal(b, &objB); errUnmarshalB != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fields to ignore: these change on every refresh but don't affect authentication logic.
|
||||
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
|
||||
// - access_token: Google OAuth returns a new access_token on each refresh, this is expected
|
||||
// and shouldn't trigger file writes (the new token will be fetched again when needed)
|
||||
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"}
|
||||
for _, field := range ignoredFields {
|
||||
delete(objA, field)
|
||||
delete(objB, field)
|
||||
}
|
||||
|
||||
return deepEqualJSON(objA, objB)
|
||||
stripVolatileMetadataFields(objA)
|
||||
stripVolatileMetadataFields(objB)
|
||||
return reflect.DeepEqual(objA, objB)
|
||||
}
|
||||
|
||||
func deepEqualJSON(a, b any) bool {
|
||||
switch valA := a.(type) {
|
||||
case map[string]any:
|
||||
valB, ok := b.(map[string]any)
|
||||
if !ok || len(valA) != len(valB) {
|
||||
return false
|
||||
}
|
||||
for key, subA := range valA {
|
||||
subB, ok1 := valB[key]
|
||||
if !ok1 || !deepEqualJSON(subA, subB) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case []any:
|
||||
sliceB, ok := b.([]any)
|
||||
if !ok || len(valA) != len(sliceB) {
|
||||
return false
|
||||
}
|
||||
for i := range valA {
|
||||
if !deepEqualJSON(valA[i], sliceB[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case float64:
|
||||
valB, ok := b.(float64)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case string:
|
||||
valB, ok := b.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case bool:
|
||||
valB, ok := b.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case nil:
|
||||
return b == nil
|
||||
default:
|
||||
return false
|
||||
func stripVolatileMetadataFields(metadata map[string]any) {
|
||||
if metadata == nil {
|
||||
return
|
||||
}
|
||||
// These fields change on refresh and would otherwise trigger watcher reload loops.
|
||||
for _, field := range []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"} {
|
||||
delete(metadata, field)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,8 +45,9 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
|
||||
geminiAuth := gemini.NewGeminiAuth()
|
||||
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
|
||||
NoBrowser: opts.NoBrowser,
|
||||
Prompt: opts.Prompt,
|
||||
NoBrowser: opts.NoBrowser,
|
||||
CallbackPort: opts.CallbackPort,
|
||||
Prompt: opts.Prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemini authentication failed: %w", err)
|
||||
|
||||
@@ -42,9 +42,14 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := iflow.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
authSvc := iflow.NewIFlowAuth(cfg)
|
||||
|
||||
oauthServer := iflow.NewOAuthServer(iflow.CallbackPort)
|
||||
oauthServer := iflow.NewOAuthServer(callbackPort)
|
||||
if err := oauthServer.Start(); err != nil {
|
||||
if strings.Contains(err.Error(), "already in use") {
|
||||
return nil, fmt.Errorf("iflow authentication server port in use: %w", err)
|
||||
@@ -64,21 +69,21 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
return nil, fmt.Errorf("iflow auth: failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflow.CallbackPort)
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, callbackPort)
|
||||
|
||||
if !opts.NoBrowser {
|
||||
fmt.Println("Opening browser for iFlow authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", err)
|
||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,10 +14,11 @@ var ErrRefreshNotSupported = errors.New("cliproxy auth: refresh not supported")
|
||||
// LoginOptions captures generic knobs shared across authenticators.
|
||||
// Provider-specific logic can inspect Metadata for extra parameters.
|
||||
type LoginOptions struct {
|
||||
NoBrowser bool
|
||||
ProjectID string
|
||||
Metadata map[string]string
|
||||
Prompt func(prompt string) (string, error)
|
||||
NoBrowser bool
|
||||
ProjectID string
|
||||
CallbackPort int
|
||||
Metadata map[string]string
|
||||
Prompt func(prompt string) (string, error)
|
||||
}
|
||||
|
||||
// Authenticator manages login and optional refresh flows for a provider.
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -31,6 +34,9 @@ type ProviderExecutor interface {
|
||||
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
|
||||
// CountTokens returns the token count for the given request.
|
||||
CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
|
||||
// HttpRequest injects provider credentials into the supplied HTTP request and executes it.
|
||||
// Callers must close the response body when non-nil.
|
||||
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// RefreshEvaluator allows runtime state to override refresh decisions.
|
||||
@@ -265,7 +271,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
if len(normalized) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -275,14 +280,12 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
||||
return m.executeWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
@@ -303,7 +306,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
if len(normalized) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -313,14 +315,12 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
||||
return m.executeCountWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
@@ -341,7 +341,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
if len(normalized) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
@@ -351,14 +350,12 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
return m.executeStreamWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||
if errStream == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errStream
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait)
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
@@ -372,6 +369,167 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
@@ -388,22 +546,8 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
if accountType == "api_key" {
|
||||
if proxyInfo != "" {
|
||||
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
if proxyInfo != "" {
|
||||
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -450,22 +594,8 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
if accountType == "api_key" {
|
||||
if proxyInfo != "" {
|
||||
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
if proxyInfo != "" {
|
||||
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -512,22 +642,8 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
if accountType == "api_key" {
|
||||
if proxyInfo != "" {
|
||||
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
if proxyInfo != "" {
|
||||
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -1227,6 +1343,77 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
return authCopy, executor, nil
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
p := strings.TrimSpace(strings.ToLower(provider))
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
providerSet[p] = struct{}{}
|
||||
}
|
||||
if len(providerSet) == 0 {
|
||||
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
candidates := make([]*Auth, 0, len(m.auths))
|
||||
modelKey := strings.TrimSpace(model)
|
||||
registryRef := registry.GetGlobalRegistry()
|
||||
for _, candidate := range m.auths {
|
||||
if candidate == nil || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := providerSet[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, used := tried[candidate.ID]; used {
|
||||
continue
|
||||
}
|
||||
if _, ok := m.executors[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
|
||||
if errPick != nil {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, "", errPick
|
||||
}
|
||||
if selected == nil {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
|
||||
executor, okExecutor := m.executors[providerKey]
|
||||
if !okExecutor {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
authCopy := selected.Clone()
|
||||
m.mu.RUnlock()
|
||||
if !selected.indexAssigned {
|
||||
m.mu.Lock()
|
||||
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||
current.EnsureIndex()
|
||||
authCopy = current.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return authCopy, executor, providerKey, nil
|
||||
}
|
||||
|
||||
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||
if m.store == nil || auth == nil {
|
||||
return nil
|
||||
@@ -1536,6 +1723,9 @@ func (m *Manager) markRefreshPending(id string, now time.Time) bool {
|
||||
}
|
||||
|
||||
func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
m.mu.RLock()
|
||||
auth := m.auths[id]
|
||||
var exec ProviderExecutor
|
||||
@@ -1548,6 +1738,10 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
}
|
||||
cloned := auth.Clone()
|
||||
updated, err := exec.Refresh(ctx, cloned)
|
||||
if err != nil && errors.Is(err, context.Canceled) {
|
||||
log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID)
|
||||
return
|
||||
}
|
||||
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
|
||||
now := time.Now()
|
||||
if err != nil {
|
||||
@@ -1606,6 +1800,23 @@ type RequestPreparer interface {
|
||||
PrepareRequest(req *http.Request, auth *Auth) error
|
||||
}
|
||||
|
||||
func executorKeyFromAuth(auth *Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
providerKey := strings.TrimSpace(auth.Attributes["provider_key"])
|
||||
compatName := strings.TrimSpace(auth.Attributes["compat_name"])
|
||||
if compatName != "" {
|
||||
if providerKey == "" {
|
||||
providerKey = compatName
|
||||
}
|
||||
return strings.ToLower(providerKey)
|
||||
}
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
}
|
||||
|
||||
// logEntryWithRequestID returns a logrus entry with request_id field if available in context.
|
||||
func logEntryWithRequestID(ctx context.Context) *log.Entry {
|
||||
if ctx == nil {
|
||||
@@ -1617,6 +1828,59 @@ func logEntryWithRequestID(ctx context.Context) *log.Entry {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
|
||||
func debugLogAuthSelection(entry *log.Entry, auth *Auth, provider string, model string) {
|
||||
if !log.IsLevelEnabled(log.DebugLevel) {
|
||||
return
|
||||
}
|
||||
if entry == nil || auth == nil {
|
||||
return
|
||||
}
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
suffix := ""
|
||||
if proxyInfo != "" {
|
||||
suffix = " " + proxyInfo
|
||||
}
|
||||
switch accountType {
|
||||
case "api_key":
|
||||
entry.Debugf("Use API key %s for model %s%s", util.HideAPIKey(accountInfo), model, suffix)
|
||||
case "oauth":
|
||||
ident := formatOauthIdentity(auth, provider, accountInfo)
|
||||
entry.Debugf("Use OAuth %s for model %s%s", ident, model, suffix)
|
||||
}
|
||||
}
|
||||
|
||||
func formatOauthIdentity(auth *Auth, provider string, accountInfo string) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
// Prefer the auth's provider when available.
|
||||
providerName := strings.TrimSpace(auth.Provider)
|
||||
if providerName == "" {
|
||||
providerName = strings.TrimSpace(provider)
|
||||
}
|
||||
// Only log the basename to avoid leaking host paths.
|
||||
// FileName may be unset for some auth backends; fall back to ID.
|
||||
authFile := strings.TrimSpace(auth.FileName)
|
||||
if authFile == "" {
|
||||
authFile = strings.TrimSpace(auth.ID)
|
||||
}
|
||||
if authFile != "" {
|
||||
authFile = filepath.Base(authFile)
|
||||
}
|
||||
parts := make([]string, 0, 3)
|
||||
if providerName != "" {
|
||||
parts = append(parts, "provider="+providerName)
|
||||
}
|
||||
if authFile != "" {
|
||||
parts = append(parts, "auth_file="+authFile)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return accountInfo
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// InjectCredentials delegates per-provider HTTP request preparation when supported.
|
||||
// If the registered executor for the auth provider implements RequestPreparer,
|
||||
// it will be invoked to modify the request (e.g., add headers).
|
||||
@@ -1628,7 +1892,7 @@ func (m *Manager) InjectCredentials(req *http.Request, authID string) error {
|
||||
a := m.auths[authID]
|
||||
var exec ProviderExecutor
|
||||
if a != nil {
|
||||
exec = m.executors[a.Provider]
|
||||
exec = m.executors[executorKeyFromAuth(a)]
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
if a == nil || exec == nil {
|
||||
@@ -1639,3 +1903,80 @@ func (m *Manager) InjectCredentials(req *http.Request, authID string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrepareHttpRequest injects provider credentials into the supplied HTTP request.
|
||||
func (m *Manager) PrepareHttpRequest(ctx context.Context, auth *Auth, req *http.Request) error {
|
||||
if m == nil {
|
||||
return &Error{Code: "provider_not_found", Message: "manager is nil"}
|
||||
}
|
||||
if auth == nil {
|
||||
return &Error{Code: "auth_not_found", Message: "auth is nil"}
|
||||
}
|
||||
if req == nil {
|
||||
return &Error{Code: "invalid_request", Message: "http request is nil"}
|
||||
}
|
||||
if ctx != nil {
|
||||
*req = *req.WithContext(ctx)
|
||||
}
|
||||
providerKey := executorKeyFromAuth(auth)
|
||||
if providerKey == "" {
|
||||
return &Error{Code: "provider_not_found", Message: "auth provider is empty"}
|
||||
}
|
||||
exec := m.executorFor(providerKey)
|
||||
if exec == nil {
|
||||
return &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey}
|
||||
}
|
||||
preparer, ok := exec.(RequestPreparer)
|
||||
if !ok || preparer == nil {
|
||||
return &Error{Code: "not_supported", Message: "executor does not support http request preparation"}
|
||||
}
|
||||
return preparer.PrepareRequest(req, auth)
|
||||
}
|
||||
|
||||
// NewHttpRequest constructs a new HTTP request and injects provider credentials into it.
|
||||
func (m *Manager) NewHttpRequest(ctx context.Context, auth *Auth, method, targetURL string, body []byte, headers http.Header) (*http.Request, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
method = strings.TrimSpace(method)
|
||||
if method == "" {
|
||||
method = http.MethodGet
|
||||
}
|
||||
var reader io.Reader
|
||||
if body != nil {
|
||||
reader = bytes.NewReader(body)
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, method, targetURL, reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if headers != nil {
|
||||
httpReq.Header = headers.Clone()
|
||||
}
|
||||
if errPrepare := m.PrepareHttpRequest(ctx, auth, httpReq); errPrepare != nil {
|
||||
return nil, errPrepare
|
||||
}
|
||||
return httpReq, nil
|
||||
}
|
||||
|
||||
// HttpRequest injects provider credentials into the supplied HTTP request and executes it.
|
||||
func (m *Manager) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||
if m == nil {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "manager is nil"}
|
||||
}
|
||||
if auth == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "auth is nil"}
|
||||
}
|
||||
if req == nil {
|
||||
return nil, &Error{Code: "invalid_request", Message: "http request is nil"}
|
||||
}
|
||||
providerKey := executorKeyFromAuth(auth)
|
||||
if providerKey == "" {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "auth provider is empty"}
|
||||
}
|
||||
exec := m.executorFor(providerKey)
|
||||
if exec == nil {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey}
|
||||
}
|
||||
return exec.HttpRequest(ctx, auth, req)
|
||||
}
|
||||
|
||||
@@ -81,7 +81,9 @@ func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, meta
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
out[util.ModelMappingOriginalModelMetadataKey] = upstreamModel
|
||||
// Store the requested alias (e.g., "gp") so downstream can use it to look up
|
||||
// model metadata from the global registry where it was registered under this alias.
|
||||
out[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
return upstreamModel, out
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -103,13 +104,29 @@ func (e *modelCooldownError) Headers() http.Header {
|
||||
return headers
|
||||
}
|
||||
|
||||
func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) {
|
||||
available = make([]*Auth, 0, len(auths))
|
||||
func authPriority(auth *Auth) int {
|
||||
if auth == nil || auth.Attributes == nil {
|
||||
return 0
|
||||
}
|
||||
raw := strings.TrimSpace(auth.Attributes["priority"])
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
parsed, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||
available = make(map[int][]*Auth)
|
||||
for i := 0; i < len(auths); i++ {
|
||||
candidate := auths[i]
|
||||
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
|
||||
if !blocked {
|
||||
available = append(available, candidate)
|
||||
priority := authPriority(candidate)
|
||||
available[priority] = append(available[priority], candidate)
|
||||
continue
|
||||
}
|
||||
if reason == blockReasonCooldown {
|
||||
@@ -119,9 +136,6 @@ func collectAvailable(auths []*Auth, model string, now time.Time) (available []*
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(available) > 1 {
|
||||
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
||||
}
|
||||
return available, cooldownCount, earliest
|
||||
}
|
||||
|
||||
@@ -130,18 +144,35 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
|
||||
}
|
||||
|
||||
available, cooldownCount, earliest := collectAvailable(auths, model, now)
|
||||
if len(available) == 0 {
|
||||
availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
|
||||
if len(availableByPriority) == 0 {
|
||||
if cooldownCount == len(auths) && !earliest.IsZero() {
|
||||
providerForError := provider
|
||||
if providerForError == "mixed" {
|
||||
providerForError = ""
|
||||
}
|
||||
resetIn := earliest.Sub(now)
|
||||
if resetIn < 0 {
|
||||
resetIn = 0
|
||||
}
|
||||
return nil, newModelCooldownError(model, provider, resetIn)
|
||||
return nil, newModelCooldownError(model, providerForError, resetIn)
|
||||
}
|
||||
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
|
||||
bestPriority := 0
|
||||
found := false
|
||||
for priority := range availableByPriority {
|
||||
if !found || priority > bestPriority {
|
||||
bestPriority = priority
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
available := availableByPriority[bestPriority]
|
||||
if len(available) > 1 {
|
||||
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
||||
}
|
||||
return available, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
@@ -56,6 +57,69 @@ func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
auths := []*Auth{
|
||||
{ID: "c", Attributes: map[string]string{"priority": "0"}},
|
||||
{ID: "a", Attributes: map[string]string{"priority": "10"}},
|
||||
{ID: "b", Attributes: map[string]string{"priority": "10"}},
|
||||
}
|
||||
|
||||
want := []string{"a", "b", "a", "b"}
|
||||
for i, id := range want {
|
||||
got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
if got.ID != id {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
|
||||
}
|
||||
if got.ID == "c" {
|
||||
t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &FillFirstSelector{}
|
||||
now := time.Now()
|
||||
model := "test-model"
|
||||
|
||||
high := &Auth{
|
||||
ID: "high",
|
||||
Attributes: map[string]string{"priority": "10"},
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Status: StatusActive,
|
||||
Unavailable: true,
|
||||
NextRetryAfter: now.Add(30 * time.Minute),
|
||||
Quota: QuotaState{
|
||||
Exceeded: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}}
|
||||
|
||||
got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low})
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() error = %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() auth = nil")
|
||||
}
|
||||
if got.ID != "low" {
|
||||
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
|
||||
selector := &RoundRobinSelector{}
|
||||
auths := []*Auth{
|
||||
|
||||
@@ -5,6 +5,9 @@ import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
// ModelInfo re-exports the registry model info structure.
|
||||
type ModelInfo = registry.ModelInfo
|
||||
|
||||
// ModelRegistryHook re-exports the registry hook interface for external integrations.
|
||||
type ModelRegistryHook = registry.ModelRegistryHook
|
||||
|
||||
// ModelRegistry describes registry operations consumed by external callers.
|
||||
type ModelRegistry interface {
|
||||
RegisterClient(clientID, clientProvider string, models []*ModelInfo)
|
||||
@@ -13,9 +16,15 @@ type ModelRegistry interface {
|
||||
ClearModelQuotaExceeded(clientID, modelID string)
|
||||
ClientSupportsModel(clientID, modelID string) bool
|
||||
GetAvailableModels(handlerType string) []map[string]any
|
||||
GetAvailableModelsByProvider(provider string) []*ModelInfo
|
||||
}
|
||||
|
||||
// GlobalModelRegistry returns the shared registry instance.
|
||||
func GlobalModelRegistry() ModelRegistry {
|
||||
return registry.GetGlobalRegistry()
|
||||
}
|
||||
|
||||
// SetGlobalModelRegistryHook registers an optional hook on the shared global registry instance.
|
||||
func SetGlobalModelRegistryHook(hook ModelRegistryHook) {
|
||||
registry.GetGlobalRegistry().SetHook(hook)
|
||||
}
|
||||
|
||||
@@ -1231,7 +1231,13 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode
|
||||
if len(mappings) == 0 {
|
||||
return models
|
||||
}
|
||||
forward := make(map[string]string, len(mappings))
|
||||
|
||||
type mappingEntry struct {
|
||||
alias string
|
||||
fork bool
|
||||
}
|
||||
|
||||
forward := make(map[string][]mappingEntry, len(mappings))
|
||||
for i := range mappings {
|
||||
name := strings.TrimSpace(mappings[i].Name)
|
||||
alias := strings.TrimSpace(mappings[i].Alias)
|
||||
@@ -1242,14 +1248,12 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(name)
|
||||
if _, exists := forward[key]; exists {
|
||||
continue
|
||||
}
|
||||
forward[key] = alias
|
||||
forward[key] = append(forward[key], mappingEntry{alias: alias, fork: mappings[i].Fork})
|
||||
}
|
||||
if len(forward) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
out := make([]*ModelInfo, 0, len(models))
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for _, model := range models {
|
||||
@@ -1260,25 +1264,61 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
mappedID := id
|
||||
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
|
||||
mappedID = strings.TrimSpace(to)
|
||||
}
|
||||
uniqueKey := strings.ToLower(mappedID)
|
||||
if _, exists := seen[uniqueKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[uniqueKey] = struct{}{}
|
||||
if mappedID == id {
|
||||
key := strings.ToLower(id)
|
||||
entries := forward[key]
|
||||
if len(entries) == 0 {
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, model)
|
||||
continue
|
||||
}
|
||||
clone := *model
|
||||
clone.ID = mappedID
|
||||
if clone.Name != "" {
|
||||
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
|
||||
|
||||
keepOriginal := false
|
||||
for _, entry := range entries {
|
||||
if entry.fork {
|
||||
keepOriginal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if keepOriginal {
|
||||
if _, exists := seen[key]; !exists {
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, model)
|
||||
}
|
||||
}
|
||||
|
||||
addedAlias := false
|
||||
for _, entry := range entries {
|
||||
mappedID := strings.TrimSpace(entry.alias)
|
||||
if mappedID == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(mappedID, id) {
|
||||
continue
|
||||
}
|
||||
aliasKey := strings.ToLower(mappedID)
|
||||
if _, exists := seen[aliasKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[aliasKey] = struct{}{}
|
||||
clone := *model
|
||||
clone.ID = mappedID
|
||||
if clone.Name != "" {
|
||||
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
|
||||
}
|
||||
out = append(out, &clone)
|
||||
addedAlias = true
|
||||
}
|
||||
|
||||
if !keepOriginal && !addedAlias {
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, model)
|
||||
}
|
||||
out = append(out, &clone)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
92
sdk/cliproxy/service_oauth_model_mappings_test.go
Normal file
92
sdk/cliproxy/service_oauth_model_mappings_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestApplyOAuthModelMappings_Rename(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OAuthModelMappings: map[string][]config.ModelNameMapping{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5"},
|
||||
},
|
||||
},
|
||||
}
|
||||
models := []*ModelInfo{
|
||||
{ID: "gpt-5", Name: "models/gpt-5"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(out))
|
||||
}
|
||||
if out[0].ID != "g5" {
|
||||
t.Fatalf("expected model id %q, got %q", "g5", out[0].ID)
|
||||
}
|
||||
if out[0].Name != "models/g5" {
|
||||
t.Fatalf("expected model name %q, got %q", "models/g5", out[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelMappings_ForkAddsAlias(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OAuthModelMappings: map[string][]config.ModelNameMapping{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
models := []*ModelInfo{
|
||||
{ID: "gpt-5", Name: "models/gpt-5"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2 models, got %d", len(out))
|
||||
}
|
||||
if out[0].ID != "gpt-5" {
|
||||
t.Fatalf("expected first model id %q, got %q", "gpt-5", out[0].ID)
|
||||
}
|
||||
if out[1].ID != "g5" {
|
||||
t.Fatalf("expected second model id %q, got %q", "g5", out[1].ID)
|
||||
}
|
||||
if out[1].Name != "models/g5" {
|
||||
t.Fatalf("expected forked model name %q, got %q", "models/g5", out[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelMappings_ForkAddsMultipleAliases(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OAuthModelMappings: map[string][]config.ModelNameMapping{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5", Fork: true},
|
||||
{Name: "gpt-5", Alias: "g5-2", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
models := []*ModelInfo{
|
||||
{ID: "gpt-5", Name: "models/gpt-5"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("expected 3 models, got %d", len(out))
|
||||
}
|
||||
if out[0].ID != "gpt-5" {
|
||||
t.Fatalf("expected first model id %q, got %q", "gpt-5", out[0].ID)
|
||||
}
|
||||
if out[1].ID != "g5" {
|
||||
t.Fatalf("expected second model id %q, got %q", "g5", out[1].ID)
|
||||
}
|
||||
if out[1].Name != "models/g5" {
|
||||
t.Fatalf("expected forked model name %q, got %q", "models/g5", out[1].Name)
|
||||
}
|
||||
if out[2].ID != "g5-2" {
|
||||
t.Fatalf("expected third model id %q, got %q", "g5-2", out[2].ID)
|
||||
}
|
||||
if out[2].Name != "models/g5-2" {
|
||||
t.Fatalf("expected forked model name %q, got %q", "models/g5-2", out[2].Name)
|
||||
}
|
||||
}
|
||||
54
test/builtin_tools_translation_test.go
Normal file
54
test/builtin_tools_translation_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestOpenAIToCodex_PreservesBuiltinTools(t *testing.T) {
|
||||
in := []byte(`{
|
||||
"model":"gpt-5",
|
||||
"messages":[{"role":"user","content":"hi"}],
|
||||
"tools":[{"type":"web_search","search_context_size":"high"}],
|
||||
"tool_choice":{"type":"web_search"}
|
||||
}`)
|
||||
|
||||
out := sdktranslator.TranslateRequest(sdktranslator.FormatOpenAI, sdktranslator.FormatCodex, "gpt-5", in, false)
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.#").Int(); got != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d: %s", got, string(out))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.type").String(); got != "web_search" {
|
||||
t.Fatalf("expected tools[0].type=web_search, got %q: %s", got, string(out))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.search_context_size").String(); got != "high" {
|
||||
t.Fatalf("expected tools[0].search_context_size=high, got %q: %s", got, string(out))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tool_choice.type").String(); got != "web_search" {
|
||||
t.Fatalf("expected tool_choice.type=web_search, got %q: %s", got, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesToOpenAI_PreservesBuiltinTools(t *testing.T) {
|
||||
in := []byte(`{
|
||||
"model":"gpt-5",
|
||||
"input":[{"role":"user","content":[{"type":"input_text","text":"hi"}]}],
|
||||
"tools":[{"type":"web_search","search_context_size":"low"}]
|
||||
}`)
|
||||
|
||||
out := sdktranslator.TranslateRequest(sdktranslator.FormatOpenAIResponse, sdktranslator.FormatOpenAI, "gpt-5", in, false)
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.#").Int(); got != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d: %s", got, string(out))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.type").String(); got != "web_search" {
|
||||
t.Fatalf("expected tools[0].type=web_search, got %q: %s", got, string(out))
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.search_context_size").String(); got != "low" {
|
||||
t.Fatalf("expected tools[0].search_context_size=low, got %q: %s", got, string(out))
|
||||
}
|
||||
}
|
||||
211
test/model_alias_thinking_suffix_test.go
Normal file
211
test/model_alias_thinking_suffix_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// TestModelAliasThinkingSuffix tests the 32 test cases defined in docs/thinking_suffix_test_cases.md
|
||||
// These tests verify the thinking suffix parsing and application logic across different providers.
|
||||
func TestModelAliasThinkingSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
id int
|
||||
name string
|
||||
provider string
|
||||
requestModel string
|
||||
suffixType string
|
||||
expectedField string // "thinkingBudget", "thinkingLevel", "budget_tokens", "reasoning_effort", "enable_thinking"
|
||||
expectedValue any
|
||||
upstreamModel string // The upstream model after alias resolution
|
||||
isAlias bool
|
||||
}{
|
||||
// === 1. Antigravity Provider ===
|
||||
// 1.1 Budget-only models (Gemini 2.5)
|
||||
{1, "antigravity_original_numeric", "antigravity", "gemini-2.5-computer-use-preview-10-2025(1000)", "numeric", "thinkingBudget", 1000, "gemini-2.5-computer-use-preview-10-2025", false},
|
||||
{2, "antigravity_alias_numeric", "antigravity", "gp(1000)", "numeric", "thinkingBudget", 1000, "gemini-2.5-computer-use-preview-10-2025", true},
|
||||
// 1.2 Budget+Levels models (Gemini 3)
|
||||
{3, "antigravity_original_numeric_to_level", "antigravity", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{4, "antigravity_original_level", "antigravity", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{5, "antigravity_alias_numeric_to_level", "antigravity", "gf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
{6, "antigravity_alias_level", "antigravity", "gf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
|
||||
// === 2. Gemini CLI Provider ===
|
||||
// 2.1 Budget-only models
|
||||
{7, "gemini_cli_original_numeric", "gemini-cli", "gemini-2.5-pro(8192)", "numeric", "thinkingBudget", 8192, "gemini-2.5-pro", false},
|
||||
{8, "gemini_cli_alias_numeric", "gemini-cli", "g25p(8192)", "numeric", "thinkingBudget", 8192, "gemini-2.5-pro", true},
|
||||
// 2.2 Budget+Levels models
|
||||
{9, "gemini_cli_original_numeric_to_level", "gemini-cli", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{10, "gemini_cli_original_level", "gemini-cli", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{11, "gemini_cli_alias_numeric_to_level", "gemini-cli", "gf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
{12, "gemini_cli_alias_level", "gemini-cli", "gf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
|
||||
// === 3. Vertex Provider ===
|
||||
// 3.1 Budget-only models
|
||||
{13, "vertex_original_numeric", "vertex", "gemini-2.5-pro(16384)", "numeric", "thinkingBudget", 16384, "gemini-2.5-pro", false},
|
||||
{14, "vertex_alias_numeric", "vertex", "vg25p(16384)", "numeric", "thinkingBudget", 16384, "gemini-2.5-pro", true},
|
||||
// 3.2 Budget+Levels models
|
||||
{15, "vertex_original_numeric_to_level", "vertex", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{16, "vertex_original_level", "vertex", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{17, "vertex_alias_numeric_to_level", "vertex", "vgf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
{18, "vertex_alias_level", "vertex", "vgf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
|
||||
// === 4. AI Studio Provider ===
|
||||
// 4.1 Budget-only models
|
||||
{19, "aistudio_original_numeric", "aistudio", "gemini-2.5-pro(12000)", "numeric", "thinkingBudget", 12000, "gemini-2.5-pro", false},
|
||||
{20, "aistudio_alias_numeric", "aistudio", "ag25p(12000)", "numeric", "thinkingBudget", 12000, "gemini-2.5-pro", true},
|
||||
// 4.2 Budget+Levels models
|
||||
{21, "aistudio_original_numeric_to_level", "aistudio", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{22, "aistudio_original_level", "aistudio", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
|
||||
{23, "aistudio_alias_numeric_to_level", "aistudio", "agf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
{24, "aistudio_alias_level", "aistudio", "agf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
|
||||
|
||||
// === 5. Claude Provider ===
|
||||
{25, "claude_original_numeric", "claude", "claude-sonnet-4-5-20250929(16384)", "numeric", "budget_tokens", 16384, "claude-sonnet-4-5-20250929", false},
|
||||
{26, "claude_alias_numeric", "claude", "cs45(16384)", "numeric", "budget_tokens", 16384, "claude-sonnet-4-5-20250929", true},
|
||||
|
||||
// === 6. Codex Provider ===
|
||||
{27, "codex_original_level", "codex", "gpt-5(high)", "level", "reasoning_effort", "high", "gpt-5", false},
|
||||
{28, "codex_alias_level", "codex", "g5(high)", "level", "reasoning_effort", "high", "gpt-5", true},
|
||||
|
||||
// === 7. Qwen Provider ===
|
||||
{29, "qwen_original_level", "qwen", "qwen3-coder-plus(high)", "level", "enable_thinking", true, "qwen3-coder-plus", false},
|
||||
{30, "qwen_alias_level", "qwen", "qcp(high)", "level", "enable_thinking", true, "qwen3-coder-plus", true},
|
||||
|
||||
// === 8. iFlow Provider ===
|
||||
{31, "iflow_original_level", "iflow", "glm-4.7(high)", "level", "reasoning_effort", "high", "glm-4.7", false},
|
||||
{32, "iflow_alias_level", "iflow", "glm(high)", "level", "reasoning_effort", "high", "glm-4.7", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Step 1: Parse model suffix (simulates SDK layer normalization)
|
||||
// For "gp(1000)" -> requestedModel="gp", metadata={thinking_budget: 1000}
|
||||
requestedModel, metadata := util.NormalizeThinkingModel(tt.requestModel)
|
||||
|
||||
// Verify suffix was parsed
|
||||
if metadata == nil && (tt.suffixType == "numeric" || tt.suffixType == "level") {
|
||||
t.Errorf("Case #%d: NormalizeThinkingModel(%q) metadata is nil", tt.id, tt.requestModel)
|
||||
return
|
||||
}
|
||||
|
||||
// Step 2: Simulate OAuth model mapping
|
||||
// Real flow: applyOAuthModelMapping stores requestedModel (the alias) in metadata
|
||||
if tt.isAlias {
|
||||
if metadata == nil {
|
||||
metadata = make(map[string]any)
|
||||
}
|
||||
metadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
}
|
||||
|
||||
// Step 3: Verify metadata extraction
|
||||
switch tt.suffixType {
|
||||
case "numeric":
|
||||
budget, _, _, matched := util.ThinkingFromMetadata(metadata)
|
||||
if !matched {
|
||||
t.Errorf("Case #%d: ThinkingFromMetadata did not match", tt.id)
|
||||
return
|
||||
}
|
||||
if budget == nil {
|
||||
t.Errorf("Case #%d: expected budget in metadata", tt.id)
|
||||
return
|
||||
}
|
||||
// For thinkingBudget/budget_tokens, verify the parsed budget value
|
||||
if tt.expectedField == "thinkingBudget" || tt.expectedField == "budget_tokens" {
|
||||
expectedBudget := tt.expectedValue.(int)
|
||||
if *budget != expectedBudget {
|
||||
t.Errorf("Case #%d: budget = %d, want %d", tt.id, *budget, expectedBudget)
|
||||
}
|
||||
}
|
||||
// For thinkingLevel (Gemini 3), verify conversion from budget to level
|
||||
if tt.expectedField == "thinkingLevel" {
|
||||
level, ok := util.ThinkingBudgetToGemini3Level(tt.upstreamModel, *budget)
|
||||
if !ok {
|
||||
t.Errorf("Case #%d: ThinkingBudgetToGemini3Level failed", tt.id)
|
||||
return
|
||||
}
|
||||
expectedLevel := tt.expectedValue.(string)
|
||||
if level != expectedLevel {
|
||||
t.Errorf("Case #%d: converted level = %q, want %q", tt.id, level, expectedLevel)
|
||||
}
|
||||
}
|
||||
|
||||
case "level":
|
||||
_, _, effort, matched := util.ThinkingFromMetadata(metadata)
|
||||
if !matched {
|
||||
t.Errorf("Case #%d: ThinkingFromMetadata did not match", tt.id)
|
||||
return
|
||||
}
|
||||
if effort == nil {
|
||||
t.Errorf("Case #%d: expected effort in metadata", tt.id)
|
||||
return
|
||||
}
|
||||
if tt.expectedField == "thinkingLevel" || tt.expectedField == "reasoning_effort" {
|
||||
expectedEffort := tt.expectedValue.(string)
|
||||
if *effort != expectedEffort {
|
||||
t.Errorf("Case #%d: effort = %q, want %q", tt.id, *effort, expectedEffort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Test Gemini-specific thinkingLevel conversion for Gemini 3 models
|
||||
if tt.expectedField == "thinkingLevel" && util.IsGemini3Model(tt.upstreamModel) {
|
||||
body := []byte(`{"request":{"contents":[]}}`)
|
||||
|
||||
// Build metadata simulating real OAuth flow:
|
||||
// - requestedModel (alias like "gf") is stored in model_mapping_original_model
|
||||
// - upstreamModel is passed as the model parameter
|
||||
testMetadata := make(map[string]any)
|
||||
if tt.isAlias {
|
||||
// Real flow: applyOAuthModelMapping stores requestedModel (the alias)
|
||||
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
}
|
||||
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
|
||||
for k, v := range metadata {
|
||||
testMetadata[k] = v
|
||||
}
|
||||
|
||||
result := util.ApplyGemini3ThinkingLevelFromMetadataCLI(tt.upstreamModel, testMetadata, body)
|
||||
levelVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
|
||||
expectedLevel := tt.expectedValue.(string)
|
||||
if !levelVal.Exists() {
|
||||
t.Errorf("Case #%d: expected thinkingLevel in result", tt.id)
|
||||
} else if levelVal.String() != expectedLevel {
|
||||
t.Errorf("Case #%d: thinkingLevel = %q, want %q", tt.id, levelVal.String(), expectedLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Test Gemini 2.5 thinkingBudget application using real ApplyThinkingMetadataCLI flow
|
||||
if tt.expectedField == "thinkingBudget" && util.IsGemini25Model(tt.upstreamModel) {
|
||||
body := []byte(`{"request":{"contents":[]}}`)
|
||||
|
||||
// Build metadata simulating real OAuth flow:
|
||||
// - requestedModel (alias like "gp") is stored in model_mapping_original_model
|
||||
// - upstreamModel is passed as the model parameter
|
||||
testMetadata := make(map[string]any)
|
||||
if tt.isAlias {
|
||||
// Real flow: applyOAuthModelMapping stores requestedModel (the alias)
|
||||
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
}
|
||||
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
|
||||
for k, v := range metadata {
|
||||
testMetadata[k] = v
|
||||
}
|
||||
|
||||
// Use the exported ApplyThinkingMetadataCLI which includes the fallback logic
|
||||
result := executor.ApplyThinkingMetadataCLI(body, testMetadata, tt.upstreamModel)
|
||||
budgetVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
|
||||
expectedBudget := tt.expectedValue.(int)
|
||||
if !budgetVal.Exists() {
|
||||
t.Errorf("Case #%d: expected thinkingBudget in result", tt.id)
|
||||
} else if int(budgetVal.Int()) != expectedBudget {
|
||||
t.Errorf("Case #%d: thinkingBudget = %d, want %d", tt.id, int(budgetVal.Int()), expectedBudget)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user