mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 12:30:50 +08:00
Compare commits
284 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e998b1229a | ||
|
|
bbed134bd1 | ||
|
|
cb56cb250e | ||
|
|
e0381a6ae0 | ||
|
|
2c01b2ef64 | ||
|
|
e947266743 | ||
|
|
c6b0e85b54 | ||
|
|
26efbed05c | ||
|
|
96340bf136 | ||
|
|
b055e00c1a | ||
|
|
857c880f99 | ||
|
|
ce7474d953 | ||
|
|
70fdd70b84 | ||
|
|
08ab6a7d77 | ||
|
|
9fa2a7e9df | ||
|
|
d443c86620 | ||
|
|
7be3f1c36c | ||
|
|
f6ab6d97b9 | ||
|
|
bc866bac49 | ||
|
|
50e6d845f4 | ||
|
|
a8cb01819d | ||
|
|
530273906b | ||
|
|
06ddf575d9 | ||
|
|
3099114cbb | ||
|
|
44b63f0767 | ||
|
|
6705d20194 | ||
|
|
a38a9c0b0f | ||
|
|
8286caa366 | ||
|
|
bd1ec8424d | ||
|
|
225e2c6797 | ||
|
|
d8fc485513 | ||
|
|
f137eb0ac4 | ||
|
|
f39a460487 | ||
|
|
ee171bc563 | ||
|
|
a95428f204 | ||
|
|
3ca5fb1046 | ||
|
|
a091d12f4e | ||
|
|
457924828a | ||
|
|
aca2ef6359 | ||
|
|
ade7194792 | ||
|
|
3a436e116a | ||
|
|
336867853b | ||
|
|
6403ff4ec4 | ||
|
|
d222469b44 | ||
|
|
7646a2b877 | ||
|
|
62090f2568 | ||
|
|
c281f4cbaf | ||
|
|
09455f9e85 | ||
|
|
c8e72ba0dc | ||
|
|
375ef252ab | ||
|
|
ee552f8720 | ||
|
|
2e88c4858e | ||
|
|
3f50da85c1 | ||
|
|
8be06255f7 | ||
|
|
72274099aa | ||
|
|
dcae098e23 | ||
|
|
2eb05ec640 | ||
|
|
3ce0d76aa4 | ||
|
|
a00b79d9be | ||
|
|
33e53a2a56 | ||
|
|
cd5b80785f | ||
|
|
54f71aa273 | ||
|
|
3f949b7f84 | ||
|
|
443c4538bb | ||
|
|
a7fc2ee4cf | ||
|
|
8e749ac22d | ||
|
|
69e09d9bc7 | ||
|
|
06ad527e8c | ||
|
|
b7409dd2de | ||
|
|
5ba325a8fc | ||
|
|
d502840f91 | ||
|
|
99238a4b59 | ||
|
|
6d43a2ff9a | ||
|
|
3faa1ca9af | ||
|
|
9d975e0375 | ||
|
|
2a6d8b78d4 | ||
|
|
671558a822 | ||
|
|
26fbb77901 | ||
|
|
a277302262 | ||
|
|
969c1a5b72 | ||
|
|
872339bceb | ||
|
|
5dc0dbc7aa | ||
|
|
2b7ba54a2f | ||
|
|
007c3304f2 | ||
|
|
e76ba0ede9 | ||
|
|
c06ac07e23 | ||
|
|
66769ec657 | ||
|
|
f413feec61 | ||
|
|
2e538e3486 | ||
|
|
9617a7b0d6 | ||
|
|
7569320770 | ||
|
|
8d25cf0d75 | ||
|
|
64e85e7019 | ||
|
|
6d1e20e940 | ||
|
|
0c0aae1eac | ||
|
|
5dcf7cb846 | ||
|
|
e52b542e22 | ||
|
|
8f6abb8a86 | ||
|
|
ed8eaae964 | ||
|
|
4e572ec8b9 | ||
|
|
24bc9cba67 | ||
|
|
1084b53fba | ||
|
|
83b90e106f | ||
|
|
5106caf641 | ||
|
|
b84ccc6e7a | ||
|
|
e19ddb53e7 | ||
|
|
5bf89dd757 | ||
|
|
2a0100b2d6 | ||
|
|
4442574e53 | ||
|
|
c020fa60d0 | ||
|
|
b078be4613 | ||
|
|
71a6dffbb6 | ||
|
|
27b43ed63f | ||
|
|
f6a3a1d0ba | ||
|
|
830fd8eac2 | ||
|
|
a86d501dc2 | ||
|
|
24e8e20b59 | ||
|
|
a87f09bad2 | ||
|
|
dbcbe48ead | ||
|
|
63908869f6 | ||
|
|
f6d625114c | ||
|
|
7dc40ba6d4 | ||
|
|
fcd6475377 | ||
|
|
4070c9de81 | ||
|
|
1e9e4a86a2 | ||
|
|
406a27271a | ||
|
|
9f9a4fc2af | ||
|
|
3fc410a253 | ||
|
|
781bc1521b | ||
|
|
05d201ece8 | ||
|
|
cd0c94f48a | ||
|
|
453e744abf | ||
|
|
653439698e | ||
|
|
24970baa57 | ||
|
|
89254cfc97 | ||
|
|
6bd9a034f7 | ||
|
|
26fc65b051 | ||
|
|
ed5ec5b55c | ||
|
|
df777650ac | ||
|
|
9855615f1e | ||
|
|
93414f1baa | ||
|
|
10f8c795ac | ||
|
|
3e4858a624 | ||
|
|
1231dc9cda | ||
|
|
c84ff42bcd | ||
|
|
8a5db02165 | ||
|
|
d7afb6eb0c | ||
|
|
bbd1fe890a | ||
|
|
f607231efa | ||
|
|
2039062845 | ||
|
|
99478d13a8 | ||
|
|
bc6c4cdbfc | ||
|
|
69d3a80fc3 | ||
|
|
404546ce93 | ||
|
|
9e268ad103 | ||
|
|
6dd1cf1dd6 | ||
|
|
9058d406a3 | ||
|
|
9d9b9e7a0d | ||
|
|
13aa82f3f3 | ||
|
|
05e55d7dc5 | ||
|
|
1b358c931c | ||
|
|
e04b02113a | ||
|
|
3275494fde | ||
|
|
ca09db21ff | ||
|
|
c1f8211acb | ||
|
|
718ff7a73f | ||
|
|
fa70b220e9 | ||
|
|
98fa2a1597 | ||
|
|
0e7c79ba23 | ||
|
|
b6ba15fcbd | ||
|
|
e44167d7a4 | ||
|
|
1bfa75f780 | ||
|
|
bbcb5552f3 | ||
|
|
1b8cb7b77b | ||
|
|
774f1fbc17 | ||
|
|
cfa8ddb59f | ||
|
|
39597267ae | ||
|
|
393e38f2c0 | ||
|
|
d1220de02d | ||
|
|
13eb5268de | ||
|
|
88798816f2 | ||
|
|
598f0af19b | ||
|
|
a33f5d31fc | ||
|
|
506699fba1 | ||
|
|
68a27772b3 | ||
|
|
de87fb622b | ||
|
|
f27672f6cf | ||
|
|
28420c14e4 | ||
|
|
0bd221ff41 | ||
|
|
5fda6f8ef3 | ||
|
|
9b956f6338 | ||
|
|
09923f654c | ||
|
|
ae7b972649 | ||
|
|
47885e3710 | ||
|
|
4b9a260b37 | ||
|
|
2c743c8f0b | ||
|
|
9f2c278ee6 | ||
|
|
aea337cfe2 | ||
|
|
811f8f8b4f | ||
|
|
27734a23b1 | ||
|
|
1b8e538a77 | ||
|
|
41c2385aca | ||
|
|
d605985f45 | ||
|
|
d52b28b147 | ||
|
|
4afe1f42ca | ||
|
|
7481c0eaa0 | ||
|
|
ffdfad8482 | ||
|
|
6586f08584 | ||
|
|
f49e887fe6 | ||
|
|
a5b3ff11fd | ||
|
|
084558f200 | ||
|
|
b602eae215 | ||
|
|
d02bf9c243 | ||
|
|
26a5f67df2 | ||
|
|
600fd42a83 | ||
|
|
670685139a | ||
|
|
52b6306388 | ||
|
|
521ec6f1b8 | ||
|
|
b0c5d9640a | ||
|
|
ef8e94e992 | ||
|
|
9df96a4bb4 | ||
|
|
28a428ae2f | ||
|
|
b326ec3641 | ||
|
|
fcecbc7d46 | ||
|
|
f4007f53ba | ||
|
|
5a812a1e93 | ||
|
|
5e624cc7b1 | ||
|
|
3af24597ee | ||
|
|
e0be6c5786 | ||
|
|
88b101ebf5 | ||
|
|
d9a65745df | ||
|
|
97ab623d42 | ||
|
|
14aa6cc7e8 | ||
|
|
3bc489254b | ||
|
|
4c07ea41c3 | ||
|
|
f6720f8dfa | ||
|
|
e19ab3a066 | ||
|
|
8f1dd69e72 | ||
|
|
f26da24a2f | ||
|
|
8e4fbcaa7d | ||
|
|
09c339953d | ||
|
|
367a05bdf6 | ||
|
|
d20b71deb9 | ||
|
|
712ce9f781 | ||
|
|
a4a3274a55 | ||
|
|
716aa71f6e | ||
|
|
e8976f9898 | ||
|
|
8496cc2444 | ||
|
|
5ef2d59e05 | ||
|
|
07bb89ae80 | ||
|
|
27a5ad8ec2 | ||
|
|
707b07c5f5 | ||
|
|
4a764afd76 | ||
|
|
ecf49d574b | ||
|
|
5a75ef8ffd | ||
|
|
07279f8746 | ||
|
|
71f788b13a | ||
|
|
59c62dc580 | ||
|
|
d5310a3300 | ||
|
|
f0a3eb574e | ||
|
|
bb15855443 | ||
|
|
14ce6aebd1 | ||
|
|
2fe83723f2 | ||
|
|
cd8c86c6fb | ||
|
|
52d5fd1a67 | ||
|
|
b6ad243e9e | ||
|
|
660aabc437 | ||
|
|
566120e8d5 | ||
|
|
f3f0f1717d | ||
|
|
7621ec609e | ||
|
|
9f511f0024 | ||
|
|
374faa2640 | ||
|
|
1c52a89535 | ||
|
|
e7cedbee6e | ||
|
|
b8194e717c | ||
|
|
15c3cc3a50 | ||
|
|
d131435e25 | ||
|
|
6e43669498 | ||
|
|
5ab3032335 | ||
|
|
1215c635a0 | ||
|
|
fc054db51a | ||
|
|
6e2306a5f2 | ||
|
|
b09e2115d1 | ||
|
|
07d21463ca |
@@ -13,8 +13,6 @@ Dockerfile
|
||||
docs/*
|
||||
README.md
|
||||
README_CN.md
|
||||
MANAGEMENT_API.md
|
||||
MANAGEMENT_API_CN.md
|
||||
LICENSE
|
||||
|
||||
# Runtime data folders (should be mounted as volumes)
|
||||
@@ -25,6 +23,14 @@ config.yaml
|
||||
|
||||
# Development/editor
|
||||
bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.codex/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -7,6 +7,13 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is it a request payload issue?**
|
||||
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
|
||||
[ ] No, it's another issue.
|
||||
|
||||
**If it's a request payload issue, you MUST know**
|
||||
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
|
||||
23
.github/workflows/pr-test-build.yml
vendored
Normal file
23
.github/workflows/pr-test-build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: pr-test-build
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build
|
||||
run: |
|
||||
go build -o test-output ./cmd/server
|
||||
rm -f test-output
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -11,11 +11,15 @@ bin/*
|
||||
logs/*
|
||||
conv/*
|
||||
temp/*
|
||||
refs/*
|
||||
|
||||
# Storage backends
|
||||
pgstore/*
|
||||
gitstore/*
|
||||
objectstore/*
|
||||
|
||||
# Static assets
|
||||
static/*
|
||||
refs/*
|
||||
|
||||
# Authentication data
|
||||
auths/*
|
||||
@@ -29,8 +33,17 @@ GEMINI.md
|
||||
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.codex/*
|
||||
.claude/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
25
README.md
25
README.md
@@ -10,14 +10,29 @@ So you can use local or multi-account CLI access with OpenAI(include Responses)/
|
||||
|
||||
## Sponsor
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
|
||||
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.7 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
|
||||
Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
|
||||
---
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=cliproxyapi">this link</a> and enter the "cliproxyapi" promo code during recharge to get 10% off.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
|
||||
<td>Thanks to Cubence for sponsoring this project! Cubence is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. Cubence provides special discounts for our software users: register using <a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">this link</a> and enter the "CLIPROXYAPI" promo code during recharge to get 10% off.</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## Overview
|
||||
|
||||
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
||||
@@ -59,7 +74,7 @@ CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and A
|
||||
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- Security-first design with localhost-only management endpoints
|
||||
|
||||
**→ [Complete Amp CLI Integration Guide](docs/amp-cli-integration.md)**
|
||||
**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDK Docs
|
||||
|
||||
@@ -99,6 +114,10 @@ CLI wrapper for instant switching between multiple Claude accounts and alternati
|
||||
|
||||
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
26
README_CN.md
26
README_CN.md
@@ -10,14 +10,30 @@
|
||||
|
||||
## 赞助商
|
||||
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
|
||||
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
|
||||
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7,为开发者提供顶尖的编码体验。
|
||||
|
||||
智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
||||
|
||||
---
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.packyapi.com/register?aff=cliproxyapi">此链接</a>注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
|
||||
<td>感谢 Cubence 对本项目的赞助!Cubence 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。Cubence 为本软件用户提供了特别优惠:使用<a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">此链接</a>注册,并在充值时输入 "CLIPROXYAPI" 优惠码即可享受九折优惠。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
|
||||
@@ -57,7 +73,7 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
|
||||
- 智能模型回退与自动路由
|
||||
- 以安全为先的设计,管理端点仅限 localhost
|
||||
|
||||
**→ [Amp CLI 完整集成指南](docs/amp-cli-integration_CN.md)**
|
||||
**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)**
|
||||
|
||||
## SDK 文档
|
||||
|
||||
@@ -97,6 +113,10 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
|
||||
|
||||
基于 macOS 平台的原生 CLIProxyAPI GUI:配置供应商、模型映射以及OAuth端点,无需 API 密钥。
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
|
||||
BIN
assets/cubence.png
Normal file
BIN
assets/cubence.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
BIN
assets/packycode.png
Normal file
BIN
assets/packycode.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.1 KiB |
@@ -405,7 +405,7 @@ func main() {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if err = logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to configure log output: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -25,6 +25,9 @@ remote-management:
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
@@ -32,19 +35,30 @@ auth-dir: "~/.cli-proxy-api"
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
- "your-api-key-3"
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
|
||||
# When true, write application logs to rotating files instead of stdout
|
||||
logging-to-file: false
|
||||
|
||||
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
|
||||
# files are deleted until within the limit. Set to 0 to disable.
|
||||
logs-max-total-size-mb: 0
|
||||
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -56,16 +70,29 @@ quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||
|
||||
# Routing strategy for selecting credentials when multiple match.
|
||||
routing:
|
||||
strategy: "round-robin" # round-robin (default), fill-first
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
|
||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||
# streaming:
|
||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
||||
|
||||
# Gemini API keys
|
||||
# gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
|
||||
# base-url: "https://generativelanguage.googleapis.com"
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# models:
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
|
||||
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
||||
@@ -76,10 +103,14 @@ ws-auth: false
|
||||
# Codex API keys
|
||||
# codex-api-key:
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "gpt-5-codex" # upstream model name
|
||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gpt-5.1" # exclude specific models (exact match)
|
||||
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
|
||||
@@ -90,13 +121,14 @@ ws-auth: false
|
||||
# claude-api-key:
|
||||
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
|
||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||
@@ -106,6 +138,7 @@ ws-auth: false
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
|
||||
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -120,14 +153,15 @@ ws-auth: false
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# models: # optional: map aliases to upstream model names
|
||||
# - name: "gemini-2.0-flash" # upstream model name
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "vertex-flash" # client-visible alias
|
||||
# - name: "gemini-1.5-pro"
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "vertex-pro"
|
||||
|
||||
# Amp Integration
|
||||
@@ -136,8 +170,20 @@ ws-auth: false
|
||||
# upstream-url: "https://ampcode.com"
|
||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||
# upstream-api-key: ""
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||
# restrict-management-to-localhost: true
|
||||
# # Per-client upstream API key mapping
|
||||
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
|
||||
# # Useful when different clients need to use different Amp accounts/quotas.
|
||||
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
|
||||
# upstream-api-keys:
|
||||
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
|
||||
# api-keys: # Client keys that use this upstream key
|
||||
# - "your-api-key-1"
|
||||
# - "your-api-key-2"
|
||||
# - upstream-api-key: "amp_key_for_team_b"
|
||||
# api-keys:
|
||||
# - "your-api-key-3"
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||
# restrict-management-to-localhost: false
|
||||
# # Force model mappings to run before checking local API keys (default: false)
|
||||
# force-model-mappings: false
|
||||
# # Amp Model Mappings
|
||||
@@ -145,12 +191,42 @@ ws-auth: false
|
||||
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||
# # but you have a similar model available (e.g., Claude Sonnet 4).
|
||||
# model-mappings:
|
||||
# - from: "claude-opus-4.5" # Model requested by Amp CLI
|
||||
# to: "claude-sonnet-4" # Route to this available model instead
|
||||
# - from: "gpt-5"
|
||||
# to: "gemini-2.5-pro"
|
||||
# - from: "claude-3-opus-20240229"
|
||||
# to: "claude-3-5-sonnet-20241022"
|
||||
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
|
||||
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
|
||||
# - from: "claude-sonnet-4-5-20250929"
|
||||
# to: "gemini-claude-sonnet-4-5-thinking"
|
||||
# - from: "claude-haiku-4-5-20251001"
|
||||
# to: "gemini-2.5-flash"
|
||||
|
||||
# Global OAuth model name mappings (per channel)
|
||||
# These mappings rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# oauth-model-mappings:
|
||||
# gemini-cli:
|
||||
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||
# alias: "g2.5p" # client-visible alias
|
||||
# vertex:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# antigravity:
|
||||
# - name: "gemini-3-pro-preview"
|
||||
# alias: "g3p"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
# codex:
|
||||
# - name: "gpt-5"
|
||||
# alias: "g5"
|
||||
# qwen:
|
||||
# - name: "qwen3-coder-plus"
|
||||
# alias: "qwen-plus"
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
|
||||
# OAuth provider excluded models
|
||||
# oauth-excluded-models:
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
# Amp CLI Integration Guide
|
||||
|
||||
This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions, enabling you to use your existing Google/ChatGPT/Claude subscriptions (via OAuth) with Amp's CLI.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate)
|
||||
- [Architecture](#architecture)
|
||||
- [Configuration](#configuration)
|
||||
- [Model Mapping Configuration](#model-mapping-configuration)
|
||||
- [Setup](#setup)
|
||||
- [Usage](#usage)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The Amp CLI integration adds specialized routing to support Amp's API patterns while maintaining full compatibility with all existing CLIProxyAPI features. This allows you to use both traditional CLIProxyAPI features and Amp CLI with the same proxy server.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers
|
||||
- **Management proxy**: Forwards OAuth and account management requests to Amp's control plane
|
||||
- **Smart fallback**: Automatically routes unconfigured models to ampcode.com
|
||||
- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- **Secret management**: Configurable precedence (config > env > file) with 5-minute caching
|
||||
- **Security-first**: Management routes restricted to localhost by default
|
||||
- **Automatic gzip handling**: Decompresses responses from Amp upstream
|
||||
|
||||
### What You Can Do
|
||||
|
||||
- Use Amp CLI with your Google account (Gemini 3 Pro Preview, Gemini 2.5 Pro, Gemini 2.5 Flash)
|
||||
- Use Amp CLI with your ChatGPT Plus/Pro subscription (GPT-5, GPT-5 Codex models)
|
||||
- Use Amp CLI with your Claude Pro/Max subscription (Claude Sonnet 4.5, Opus 4.1)
|
||||
- Use Amp IDE extensions (VS Code, Cursor, Windsurf, etc.) with the same proxy
|
||||
- Run multiple CLI tools (Factory + Amp) through one proxy server
|
||||
- Route unconfigured models automatically through ampcode.com
|
||||
|
||||
### Which Providers Should You Authenticate?
|
||||
|
||||
**Important**: The providers you need to authenticate depend on which models and features your installed version of Amp currently uses. Amp employs different providers for various agent modes and specialized subagents:
|
||||
|
||||
- **Smart mode**: Uses Google/Gemini models (Gemini 3 Pro)
|
||||
- **Rush mode**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Oracle subagent**: Uses OpenAI/GPT models (GPT-5 medium reasoning)
|
||||
- **Librarian subagent**: Uses Anthropic/Claude models (Claude Sonnet 4.5)
|
||||
- **Search subagent**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Review feature**: Uses Google/Gemini models (Gemini 2.5 Flash-Lite)
|
||||
|
||||
For the most current information about which models Amp uses, see the **[Amp Models Documentation](https://ampcode.com/models)**.
|
||||
|
||||
#### Fallback Behavior
|
||||
|
||||
CLIProxyAPI uses a smart fallback system:
|
||||
|
||||
1. **Provider authenticated locally** (`--login`, `--codex-login`, `--claude-login`):
|
||||
- Requests use **your OAuth subscription** (ChatGPT Plus/Pro, Claude Pro/Max, Google account)
|
||||
- You benefit from your subscription's included usage quotas
|
||||
- No Amp credits consumed
|
||||
|
||||
2. **Provider NOT authenticated locally**:
|
||||
- Requests automatically forward to **ampcode.com**
|
||||
- Uses Amp's backend provider connections
|
||||
- **Requires Amp credits** if the provider is paid (OpenAI, Anthropic paid tiers)
|
||||
- May result in errors if Amp credit balance is insufficient
|
||||
|
||||
**Recommendation**: Authenticate all providers you have subscriptions for to maximize value and minimize Amp credit usage. If you don't have subscriptions to all providers Amp uses, ensure you have sufficient Amp credits available for fallback requests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Request Flow
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO ↓
|
||||
│ │ ├─ Model mapping configured?
|
||||
│ │ │ YES → Rewrite model → Use local handler (free)
|
||||
│ │ │ NO → Forward to ampcode.com (uses Amp credits)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### Components
|
||||
|
||||
The Amp integration is implemented as a modular routing module (`internal/api/modules/amp/`) with these components:
|
||||
|
||||
1. **Route Aliases** (`routes.go`): Maps Amp-style paths to standard handlers
|
||||
2. **Reverse Proxy** (`proxy.go`): Forwards management requests to ampcode.com
|
||||
3. **Fallback Handler** (`fallback_handlers.go`): Routes unconfigured models to ampcode.com
|
||||
4. **Secret Management** (`secret.go`): Multi-source API key resolution with caching
|
||||
5. **Main Module** (`amp.go`): Orchestrates registration and configuration
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Add these fields to your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# Amp upstream control plane (required for management routes)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# Optional: Override API key (otherwise uses env or file)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# Security: restrict management routes to localhost (recommended)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### Model Mapping Configuration
|
||||
|
||||
When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally.
|
||||
|
||||
```yaml
|
||||
# Route unavailable models to alternatives
|
||||
amp-model-mappings:
|
||||
# Example: Route Claude Opus 4.5 requests to Claude Sonnet 4
|
||||
- from: "claude-opus-4.5"
|
||||
to: "claude-sonnet-4"
|
||||
|
||||
# Example: Route GPT-5 requests to Gemini 2.5 Pro
|
||||
- from: "gpt-5"
|
||||
to: "gemini-2.5-pro"
|
||||
|
||||
# Example: Map older model names to newer versions
|
||||
- from: "claude-3-opus-20240229"
|
||||
to: "claude-3-5-sonnet-20241022"
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. Amp CLI requests a model (e.g., `claude-opus-4.5`)
|
||||
2. CLIProxyAPI checks if a local provider is available for that model
|
||||
3. If not available, it checks the model mappings
|
||||
4. If a mapping exists, the request is rewritten to use the target model
|
||||
5. The request is then handled locally (free, using your OAuth subscription)
|
||||
|
||||
**Benefits:**
|
||||
- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com
|
||||
- **Hot-reload**: Mappings can be updated without restarting the proxy
|
||||
- **Structured logging**: Clear logs show when mappings are applied
|
||||
|
||||
**Routing Decision Logs:**
|
||||
|
||||
The proxy logs each routing decision with structured fields:
|
||||
|
||||
```
|
||||
[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free)
|
||||
[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free)
|
||||
[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits)
|
||||
```
|
||||
|
||||
### Secret Resolution Precedence
|
||||
|
||||
The Amp module resolves API keys using this precedence order:
|
||||
|
||||
| Source | Key | Priority | Cache |
|
||||
|--------|-----|----------|-------|
|
||||
| Config file | `amp-upstream-api-key` | High | No |
|
||||
| Environment | `AMP_API_KEY` | Medium | No |
|
||||
| Amp secrets file | `~/.local/share/amp/secrets.json` | Low | 5 min |
|
||||
|
||||
**Recommendation**: Use the Amp secrets file (lowest precedence) for normal usage. This file is automatically managed by `amp login`.
|
||||
|
||||
### Security Settings
|
||||
|
||||
**`amp-restrict-management-to-localhost`** (default: `true`)
|
||||
|
||||
When enabled, management routes (`/api/auth`, `/api/user`, `/api/threads`, etc.) only accept connections from localhost (127.0.0.1, ::1). This prevents:
|
||||
- Drive-by browser attacks
|
||||
- Remote access to management endpoints
|
||||
- CORS-based attacks
|
||||
- Header spoofing attacks (e.g., `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### How It Works
|
||||
|
||||
This restriction uses the **actual TCP connection address** (`RemoteAddr`), not HTTP headers like `X-Forwarded-For`. This prevents header spoofing attacks but has important implications:
|
||||
|
||||
- ✅ **Works for direct connections**: Running CLIProxyAPI directly on your machine or server
|
||||
- ⚠️ **May not work behind reverse proxies**: If deploying behind nginx, Cloudflare, or other proxies, the connection will appear to come from the proxy's IP, not localhost
|
||||
|
||||
#### Reverse Proxy Deployments
|
||||
|
||||
If you need to run CLIProxyAPI behind a reverse proxy (nginx, Caddy, Cloudflare Tunnel, etc.):
|
||||
|
||||
1. **Disable the localhost restriction**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **Use alternative security measures**:
|
||||
- Firewall rules restricting access to management routes
|
||||
- Proxy-level authentication (HTTP Basic Auth, OAuth)
|
||||
- Network-level isolation (VPN, Tailscale, Cloudflare Access)
|
||||
- Bind CLIProxyAPI to `127.0.0.1` only and access via SSH tunnel
|
||||
|
||||
3. **Example nginx configuration** (blocks external access to management routes):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**Important**: Only disable `amp-restrict-management-to-localhost` if you understand the security implications and have other protections in place.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Configure CLIProxyAPI
|
||||
|
||||
Create or edit `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp integration
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# Other standard settings...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. Authenticate with Providers
|
||||
|
||||
Run OAuth login for the providers you want to use:
|
||||
|
||||
**Google Account (Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro (GPT-5, GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max (Claude Sonnet 4.5, Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
Tokens are saved to:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. Start the Proxy
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
Or run in background with tmux (recommended for remote servers):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. Configure Amp CLI
|
||||
|
||||
#### Option A: Settings File
|
||||
|
||||
Edit `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### Option B: Environment Variable
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. Login and Use Amp
|
||||
|
||||
Login through the proxy (proxied to ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
Use Amp as normal:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (Optional) Configure Amp IDE Extension
|
||||
|
||||
The proxy also works with Amp IDE extensions for VS Code, Cursor, Windsurf, etc.
|
||||
|
||||
1. Open Amp extension settings in your IDE
|
||||
2. Set **Amp URL** to `http://localhost:8317`
|
||||
3. Login with your Amp account
|
||||
4. Start using Amp in your IDE
|
||||
|
||||
Both CLI and IDE can use the proxy simultaneously.
|
||||
|
||||
## Usage
|
||||
|
||||
### Supported Routes
|
||||
|
||||
#### Provider Aliases (Always Available)
|
||||
|
||||
These routes work even without `amp-upstream-url` configured:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI calls these routes with your OAuth-authenticated models configured in CLIProxyAPI.
|
||||
|
||||
#### Management Routes (Require `amp-upstream-url`)
|
||||
|
||||
These routes are proxied to ampcode.com:
|
||||
|
||||
- `/api/auth` - Authentication
|
||||
- `/api/user` - User profile
|
||||
- `/api/meta` - Metadata
|
||||
- `/api/threads` - Conversation threads
|
||||
- `/api/telemetry` - Usage telemetry
|
||||
- `/api/internal` - Internal APIs
|
||||
|
||||
**Security**: Restricted to localhost by default.
|
||||
|
||||
### Model Fallback Behavior
|
||||
|
||||
When Amp requests a model:
|
||||
|
||||
1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider?
|
||||
2. **If YES**: Route to local handler (use your OAuth subscription)
|
||||
3. **If NO**: Check if a model mapping exists
|
||||
4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free)
|
||||
5. **If no mapping**: Forward to ampcode.com (uses Amp credits)
|
||||
|
||||
This enables seamless mixed usage:
|
||||
- Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions
|
||||
- Models with mappings configured → Routed to alternative local models (free)
|
||||
- Models you haven't configured and have no mapping → Amp's default providers (uses credits)
|
||||
|
||||
### Example API Calls
|
||||
|
||||
**Chat completion with local OAuth:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**Management endpoint (localhost only):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
|---------|--------------|-----|
|
||||
| 404 on `/api/provider/...` | Incorrect route path | Ensure exact path: `/api/provider/{provider}/v1...` |
|
||||
| 403 on `/api/user` | Non-localhost request | Run from same machine or disable `amp-restrict-management-to-localhost` (not recommended) |
|
||||
| 401/403 from provider | Missing/expired OAuth | Re-run `--codex-login` or `--claude-login` |
|
||||
| Amp gzip errors | Response decompression issue | Update to latest build; auto-decompression should handle this |
|
||||
| Models not using proxy | Wrong Amp URL | Verify `amp.url` setting or `AMP_URL` environment variable |
|
||||
| CORS errors | Protected management endpoint | Use CLI/terminal, not browser |
|
||||
|
||||
### Diagnostics
|
||||
|
||||
**Check proxy logs:**
|
||||
```bash
|
||||
# If logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# If running in tmux
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**Enable debug mode** (temporarily):
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**Test basic connectivity:**
|
||||
```bash
|
||||
# Check if proxy is running
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# Check Amp-specific route
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**Verify Amp configuration:**
|
||||
```bash
|
||||
# Check if Amp is using proxy
|
||||
amp config get amp.url
|
||||
|
||||
# Or check environment
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### Security Checklist
|
||||
|
||||
- ✅ Keep `amp-restrict-management-to-localhost: true` (default)
|
||||
- ✅ Don't expose proxy publicly (bind to localhost or use firewall/VPN)
|
||||
- ✅ Use the Amp secrets file (`~/.local/share/amp/secrets.json`) managed by `amp login`
|
||||
- ✅ Rotate OAuth tokens periodically by re-running login commands
|
||||
- ✅ Store config and auth-dir on encrypted disk if handling sensitive data
|
||||
- ✅ Keep proxy binary up to date for security fixes
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [CLIProxyAPI Main Documentation](https://help.router-for.me/)
|
||||
- [Amp CLI Official Manual](https://ampcode.com/manual)
|
||||
- [Management API Reference](https://help.router-for.me/management/api)
|
||||
- [SDK Documentation](sdk-usage.md)
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This integration is for personal/educational use. Using reverse proxies or alternate API bases may violate provider Terms of Service. You are solely responsible for how you use this software. Accounts may be rate-limited, locked, or banned. No warranties. Use at your own risk.
|
||||
@@ -1,392 +0,0 @@
|
||||
# Amp CLI 集成指南
|
||||
|
||||
本指南说明如何在 Amp CLI 和 Amp IDE 扩展中使用 CLIProxyAPI,通过 OAuth 让你能够把已有的 Google/ChatGPT/Claude 订阅与 Amp 的 CLI 一起使用。
|
||||
|
||||
## 目录
|
||||
|
||||
- [概述](#概述)
|
||||
- [应该认证哪些服务提供商?](#应该认证哪些服务提供商)
|
||||
- [架构](#架构)
|
||||
- [配置](#配置)
|
||||
- [设置](#设置)
|
||||
- [用法](#用法)
|
||||
- [故障排查](#故障排查)
|
||||
|
||||
## 概述
|
||||
|
||||
Amp CLI 集成为 Amp 的 API 模式添加了专用路由,同时保持与现有 CLIProxyAPI 功能的完全兼容。这样你可以在同一个代理服务器上同时使用传统 CLIProxyAPI 功能和 Amp CLI。
|
||||
|
||||
### 主要特性
|
||||
|
||||
- **提供者路由别名**:将 Amp 的 `/api/provider/{provider}/v1...` 路径映射到 CLIProxyAPI 处理器
|
||||
- **管理代理**:将 OAuth 和账号管理请求转发到 Amp 控制平面
|
||||
- **智能回退**:自动将未配置的模型路由到 ampcode.com
|
||||
- **密钥管理**:可配置优先级(配置 > 环境变量 > 文件),缓存 5 分钟
|
||||
- **安全优先**:管理路由默认限制为 localhost
|
||||
- **自动 gzip 处理**:自动解压来自 Amp 上游的响应
|
||||
|
||||
### 你可以做什么
|
||||
|
||||
- 使用 Amp CLI 搭配你的 Google 账号(Gemini 3 Pro Preview、Gemini 2.5 Pro、Gemini 2.5 Flash)
|
||||
- 使用 Amp CLI 搭配你的 ChatGPT Plus/Pro 订阅(GPT-5、GPT-5 Codex 模型)
|
||||
- 使用 Amp CLI 搭配你的 Claude Pro/Max 订阅(Claude Sonnet 4.5、Opus 4.1)
|
||||
- 将 Amp IDE 扩展(VS Code、Cursor、Windsurf 等)与同一个代理一起使用
|
||||
- 通过一个代理同时运行多个 CLI 工具(Factory + Amp)
|
||||
- 将未配置的模型自动路由到 ampcode.com
|
||||
|
||||
### 应该认证哪些服务提供商?
|
||||
|
||||
**重要**:需要认证的提供商取决于你安装的 Amp 版本当前使用的模型和功能。Amp 的不同智能模式和子代理会使用不同的提供商:
|
||||
|
||||
- **Smart 模式**:使用 Google/Gemini 模型(Gemini 3 Pro)
|
||||
- **Rush 模式**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Oracle 子代理**:使用 OpenAI/GPT 模型(GPT-5 medium reasoning)
|
||||
- **Librarian 子代理**:使用 Anthropic/Claude 模型(Claude Sonnet 4.5)
|
||||
- **Search 子代理**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Review 功能**:使用 Google/Gemini 模型(Gemini 2.5 Flash-Lite)
|
||||
|
||||
有关 Amp 当前使用哪些模型的最新信息,请参阅 **[Amp 模型文档](https://ampcode.com/models)**。
|
||||
|
||||
#### 回退行为
|
||||
|
||||
CLIProxyAPI 采用智能回退机制:
|
||||
|
||||
1. **本地已认证提供商**(`--login`、`--codex-login`、`--claude-login`):
|
||||
- 请求使用**你的 OAuth 订阅**(ChatGPT Plus/Pro、Claude Pro/Max、Google 账号)
|
||||
- 享受订阅自带的额度
|
||||
- 不消耗 Amp 额度
|
||||
|
||||
2. **本地未认证提供商**:
|
||||
- 请求自动转发到 **ampcode.com**
|
||||
- 使用 Amp 的后端提供商连接
|
||||
- 如果提供商是付费的(OpenAI、Anthropic 付费档),**需要消耗 Amp 额度**
|
||||
- 若 Amp 额度不足,可能产生错误
|
||||
|
||||
**建议**:对你有订阅的所有提供商都进行认证,以最大化价值并尽量减少 Amp 额度消耗。如果没有覆盖 Amp 使用的全部提供商,请确保为回退请求准备足够的 Amp 额度。
|
||||
|
||||
## 架构
|
||||
|
||||
### 请求流
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO → Forward to ampcode.com (reverse proxy)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### 组件
|
||||
|
||||
Amp 集成以模块化路由模块(`internal/api/modules/amp/`)实现,包含以下组件:
|
||||
|
||||
1. **路由别名**(`routes.go`):将 Amp 风格的路径映射到标准处理器
|
||||
2. **反向代理**(`proxy.go`):将管理请求转发到 ampcode.com
|
||||
3. **回退处理器**(`fallback_handlers.go`):将未配置的模型路由到 ampcode.com
|
||||
4. **密钥管理**(`secret.go`):多来源 API 密钥解析并带缓存
|
||||
5. **主模块**(`amp.go`):负责注册和配置
|
||||
|
||||
## 配置
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.yaml` 中新增以下字段:
|
||||
|
||||
```yaml
|
||||
# Amp 上游控制平面(管理路由必需)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# 可选:覆盖 API key(否则使用环境变量或文件)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# 安全性:将管理路由限制为 localhost(推荐)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### 密钥解析优先级
|
||||
|
||||
Amp 模块以如下优先级解析 API key:
|
||||
|
||||
| 来源 | 键名 | 优先级 | 缓存 |
|
||||
|------|------|--------|------|
|
||||
| 配置文件 | `amp-upstream-api-key` | 高 | 无 |
|
||||
| 环境变量 | `AMP_API_KEY` | 中 | 无 |
|
||||
| Amp 密钥文件 | `~/.local/share/amp/secrets.json` | 低 | 5 分钟 |
|
||||
|
||||
**建议**:日常使用时采用 Amp 密钥文件(最低优先级)。该文件由 `amp login` 自动管理。
|
||||
|
||||
### 安全设置
|
||||
|
||||
**`amp-restrict-management-to-localhost`**(默认:`true`)
|
||||
|
||||
启用后,管理路由(`/api/auth`、`/api/user`、`/api/threads` 等)只接受来自 localhost(127.0.0.1、::1)的连接,可防止:
|
||||
- 浏览器探测式攻击
|
||||
- 对管理端点的远程访问
|
||||
- 基于 CORS 的攻击
|
||||
- 伪造头攻击(例如 `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### 工作原理
|
||||
|
||||
此限制使用**实际的 TCP 连接地址**(`RemoteAddr`),而非 `X-Forwarded-For` 等 HTTP 头,能防止头部伪造,但有重要影响:
|
||||
|
||||
- ✅ **直接连接可用**:在本机或服务器直接运行 CLIProxyAPI 时适用
|
||||
- ⚠️ **可能不适用于反向代理场景**:部署在 nginx、Cloudflare 等代理后,请求源会显示为代理 IP 而非 localhost
|
||||
|
||||
#### 反向代理部署
|
||||
|
||||
若需要在反向代理(nginx、Caddy、Cloudflare Tunnel 等)后运行 CLIProxyAPI:
|
||||
|
||||
1. **关闭 localhost 限制**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **使用替代安全措施**:
|
||||
- 防火墙规则限制管理路由访问
|
||||
- 代理层认证(HTTP Basic Auth、OAuth)
|
||||
- 网络隔离(VPN、Tailscale、Cloudflare Access)
|
||||
- 将 CLIProxyAPI 仅绑定 `127.0.0.1`,并通过 SSH 隧道访问
|
||||
|
||||
3. **nginx 示例配置**(阻止外部访问管理路由):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**重要**:只有在理解安全影响并已采取其他防护措施时,才关闭 `amp-restrict-management-to-localhost`。
|
||||
|
||||
## 设置
|
||||
|
||||
### 1. 配置 CLIProxyAPI
|
||||
|
||||
创建或编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp 集成
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# 其他常规设置...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. 认证提供商
|
||||
|
||||
为要使用的提供商执行 OAuth 登录:
|
||||
|
||||
**Google 账号(Gemini 2.5 Pro、Gemini 2.5 Flash、Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro(GPT-5、GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max(Claude Sonnet 4.5、Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
令牌会保存到:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. 启动代理
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
或使用 tmux 在后台运行(推荐用于远程服务器):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. 配置 Amp CLI
|
||||
|
||||
#### 方案 A:配置文件
|
||||
|
||||
编辑 `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### 方案 B:环境变量
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. 登录并使用 Amp
|
||||
|
||||
通过代理登录(请求会被代理到 ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
像平常一样使用 Amp:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (可选)配置 Amp IDE 扩展
|
||||
|
||||
该代理同样适用于 VS Code、Cursor、Windsurf 等 Amp IDE 扩展。
|
||||
|
||||
1. 在 IDE 中打开 Amp 扩展设置
|
||||
2. 将 **Amp URL** 设置为 `http://localhost:8317`
|
||||
3. 用你的 Amp 账号登录
|
||||
4. 在 IDE 中开始使用 Amp
|
||||
|
||||
CLI 和 IDE 可同时使用该代理。
|
||||
|
||||
## 用法
|
||||
|
||||
### 支持的路由
|
||||
|
||||
#### 提供商别名(始终可用)
|
||||
|
||||
这些路由即使未配置 `amp-upstream-url` 也可使用:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI 会使用你在 CLIProxyAPI 中通过 OAuth 认证的模型来调用这些路由。
|
||||
|
||||
#### 管理路由(需要 `amp-upstream-url`)
|
||||
|
||||
这些路由会被代理到 ampcode.com:
|
||||
|
||||
- `/api/auth` - 认证
|
||||
- `/api/user` - 用户资料
|
||||
- `/api/meta` - 元数据
|
||||
- `/api/threads` - 会话线程
|
||||
- `/api/telemetry` - 使用遥测
|
||||
- `/api/internal` - 内部 API
|
||||
|
||||
**安全性**:默认限制为 localhost。
|
||||
|
||||
### 模型回退行为
|
||||
|
||||
当 Amp 请求模型时:
|
||||
|
||||
1. **检查本地配置**:CLIProxyAPI 是否有该模型提供商的 OAuth 令牌?
|
||||
2. **如果有**:路由到本地处理器(使用你的 OAuth 订阅)
|
||||
3. **如果没有**:转发到 ampcode.com(使用 Amp 的默认路由)
|
||||
|
||||
这实现了无缝混用:
|
||||
- 你已配置的模型(Gemini、ChatGPT、Claude)→ 你的 OAuth 订阅
|
||||
- 未配置的模型 → Amp 的默认提供商
|
||||
|
||||
### 示例 API 调用
|
||||
|
||||
**使用本地 OAuth 的聊天补全:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**管理端点(仅限 localhost):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 常见问题
|
||||
|
||||
| 症状 | 可能原因 | 解决方案 |
|
||||
|------|----------|----------|
|
||||
| `/api/provider/...` 返回 404 | 路径错误 | 确保路径准确:`/api/provider/{provider}/v1...` |
|
||||
| `/api/user` 返回 403 | 非 localhost 请求 | 在同一机器上访问,或关闭 `amp-restrict-management-to-localhost`(不推荐) |
|
||||
| 提供商返回 401/403 | OAuth 缺失或过期 | 重新运行 `--codex-login` 或 `--claude-login` |
|
||||
| Amp gzip 错误 | 响应解压问题 | 更新到最新构建;自动解压应能处理 |
|
||||
| 模型未走代理 | Amp URL 设置错误 | 检查 `amp.url` 设置或 `AMP_URL` 环境变量 |
|
||||
| CORS 错误 | 受保护的管理端点 | 使用 CLI/终端而非浏览器 |
|
||||
|
||||
### 诊断
|
||||
|
||||
**查看代理日志:**
|
||||
```bash
|
||||
# 若 logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# 若运行在 tmux 中
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**临时开启调试模式:**
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**测试基础连通性:**
|
||||
```bash
|
||||
# 检查代理是否运行
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# 检查 Amp 特定路由
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**验证 Amp 配置:**
|
||||
```bash
|
||||
# 检查 Amp 是否使用代理
|
||||
amp config get amp.url
|
||||
|
||||
# 或检查环境变量
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### 安全清单
|
||||
|
||||
- ✅ 保持 `amp-restrict-management-to-localhost: true`(默认)
|
||||
- ✅ 不要将代理暴露到公共网络(绑定到 localhost 或使用防火墙/VPN)
|
||||
- ✅ 使用 `amp login` 管理的 Amp 密钥文件(`~/.local/share/amp/secrets.json`)
|
||||
- ✅ 定期重新登录轮换 OAuth 令牌
|
||||
- ✅ 若处理敏感数据,使用加密磁盘存储配置和 auth-dir
|
||||
- ✅ 保持代理二进制为最新版本以获取安全修复
|
||||
|
||||
## 其他资源
|
||||
|
||||
- [CLIProxyAPI 主文档](https://help.router-for.me/)
|
||||
- [Amp CLI 官方手册](https://ampcode.com/manual)
|
||||
- [管理 API 参考](https://help.router-for.me/management/api)
|
||||
- [SDK 文档](sdk-usage.md)
|
||||
|
||||
## 免责声明
|
||||
|
||||
此集成仅用于个人或教育用途。使用反向代理或替代 API 基址可能违反提供商的服务条款。你需要对自己的使用方式负责。账号可能会被限速、锁定或封禁。软件不附带任何保证,使用风险自负。
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
|
||||
10
go.mod
10
go.mod
@@ -18,8 +18,8 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.7.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -68,9 +68,9 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
24
go.sum
24
go.sum
@@ -160,22 +160,22 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
538
internal/api/handlers/management/api_tools.go
Normal file
538
internal/api/handlers/management/api_tools.go
Normal file
@@ -0,0 +1,538 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const defaultAPICallTimeout = 60 * time.Second
|
||||
|
||||
const (
|
||||
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOAuthClientSecret = "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"
|
||||
)
|
||||
|
||||
var geminiOAuthScopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
type apiCallRequest struct {
|
||||
AuthIndexSnake *string `json:"auth_index"`
|
||||
AuthIndexCamel *string `json:"authIndex"`
|
||||
AuthIndexPascal *string `json:"AuthIndex"`
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Header map[string]string `json:"header"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type apiCallResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Header map[string][]string `json:"header"`
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
// APICall makes a generic HTTP request on behalf of the management API caller.
|
||||
// It is protected by the management middleware.
|
||||
//
|
||||
// Endpoint:
|
||||
//
|
||||
// POST /v0/management/api-call
|
||||
//
|
||||
// Authentication:
|
||||
//
|
||||
// Same as other management APIs (requires a management key and remote-management rules).
|
||||
// You can provide the key via:
|
||||
// - Authorization: Bearer <key>
|
||||
// - X-Management-Key: <key>
|
||||
//
|
||||
// Request JSON:
|
||||
// - auth_index / authIndex / AuthIndex (optional):
|
||||
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
|
||||
// If omitted or not found, credential-specific proxy/token substitution is skipped.
|
||||
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
|
||||
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
|
||||
// - header (optional): Request headers map.
|
||||
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
|
||||
// 1) metadata.access_token
|
||||
// 2) attributes.api_key
|
||||
// 3) metadata.token / metadata.id_token / metadata.cookie
|
||||
// Example: {"Authorization":"Bearer $TOKEN$"}.
|
||||
// Note: if you need to override the HTTP Host header, set header["Host"].
|
||||
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
|
||||
//
|
||||
// Proxy selection (highest priority first):
|
||||
// 1. Selected credential proxy_url
|
||||
// 2. Global config proxy-url
|
||||
// 3. Direct connect (environment proxies are not used)
|
||||
//
|
||||
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
|
||||
// - status_code: Upstream HTTP status code.
|
||||
// - header: Upstream response headers.
|
||||
// - body: Upstream response body as string.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
|
||||
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
|
||||
// -H "Content-Type: application/json" \
|
||||
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
|
||||
//
|
||||
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
|
||||
// -H "Authorization: Bearer 831227" \
|
||||
// -H "Content-Type: application/json" \
|
||||
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
|
||||
func (h *Handler) APICall(c *gin.Context) {
|
||||
var body apiCallRequest
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
method := strings.ToUpper(strings.TrimSpace(body.Method))
|
||||
if method == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
|
||||
return
|
||||
}
|
||||
|
||||
urlStr := strings.TrimSpace(body.URL)
|
||||
if urlStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
|
||||
return
|
||||
}
|
||||
parsedURL, errParseURL := url.Parse(urlStr)
|
||||
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
|
||||
return
|
||||
}
|
||||
|
||||
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
|
||||
auth := h.authByIndex(authIndex)
|
||||
|
||||
reqHeaders := body.Header
|
||||
if reqHeaders == nil {
|
||||
reqHeaders = map[string]string{}
|
||||
}
|
||||
|
||||
var hostOverride string
|
||||
var token string
|
||||
var tokenResolved bool
|
||||
var tokenErr error
|
||||
for key, value := range reqHeaders {
|
||||
if !strings.Contains(value, "$TOKEN$") {
|
||||
continue
|
||||
}
|
||||
if !tokenResolved {
|
||||
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
|
||||
tokenResolved = true
|
||||
}
|
||||
if auth != nil && token == "" {
|
||||
if tokenErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
|
||||
return
|
||||
}
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if body.Data != "" {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
}
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
|
||||
if errNewRequest != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
|
||||
return
|
||||
}
|
||||
|
||||
for key, value := range reqHeaders {
|
||||
if strings.EqualFold(key, "host") {
|
||||
hostOverride = strings.TrimSpace(value)
|
||||
continue
|
||||
}
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
if hostOverride != "" {
|
||||
req.Host = hostOverride
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
}
|
||||
httpClient.Transport = h.apiCallTransport(auth)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.WithError(errDo).Debug("management APICall request failed")
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
respBody, errReadAll := io.ReadAll(resp.Body)
|
||||
if errReadAll != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, apiCallResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header,
|
||||
Body: string(respBody),
|
||||
})
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...*string) string {
|
||||
for _, v := range values {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
if out := strings.TrimSpace(*v); out != "" {
|
||||
return out
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func tokenValueForAuth(auth *coreauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
|
||||
return v
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
||||
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if provider == "gemini-cli" {
|
||||
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
|
||||
return token, errToken
|
||||
}
|
||||
|
||||
return tokenValueForAuth(auth), nil
|
||||
}
|
||||
|
||||
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
metadata, updater := geminiOAuthMetadata(auth)
|
||||
if len(metadata) == 0 {
|
||||
return "", fmt.Errorf("gemini oauth metadata missing")
|
||||
}
|
||||
|
||||
base := make(map[string]any)
|
||||
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
|
||||
base = cloneMap(tokenRaw)
|
||||
}
|
||||
|
||||
var token oauth2.Token
|
||||
if len(base) > 0 {
|
||||
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
|
||||
_ = json.Unmarshal(raw, &token)
|
||||
}
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
token.AccessToken = stringValue(metadata, "access_token")
|
||||
}
|
||||
if token.RefreshToken == "" {
|
||||
token.RefreshToken = stringValue(metadata, "refresh_token")
|
||||
}
|
||||
if token.TokenType == "" {
|
||||
token.TokenType = stringValue(metadata, "token_type")
|
||||
}
|
||||
if token.Expiry.IsZero() {
|
||||
if expiry := stringValue(metadata, "expiry"); expiry != "" {
|
||||
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
|
||||
token.Expiry = ts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conf := &oauth2.Config{
|
||||
ClientID: geminiOAuthClientID,
|
||||
ClientSecret: geminiOAuthClientSecret,
|
||||
Scopes: geminiOAuthScopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
ctxToken := ctx
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
Transport: h.apiCallTransport(auth),
|
||||
}
|
||||
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
|
||||
|
||||
src := conf.TokenSource(ctxToken, &token)
|
||||
currentToken, errToken := src.Token()
|
||||
if errToken != nil {
|
||||
return "", errToken
|
||||
}
|
||||
|
||||
merged := buildOAuthTokenMap(base, currentToken)
|
||||
fields := buildOAuthTokenFields(currentToken, merged)
|
||||
if updater != nil {
|
||||
updater(fields)
|
||||
}
|
||||
return strings.TrimSpace(currentToken.AccessToken), nil
|
||||
}
|
||||
|
||||
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
|
||||
if auth == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
||||
snapshot := shared.MetadataSnapshot()
|
||||
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
|
||||
}
|
||||
return auth.Metadata, func(fields map[string]any) {
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
for k, v := range fields {
|
||||
auth.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(metadata map[string]any, key string) string {
|
||||
if len(metadata) == 0 || key == "" {
|
||||
return ""
|
||||
}
|
||||
if v, ok := metadata[key].(string); ok {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
|
||||
merged := cloneMap(base)
|
||||
if merged == nil {
|
||||
merged = make(map[string]any)
|
||||
}
|
||||
if tok == nil {
|
||||
return merged
|
||||
}
|
||||
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
|
||||
var tokenMap map[string]any
|
||||
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
|
||||
for k, v := range tokenMap {
|
||||
merged[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
|
||||
fields := make(map[string]any, 5)
|
||||
if tok != nil && tok.AccessToken != "" {
|
||||
fields["access_token"] = tok.AccessToken
|
||||
}
|
||||
if tok != nil && tok.TokenType != "" {
|
||||
fields["token_type"] = tok.TokenType
|
||||
}
|
||||
if tok != nil && tok.RefreshToken != "" {
|
||||
fields["refresh_token"] = tok.RefreshToken
|
||||
}
|
||||
if tok != nil && !tok.Expiry.IsZero() {
|
||||
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
|
||||
}
|
||||
if len(merged) > 0 {
|
||||
fields["token"] = cloneMap(merged)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func tokenValueFromMetadata(metadata map[string]any) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
|
||||
switch typed := tokenRaw.(type) {
|
||||
case string:
|
||||
if v := strings.TrimSpace(typed); v != "" {
|
||||
return v
|
||||
}
|
||||
case map[string]any:
|
||||
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
case map[string]string:
|
||||
if v := strings.TrimSpace(typed["access_token"]); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
|
||||
authIndex = strings.TrimSpace(authIndex)
|
||||
if authIndex == "" || h == nil || h.authManager == nil {
|
||||
return nil
|
||||
}
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
if auth.Index == authIndex {
|
||||
return auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
||||
var proxyCandidates []string
|
||||
if auth != nil {
|
||||
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||
}
|
||||
}
|
||||
if h != nil && h.cfg != nil {
|
||||
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, proxyStr := range proxyCandidates {
|
||||
if transport := buildProxyTransport(proxyStr); transport != nil {
|
||||
return transport
|
||||
}
|
||||
}
|
||||
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
return &http.Transport{Proxy: nil}
|
||||
}
|
||||
clone := transport.Clone()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
proxyStr = strings.TrimSpace(proxyStr)
|
||||
if proxyStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
proxyURL, errParse := url.Parse(proxyStr)
|
||||
if errParse != nil {
|
||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
||||
return nil
|
||||
}
|
||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
||||
log.Debug("proxy URL missing scheme/host")
|
||||
return nil
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
var proxyAuth *proxy.Auth
|
||||
if proxyURL.User != nil {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
||||
return nil
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
return nil
|
||||
}
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -35,10 +36,6 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
var (
|
||||
oauthStatus = make(map[string]string)
|
||||
)
|
||||
|
||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||
|
||||
const (
|
||||
@@ -200,6 +197,19 @@ func stopCallbackForwarder(port int) {
|
||||
stopForwarderInstance(port, forwarder)
|
||||
}
|
||||
|
||||
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||
if forwarder == nil {
|
||||
return
|
||||
}
|
||||
callbackForwardersMu.Lock()
|
||||
if current := callbackForwarders[port]; current == forwarder {
|
||||
delete(callbackForwarders, port)
|
||||
}
|
||||
callbackForwardersMu.Unlock()
|
||||
|
||||
stopForwarderInstance(port, forwarder)
|
||||
}
|
||||
|
||||
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||
if forwarder == nil || forwarder.server == nil {
|
||||
return
|
||||
@@ -266,6 +276,54 @@ func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetAuthFileModels returns the models supported by a specific auth file
|
||||
func (h *Handler) GetAuthFileModels(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
if name == "" {
|
||||
c.JSON(400, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Try to find auth ID via authManager
|
||||
var authID string
|
||||
if h.authManager != nil {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name || auth.ID == name {
|
||||
authID = auth.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if authID == "" {
|
||||
authID = name // fallback to filename as ID
|
||||
}
|
||||
|
||||
// Get models from registry
|
||||
reg := registry.GetGlobalRegistry()
|
||||
models := reg.GetModelsForClient(authID)
|
||||
|
||||
result := make([]gin.H, 0, len(models))
|
||||
for _, m := range models {
|
||||
entry := gin.H{
|
||||
"id": m.ID,
|
||||
}
|
||||
if m.DisplayName != "" {
|
||||
entry["display_name"] = m.DisplayName
|
||||
}
|
||||
if m.Type != "" {
|
||||
entry["type"] = m.Type
|
||||
}
|
||||
if m.OwnedBy != "" {
|
||||
entry["owned_by"] = m.OwnedBy
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"models": result})
|
||||
}
|
||||
|
||||
// List auth files from disk when the auth manager is unavailable.
|
||||
func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
||||
entries, err := os.ReadDir(h.cfg.AuthDir)
|
||||
@@ -369,9 +427,46 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
log.WithError(err).Warnf("failed to stat auth file %s", path)
|
||||
}
|
||||
}
|
||||
if claims := extractCodexIDTokenClaims(auth); claims != nil {
|
||||
entry["id_token"] = claims
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
||||
return nil
|
||||
}
|
||||
idTokenRaw, ok := auth.Metadata["id_token"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
idToken := strings.TrimSpace(idTokenRaw)
|
||||
if idToken == "" {
|
||||
return nil
|
||||
}
|
||||
claims, err := codex.ParseJWTToken(idToken)
|
||||
if err != nil || claims == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := gin.H{}
|
||||
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
|
||||
result["chatgpt_account_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
|
||||
result["plan_type"] = v
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func authEmail(auth *coreauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
@@ -737,7 +832,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "anthropic")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
|
||||
if errTarget != nil {
|
||||
@@ -745,7 +843,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -754,7 +853,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(anthropicCallbackPort)
|
||||
defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
// Helper: wait for callback file
|
||||
@@ -762,8 +861,11 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "anthropic") {
|
||||
return nil, errOAuthSessionNotPending
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
data, errRead := os.ReadFile(path)
|
||||
@@ -781,6 +883,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
// Wait up to 5 minutes
|
||||
resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
|
||||
if errWait != nil {
|
||||
if errors.Is(errWait, errOAuthSessionNotPending) {
|
||||
return
|
||||
}
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
return
|
||||
@@ -788,13 +893,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errStr := resultMap["error"]; errStr != "" {
|
||||
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad request"
|
||||
SetOAuthSessionError(state, "Bad request")
|
||||
return
|
||||
}
|
||||
if resultMap["state"] != state {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -827,7 +932,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errDo != nil {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -838,7 +943,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
var tResp struct {
|
||||
@@ -851,7 +956,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
bundle := &claude.ClaudeAuthBundle{
|
||||
@@ -876,7 +981,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -885,10 +990,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Claude services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("anthropic")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -919,7 +1024,10 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
RegisterOAuthSession(state, "gemini")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/google/callback")
|
||||
if errTarget != nil {
|
||||
@@ -927,7 +1035,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start gemini callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -936,7 +1045,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(geminiCallbackPort)
|
||||
defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
// Wait for callback file written by server route
|
||||
@@ -945,9 +1054,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var authCode string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "gemini") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -956,13 +1068,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
authCode = m["code"]
|
||||
if authCode == "" {
|
||||
log.Errorf("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -974,7 +1086,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
token, err := conf.Exchange(ctx, authCode)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to exchange token: %v", err)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -985,7 +1097,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Could not get user info: %v", errNewRequest)
|
||||
oauthStatus[state] = "Could not get user info"
|
||||
SetOAuthSessionError(state, "Could not get user info")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -994,7 +1106,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
resp, errDo := authHTTPClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to execute request"
|
||||
SetOAuthSessionError(state, "Failed to execute request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1006,7 +1118,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1015,7 +1127,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
fmt.Printf("Authenticated user email: %s\n", email)
|
||||
} else {
|
||||
fmt.Println("Failed to get user email from token")
|
||||
oauthStatus[state] = "Failed to get user email from token"
|
||||
}
|
||||
|
||||
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
||||
@@ -1023,7 +1134,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
jsonData, _ := json.Marshal(token)
|
||||
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
||||
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
||||
oauthStatus[state] = "Failed to unmarshal token"
|
||||
SetOAuthSessionError(state, "Failed to unmarshal token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1046,10 +1157,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
// Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
|
||||
gemAuth := geminiAuth.NewGeminiAuth()
|
||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{
|
||||
NoBrowser: true,
|
||||
})
|
||||
if errGetClient != nil {
|
||||
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
||||
oauthStatus[state] = "Failed to get authenticated client"
|
||||
SetOAuthSessionError(state, "Failed to get authenticated client")
|
||||
return
|
||||
}
|
||||
fmt.Println("Authentication successful.")
|
||||
@@ -1059,12 +1172,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
@@ -1072,26 +1185,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Onboarding did not return a project ID")
|
||||
oauthStatus[state] = "Failed to resolve project ID"
|
||||
SetOAuthSessionError(state, "Failed to resolve project ID")
|
||||
return
|
||||
}
|
||||
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
oauthStatus[state] = "Cloud AI API not enabled"
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1114,15 +1227,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("gemini")
|
||||
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1158,7 +1271,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "codex")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
|
||||
if errTarget != nil {
|
||||
@@ -1166,7 +1282,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start codex callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -1175,7 +1292,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(codexCallbackPort)
|
||||
defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
// Wait for callback file
|
||||
@@ -1183,10 +1300,13 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var code string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "codex") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -1196,12 +1316,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad Request"
|
||||
SetOAuthSessionError(state, "Bad Request")
|
||||
return
|
||||
}
|
||||
if m["state"] != state {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
return
|
||||
}
|
||||
@@ -1232,14 +1352,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
@@ -1250,7 +1370,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
return
|
||||
}
|
||||
@@ -1288,7 +1408,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1297,10 +1417,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Codex services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("codex")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1341,7 +1461,10 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
params.Set("state", state)
|
||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
||||
|
||||
RegisterOAuthSession(state, "antigravity")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
|
||||
if errTarget != nil {
|
||||
@@ -1349,7 +1472,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -1358,16 +1482,19 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(antigravityCallbackPort)
|
||||
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var authCode string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "antigravity") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||
@@ -1376,18 +1503,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||
log.Errorf("Authentication failed: state mismatch")
|
||||
oauthStatus[state] = "Authentication failed: state mismatch"
|
||||
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
authCode = strings.TrimSpace(payload["code"])
|
||||
if authCode == "" {
|
||||
log.Error("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -1406,7 +1533,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
||||
oauthStatus[state] = "Failed to build token request"
|
||||
SetOAuthSessionError(state, "Failed to build token request")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
@@ -1414,7 +1541,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute token request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1426,7 +1553,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1438,7 +1565,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
}
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1447,7 +1574,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errInfoReq != nil {
|
||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
||||
oauthStatus[state] = "Failed to build user info request"
|
||||
SetOAuthSessionError(state, "Failed to build user info request")
|
||||
return
|
||||
}
|
||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||
@@ -1455,7 +1582,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoResp, errInfo := httpClient.Do(infoReq)
|
||||
if errInfo != nil {
|
||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
||||
oauthStatus[state] = "Failed to execute user info request"
|
||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1474,7 +1601,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
} else {
|
||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1522,11 +1649,12 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("antigravity")
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
@@ -1534,7 +1662,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
fmt.Println("You can now use Antigravity services through this CLI")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1556,11 +1683,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
|
||||
RegisterOAuthSession(state, "qwen")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||
if errPollForToken != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||
return
|
||||
}
|
||||
@@ -1579,16 +1708,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Qwen services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1601,7 +1729,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
||||
|
||||
RegisterOAuthSession(state, "iflow")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
|
||||
if errTarget != nil {
|
||||
@@ -1609,7 +1740,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start iflow callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
|
||||
return
|
||||
@@ -1618,7 +1750,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(iflowauth.CallbackPort)
|
||||
defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
|
||||
}
|
||||
fmt.Println("Waiting for authentication...")
|
||||
|
||||
@@ -1626,8 +1758,11 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var resultMap map[string]string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "iflow") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||
return
|
||||
}
|
||||
@@ -1640,26 +1775,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||
return
|
||||
}
|
||||
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(resultMap["code"])
|
||||
if code == "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: code missing")
|
||||
return
|
||||
}
|
||||
|
||||
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||
if errExchange != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||
return
|
||||
}
|
||||
@@ -1681,7 +1816,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1691,10 +1826,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use iFlow services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("iflow")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1722,6 +1857,17 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for duplicate BXAuth before authentication
|
||||
bxAuth := iflowauth.ExtractBXAuth(cookieValue)
|
||||
if existingFile, err := iflowauth.CheckDuplicateBXAuth(h.cfg.AuthDir, bxAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to check duplicate"})
|
||||
return
|
||||
} else if existingFile != "" {
|
||||
existingFileName := filepath.Base(existingFile)
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "duplicate BXAuth found", "existing_file": existingFileName})
|
||||
return
|
||||
}
|
||||
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue)
|
||||
if errAuth != nil {
|
||||
@@ -1744,11 +1890,12 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
tokenStorage.Email = email
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fmt.Sprintf("iflow-%s.json", fileName),
|
||||
ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
||||
Provider: "iflow",
|
||||
FileName: fmt.Sprintf("iflow-%s.json", fileName),
|
||||
FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{
|
||||
"email": email,
|
||||
@@ -2118,16 +2265,24 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
}
|
||||
|
||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if err, ok := oauthStatus[state]; ok {
|
||||
if err != "" {
|
||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "wait"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if state == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
delete(oauthStatus, state)
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
|
||||
_, status, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
if status != "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
@@ -145,71 +145,74 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
||||
type geminiKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.GeminiKey `json:"value"`
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *geminiKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
|
||||
if value.APIKey == "" {
|
||||
// Treat empty API key as delete.
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
if match != "" {
|
||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||
removed := false
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if !removed && h.cfg.GeminiKey[i].APIKey == match {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.GeminiKey[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.GeminiKey = out
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
if match != "" {
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if h.cfg.GeminiKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey[*body.Index] = value
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if h.cfg.GeminiKey[i].APIKey == match {
|
||||
h.cfg.GeminiKey[i] = value
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry := h.cfg.GeminiKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.APIKey)
|
||||
if trimmed == "" {
|
||||
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.APIKey = trimmed
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
h.cfg.GeminiKey[targetIndex] = entry
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||
@@ -268,35 +271,70 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
||||
type claudeKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Models *[]config.ClaudeModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.ClaudeKey `json:"value"`
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *claudeKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
normalizeClaudeKey(&value)
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
|
||||
h.cfg.ClaudeKey[*body.Index] = value
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if body.Match != nil {
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.ClaudeKey {
|
||||
if h.cfg.ClaudeKey[i].APIKey == *body.Match {
|
||||
h.cfg.ClaudeKey[i] = value
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
if h.cfg.ClaudeKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.ClaudeKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
normalizeClaudeKey(&entry)
|
||||
h.cfg.ClaudeKey[targetIndex] = entry
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
||||
if val := c.Query("api-key"); val != "" {
|
||||
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
||||
@@ -356,62 +394,73 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchOpenAICompat(c *gin.Context) {
|
||||
type openAICompatPatch struct {
|
||||
Name *string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"`
|
||||
Models *[]config.OpenAICompatibilityModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
}
|
||||
var body struct {
|
||||
Name *string `json:"name"`
|
||||
Index *int `json:"index"`
|
||||
Value *config.OpenAICompatibility `json:"value"`
|
||||
Name *string `json:"name"`
|
||||
Index *int `json:"index"`
|
||||
Value *openAICompatPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
normalizeOpenAICompatibilityEntry(body.Value)
|
||||
// If base-url becomes empty, delete the provider instead of updating
|
||||
if strings.TrimSpace(body.Value.BaseURL) == "" {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...)
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if targetIndex == -1 && body.Name != nil {
|
||||
match := strings.TrimSpace(*body.Name)
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if h.cfg.OpenAICompatibility[i].Name == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.OpenAICompatibility[targetIndex]
|
||||
if body.Value.Name != nil {
|
||||
entry.Name = strings.TrimSpace(*body.Value.Name)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||
if trimmed == "" {
|
||||
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...)
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Name != nil {
|
||||
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
||||
removed := false
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.OpenAICompatibility[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.OpenAICompatibility = out
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
entry.BaseURL = trimmed
|
||||
}
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility[*body.Index] = *body.Value
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
if body.Value.APIKeyEntries != nil {
|
||||
entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...)
|
||||
}
|
||||
if body.Name != nil {
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if h.cfg.OpenAICompatibility[i].Name == *body.Name {
|
||||
h.cfg.OpenAICompatibility[i] = *body.Value
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
normalizeOpenAICompatibilityEntry(&entry)
|
||||
h.cfg.OpenAICompatibility[targetIndex] = entry
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
||||
if name := c.Query("name"); name != "" {
|
||||
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
||||
@@ -548,11 +597,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
filtered := make([]config.CodexKey, 0, len(arr))
|
||||
for i := range arr {
|
||||
entry := arr[i]
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
normalizeCodexKey(&entry)
|
||||
if entry.BaseURL == "" {
|
||||
continue
|
||||
}
|
||||
@@ -563,66 +608,77 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchCodexKey(c *gin.Context) {
|
||||
type codexKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Models *[]config.CodexModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.CodexKey `json:"value"`
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *codexKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.Headers = config.NormalizeHeaders(value.Headers)
|
||||
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
|
||||
// If base-url becomes empty, delete instead of update
|
||||
if value.BaseURL == "" {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...)
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||
removed := false
|
||||
for i := range h.cfg.CodexKey {
|
||||
if !removed && h.cfg.CodexKey[i].APIKey == *body.Match {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.CodexKey[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.CodexKey = out
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey[*body.Index] = value
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
for i := range h.cfg.CodexKey {
|
||||
if h.cfg.CodexKey[i].APIKey == *body.Match {
|
||||
h.cfg.CodexKey[i] = value
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.CodexKey {
|
||||
if h.cfg.CodexKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.CodexKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||
if trimmed == "" {
|
||||
h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.BaseURL = trimmed
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
normalizeCodexKey(&entry)
|
||||
h.cfg.CodexKey[targetIndex] = entry
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
||||
if val := c.Query("api-key"); val != "" {
|
||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||
@@ -707,6 +763,32 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
func normalizeCodexKey(entry *config.CodexKey) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.Prefix = strings.TrimSpace(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
normalized := make([]config.CodexModel, 0, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
model.Name = strings.TrimSpace(model.Name)
|
||||
model.Alias = strings.TrimSpace(model.Alias)
|
||||
if model.Name == "" && model.Alias == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, model)
|
||||
}
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
// GetAmpCode returns the complete ampcode configuration.
|
||||
func (h *Handler) GetAmpCode(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
@@ -858,3 +940,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
|
||||
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
||||
}
|
||||
|
||||
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
|
||||
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
|
||||
}
|
||||
|
||||
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
|
||||
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
// Normalize entries: trim whitespace, filter empty
|
||||
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
|
||||
// Matching is done by upstream-api-key value.
|
||||
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
existing := make(map[string]int)
|
||||
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
|
||||
}
|
||||
|
||||
for _, newEntry := range body.Value {
|
||||
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
|
||||
UpstreamAPIKey: upstreamKey,
|
||||
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
|
||||
}
|
||||
if idx, ok := existing[upstreamKey]; ok {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
|
||||
} else {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
|
||||
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
|
||||
}
|
||||
}
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
|
||||
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
|
||||
// If "value" is an empty array, clears all entries.
|
||||
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
|
||||
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []string `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
if body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "missing value"})
|
||||
return
|
||||
}
|
||||
|
||||
// Empty array means clear all
|
||||
if len(body.Value) == 0 {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = nil
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
|
||||
toRemove := make(map[string]bool)
|
||||
for _, key := range body.Value {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
toRemove[trimmed] = true
|
||||
}
|
||||
if len(toRemove) == 0 {
|
||||
c.JSON(400, gin.H{"error": "empty value"})
|
||||
return
|
||||
}
|
||||
|
||||
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
|
||||
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
|
||||
newEntries = append(newEntries, entry)
|
||||
}
|
||||
}
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
|
||||
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
apiKeys := normalizeAPIKeysList(entry.APIKeys)
|
||||
out = append(out, config.AmpUpstreamAPIKeyEntry{
|
||||
UpstreamAPIKey: upstreamKey,
|
||||
APIKeys: apiKeys,
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
|
||||
func normalizeAPIKeysList(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed != "" {
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -59,6 +59,11 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
|
||||
return NewHandler(cfg, "", manager)
|
||||
}
|
||||
|
||||
// SetConfig updates the in-memory config reference when the server hot-reloads.
|
||||
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
|
||||
|
||||
|
||||
@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetRequestLogByID finds and downloads a request log file by its request ID.
|
||||
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
|
||||
func (h *Handler) GetRequestLogByID(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(c.Param("id"))
|
||||
if requestID == "" {
|
||||
requestID = strings.TrimSpace(c.Query("id"))
|
||||
}
|
||||
if requestID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
|
||||
return
|
||||
}
|
||||
if strings.ContainsAny(requestID, "/\\") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
suffix := "-" + requestID + ".log"
|
||||
var matchedFile string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, suffix) {
|
||||
matchedFile = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchedFile == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
dirAbs, errAbs := filepath.Abs(dir)
|
||||
if errAbs != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
||||
return
|
||||
}
|
||||
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
|
||||
prefix := dirAbs + string(os.PathSeparator)
|
||||
if !strings.HasPrefix(fullPath, prefix) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
||||
return
|
||||
}
|
||||
|
||||
info, errStat := os.Stat(fullPath)
|
||||
if errStat != nil {
|
||||
if os.IsNotExist(errStat) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
||||
return
|
||||
}
|
||||
if info.IsDir() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.FileAttachment(fullPath, matchedFile)
|
||||
}
|
||||
|
||||
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
||||
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
||||
if h == nil {
|
||||
|
||||
100
internal/api/handlers/management/oauth_callback.go
Normal file
100
internal/api/handlers/management/oauth_callback.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type oauthCallbackRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (h *Handler) PostOAuthCallback(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req oauthCallbackRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
|
||||
return
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(req.State)
|
||||
code := strings.TrimSpace(req.Code)
|
||||
errMsg := strings.TrimSpace(req.Error)
|
||||
|
||||
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
|
||||
u, errParse := url.Parse(rawRedirect)
|
||||
if errParse != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
}
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error"))
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
||||
return
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
if code == "" && errMsg == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
|
||||
return
|
||||
}
|
||||
if sessionStatus != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(sessionProvider, canonicalProvider) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
|
||||
if errors.Is(errWrite, errOAuthSessionNotPending) {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
283
internal/api/handlers/management/oauth_sessions.go
Normal file
283
internal/api/handlers/management/oauth_sessions.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthSessionTTL = 10 * time.Minute
|
||||
maxOAuthStateLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidOAuthState = errors.New("invalid oauth state")
|
||||
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
||||
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
||||
)
|
||||
|
||||
type oauthSession struct {
|
||||
Provider string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type oauthSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
sessions map[string]oauthSession
|
||||
}
|
||||
|
||||
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
||||
if ttl <= 0 {
|
||||
ttl = oauthSessionTTL
|
||||
}
|
||||
return &oauthSessionStore{
|
||||
ttl: ttl,
|
||||
sessions: make(map[string]oauthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
||||
for state, session := range s.sessions {
|
||||
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Register(state, provider string) {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if state == "" || provider == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
s.sessions[state] = oauthSession{
|
||||
Provider: provider,
|
||||
Status: "",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) SetError(state, message string) {
|
||||
state = strings.TrimSpace(state)
|
||||
message = strings.TrimSpace(message)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
if message == "" {
|
||||
message = "Authentication failed"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
session.Status = message
|
||||
session.ExpiresAt = now.Add(s.ttl)
|
||||
s.sessions[state] = session
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Complete(state string) {
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) CompleteProvider(provider string) int {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return 0
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
removed := 0
|
||||
for state, session := range s.sessions {
|
||||
if strings.EqualFold(session.Provider, provider) {
|
||||
delete(s.sessions, state)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
||||
state = strings.TrimSpace(state)
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if session.Status != "" {
|
||||
return false
|
||||
}
|
||||
if provider == "" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(session.Provider, provider)
|
||||
}
|
||||
|
||||
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
||||
|
||||
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
||||
|
||||
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
||||
|
||||
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
||||
|
||||
func CompleteOAuthSessionsByProvider(provider string) int {
|
||||
return oauthSessions.CompleteProvider(provider)
|
||||
}
|
||||
|
||||
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
||||
session, ok := oauthSessions.Get(state)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return session.Provider, session.Status, true
|
||||
}
|
||||
|
||||
func IsOAuthSessionPending(state, provider string) bool {
|
||||
return oauthSessions.IsPending(state, provider)
|
||||
}
|
||||
|
||||
func ValidateOAuthState(state string) error {
|
||||
trimmed := strings.TrimSpace(state)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
||||
}
|
||||
if len(trimmed) > maxOAuthStateLength {
|
||||
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
||||
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "..") {
|
||||
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
||||
}
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "anthropic", "claude":
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
return "iflow", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "qwen":
|
||||
return "qwen", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
}
|
||||
|
||||
type oauthCallbackFilePayload struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
if strings.TrimSpace(authDir) == "" {
|
||||
return "", fmt.Errorf("auth dir is empty")
|
||||
}
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
payload := oauthCallbackFilePayload{
|
||||
Code: strings.TrimSpace(code),
|
||||
State: strings.TrimSpace(state),
|
||||
Error: strings.TrimSpace(errorMessage),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
||||
return "", fmt.Errorf("write oauth callback file: %w", err)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !IsOAuthSessionPending(state, canonicalProvider) {
|
||||
return "", errOAuthSessionNotPending
|
||||
}
|
||||
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
||||
}
|
||||
@@ -1,12 +1,25 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
)
|
||||
|
||||
type usageExportPayload struct {
|
||||
Version int `json:"version"`
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
type usageImportPayload struct {
|
||||
Version int `json:"version"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
||||
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
|
||||
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
|
||||
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
if h != nil && h.usageStats != nil {
|
||||
snapshot = h.usageStats.Snapshot()
|
||||
}
|
||||
c.JSON(http.StatusOK, usageExportPayload{
|
||||
Version: 1,
|
||||
ExportedAt: time.Now().UTC(),
|
||||
Usage: snapshot,
|
||||
})
|
||||
}
|
||||
|
||||
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
|
||||
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
|
||||
if h == nil || h.usageStats == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
var payload usageImportPayload
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
|
||||
return
|
||||
}
|
||||
if payload.Version != 0 && payload.Version != 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
|
||||
return
|
||||
}
|
||||
|
||||
result := h.usageStats.MergeSnapshot(payload.Usage)
|
||||
snapshot := h.usageStats.Snapshot()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"added": result.Added,
|
||||
"skipped": result.Skipped,
|
||||
"total_requests": snapshot.TotalRequests,
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
}
|
||||
|
||||
return &RequestInfo{
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
RequestID: logging.GetGinRequestID(c),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,11 @@ import (
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
RequestID string // RequestID is the unique identifier for the request.
|
||||
}
|
||||
|
||||
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
||||
@@ -71,22 +72,64 @@ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(data)
|
||||
|
||||
// THEN: Handle logging based on response type
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
// For streaming responses: Send to async logging channel (non-blocking)
|
||||
if w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
} else {
|
||||
// For non-streaming responses: Buffer complete response
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.Write(data)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool {
|
||||
if w.logger != nil && w.logger.IsEnabled() {
|
||||
return true
|
||||
}
|
||||
if !w.logOnErrorOnly {
|
||||
return false
|
||||
}
|
||||
status := w.statusCode
|
||||
if status == 0 {
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil {
|
||||
status = statusWriter.Status()
|
||||
} else {
|
||||
status = http.StatusOK
|
||||
}
|
||||
}
|
||||
return status >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
// WriteString wraps the underlying ResponseWriter's WriteString method to capture response data.
|
||||
// Some handlers (and fmt/io helpers) write via io.StringWriter; without this override, those writes
|
||||
// bypass Write() and would be missing from request logs.
|
||||
func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
// CRITICAL: Write to client first (zero latency)
|
||||
n, err := w.ResponseWriter.WriteString(data)
|
||||
|
||||
// THEN: Capture for logging
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- []byte(data):
|
||||
default:
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.WriteString(data)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
|
||||
// It captures the status code, detects if the response is streaming based on the Content-Type header,
|
||||
// and initializes the appropriate logging mechanism (standard or streaming).
|
||||
@@ -107,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
w.requestInfo.Body,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
if err == nil {
|
||||
w.streamWriter = streamWriter
|
||||
@@ -160,12 +204,16 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check request body for streaming indicators
|
||||
if w.requestInfo.Body != nil {
|
||||
// If a concrete Content-Type is already set (e.g., application/json for error responses),
|
||||
// treat it as non-streaming instead of inferring from the request payload.
|
||||
if strings.TrimSpace(contentType) != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -221,7 +269,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.streamWriter != nil {
|
||||
if w.chunkChannel != nil {
|
||||
close(w.chunkChannel)
|
||||
w.chunkChannel = nil
|
||||
@@ -233,24 +281,19 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
}
|
||||
|
||||
// Write API Request and Response to the streaming log before closing
|
||||
if w.streamWriter != nil {
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
if forceLog {
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
w.streamWriter = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -305,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
@@ -319,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -333,28 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
}
|
||||
|
||||
// Status returns the HTTP response status code captured by the wrapper.
|
||||
// It defaults to 200 if WriteHeader has not been called.
|
||||
func (w *ResponseWriterWrapper) Status() int {
|
||||
if w.statusCode == 0 {
|
||||
return 200 // Default status code
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
// Size returns the size of the response body in bytes for non-streaming responses.
|
||||
// For streaming responses, it returns -1, as the total size is unknown.
|
||||
func (w *ResponseWriterWrapper) Size() int {
|
||||
if w.isStreaming {
|
||||
return -1 // Unknown size for streaming responses
|
||||
}
|
||||
return w.body.Len()
|
||||
}
|
||||
|
||||
// Written returns true if the response header has been written (i.e., a status code has been set).
|
||||
func (w *ResponseWriterWrapper) Written() bool {
|
||||
return w.statusCode != 0
|
||||
}
|
||||
|
||||
@@ -137,7 +137,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
||||
// Pass auth middleware to require valid API key for all management routes.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
@@ -187,9 +188,6 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
|
||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||
if !newSettings.RestrictManagementToLocalhost {
|
||||
log.Warnf("amp management routes now accessible from any IP - this is insecure!")
|
||||
}
|
||||
}
|
||||
|
||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||
@@ -229,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check API key change
|
||||
// Check API key change (both default and per-client mappings)
|
||||
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
||||
if apiKeyChanged {
|
||||
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
|
||||
if apiKeyChanged || upstreamAPIKeysChanged {
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||
if apiKeyChanged {
|
||||
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
if upstreamAPIKeysChanged {
|
||||
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
|
||||
}
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
@@ -253,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
|
||||
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
||||
if m.secretSource == nil {
|
||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
|
||||
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
mappedSource := NewMappedSecretSource(defaultSource)
|
||||
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
m.secretSource = mappedSource
|
||||
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
ms.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
|
||||
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
mappedSource := NewMappedSecretSource(ms)
|
||||
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
m.secretSource = mappedSource
|
||||
}
|
||||
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
@@ -281,16 +300,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for efficient comparison
|
||||
oldMap := make(map[string]string, len(old.ModelMappings))
|
||||
// Build map for efficient and robust comparison
|
||||
type mappingInfo struct {
|
||||
to string
|
||||
regex bool
|
||||
}
|
||||
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
|
||||
for _, mapping := range old.ModelMappings {
|
||||
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
|
||||
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
|
||||
to: strings.TrimSpace(mapping.To),
|
||||
regex: mapping.Regex,
|
||||
}
|
||||
}
|
||||
|
||||
for _, mapping := range new.ModelMappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
|
||||
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -308,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
|
||||
return oldKey != newKey
|
||||
}
|
||||
|
||||
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
|
||||
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||
if old == nil {
|
||||
return len(new.UpstreamAPIKeys) > 0
|
||||
}
|
||||
|
||||
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for comparison: upstreamKey -> set of clientKeys
|
||||
type entryInfo struct {
|
||||
upstreamKey string
|
||||
clientKeys map[string]struct{}
|
||||
}
|
||||
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
|
||||
for i, entry := range old.UpstreamAPIKeys {
|
||||
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
|
||||
for _, k := range entry.APIKeys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
clientKeys[trimmed] = struct{}{}
|
||||
}
|
||||
oldEntries[i] = entryInfo{
|
||||
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
|
||||
clientKeys: clientKeys,
|
||||
}
|
||||
}
|
||||
|
||||
for i, newEntry := range new.UpstreamAPIKeys {
|
||||
if i >= len(oldEntries) {
|
||||
return true
|
||||
}
|
||||
oldE := oldEntries[i]
|
||||
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
|
||||
return true
|
||||
}
|
||||
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
|
||||
for _, k := range newEntry.APIKeys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
newKeys[trimmed] = struct{}{}
|
||||
}
|
||||
if len(newKeys) != len(oldE.clientKeys) {
|
||||
return true
|
||||
}
|
||||
for k := range newKeys {
|
||||
if _, ok := oldE.clientKeys[k]; !ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||
return m.modelMapper
|
||||
|
||||
@@ -146,6 +146,9 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
m.lastConfig = &config.AmpCode{
|
||||
UpstreamAPIKey: "old-key",
|
||||
}
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
@@ -157,7 +160,7 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x"}}); err != nil {
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -309,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
|
||||
m := &AmpModule{}
|
||||
|
||||
oldCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||
},
|
||||
}
|
||||
newCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
|
||||
},
|
||||
}
|
||||
|
||||
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
|
||||
m := &AmpModule{}
|
||||
|
||||
oldCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||
},
|
||||
}
|
||||
newCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
|
||||
},
|
||||
}
|
||||
|
||||
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||
t.Fatal("expected no change when only whitespace/empty entries differ")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
fields["cost"] = "amp_credits"
|
||||
fields["source"] = "ampcode.com"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
|
||||
case RouteTypeNoProvider:
|
||||
fields["cost"] = "none"
|
||||
@@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
}
|
||||
|
||||
// Normalize model (handles dynamic thinking suffixes)
|
||||
normalizedModel, _ := util.NormalizeThinkingModel(modelName)
|
||||
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
|
||||
thinkingSuffix := ""
|
||||
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
|
||||
thinkingSuffix = modelName[len(normalizedModel):]
|
||||
}
|
||||
|
||||
resolveMappedModel := func() (string, []string) {
|
||||
if fh.modelMapper == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
mappedModel = strings.TrimSpace(mappedModel)
|
||||
if mappedModel == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||
// already specifies its own thinking suffix.
|
||||
if thinkingSuffix != "" {
|
||||
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
|
||||
if mappedThinkingMetadata == nil {
|
||||
mappedModel += thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return mappedModel, mappedProviders
|
||||
}
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
resolvedModel := normalizedModel
|
||||
@@ -147,21 +183,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - check if we have a provider for the mapped model
|
||||
mappedProviders := util.GetProviderName(mappedModel)
|
||||
if len(mappedProviders) > 0 {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
|
||||
// If no mapping applied, check for local providers
|
||||
@@ -174,21 +204,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - check if we have a provider for the mapped model
|
||||
mappedProviders := util.GetProviderName(mappedModel)
|
||||
if len(mappedProviders) > 0 {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
// Log: Model was mapped to another model
|
||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
rewriter.Flush()
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel)
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
||||
} else if len(providers) > 0 {
|
||||
// Log: Using local provider (free)
|
||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||
|
||||
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
||||
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-amp-fallback")
|
||||
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
||||
})
|
||||
|
||||
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
||||
|
||||
handler := func(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": req.Model,
|
||||
"seen_model": req.Model,
|
||||
})
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
||||
|
||||
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Model string `json:"model"`
|
||||
SeenModel string `json:"seen_model"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response JSON: %v", err)
|
||||
}
|
||||
|
||||
if resp.Model != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
||||
}
|
||||
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -26,13 +27,15 @@ type ModelMapper interface {
|
||||
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
||||
type DefaultModelMapper struct {
|
||||
mu sync.RWMutex
|
||||
mappings map[string]string // from -> to (normalized lowercase keys)
|
||||
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||
regexps []regexMapping // regex rules evaluated in order
|
||||
}
|
||||
|
||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
m := &DefaultModelMapper{
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
}
|
||||
m.UpdateMappings(mappings)
|
||||
return m
|
||||
@@ -55,11 +58,23 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
// Check for direct mapping
|
||||
targetModel, exists := m.mappings[normalizedRequest]
|
||||
if !exists {
|
||||
return ""
|
||||
// Try regex mappings in order
|
||||
base, _ := util.NormalizeThinkingModel(requestedModel)
|
||||
for _, rm := range m.regexps {
|
||||
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
|
||||
targetModel = rm.to
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Verify target model has available providers
|
||||
providers := util.GetProviderName(targetModel)
|
||||
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
|
||||
providers := util.GetProviderName(normalizedTarget)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
@@ -77,6 +92,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
|
||||
// Clear and rebuild mappings
|
||||
m.mappings = make(map[string]string, len(mappings))
|
||||
m.regexps = make([]regexMapping, 0, len(mappings))
|
||||
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
@@ -87,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
if mapping.Regex {
|
||||
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
|
||||
pattern := "(?i)" + from
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
|
||||
continue
|
||||
}
|
||||
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
|
||||
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
|
||||
} else {
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.mappings) > 0 {
|
||||
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
||||
}
|
||||
if n := len(m.regexps); n > 0 {
|
||||
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMappings returns a copy of current mappings (for debugging/status).
|
||||
@@ -110,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type regexMapping struct {
|
||||
re *regexp.Regexp
|
||||
to string
|
||||
}
|
||||
|
||||
@@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
|
||||
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-thinking")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("gpt-5.2-alias")
|
||||
if result != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
||||
@@ -184,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
|
||||
t.Error("Original map was modified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-1")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Incoming model has reasoning suffix but should match base via regex
|
||||
result := mapper.MapModel("gpt-5(high)")
|
||||
if result != "gemini-2.5-pro" {
|
||||
t.Errorf("Expected gemini-2.5-pro, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-2")
|
||||
defer reg.UnregisterClient("test-client-regex-3")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
|
||||
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Exact match should win over regex
|
||||
result := mapper.MapModel("gpt-5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
|
||||
// Invalid regex should be skipped and not cause panic
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "(", To: "target", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("anything")
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty result due to invalid regex, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-4")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("claude-opus-4.5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,33 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func removeQueryValuesMatching(req *http.Request, key string, match string) {
|
||||
if req == nil || req.URL == nil || match == "" {
|
||||
return
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
values, ok := q[key]
|
||||
if !ok || len(values) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
kept := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
if v == match {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, v)
|
||||
}
|
||||
|
||||
if len(kept) == 0 {
|
||||
q.Del(key)
|
||||
} else {
|
||||
q[key] = kept
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// readCloser wraps a reader and forwards Close to a separate closer.
|
||||
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
||||
type readCloser struct {
|
||||
@@ -41,6 +68,19 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
originalDirector(req)
|
||||
req.Host = parsed.Host
|
||||
|
||||
// Remove client's Authorization header - it was only used for CLI Proxy API authentication
|
||||
// We will set our own Authorization using the configured upstream-api-key
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Del("X-Api-Key")
|
||||
req.Header.Del("X-Goog-Api-Key")
|
||||
|
||||
// Remove query-based credentials if they match the authenticated client API key.
|
||||
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||
// breaking unrelated upstream query parameters.
|
||||
clientKey := getClientAPIKeyFromContext(req.Context())
|
||||
removeQueryValuesMatching(req, "key", clientKey)
|
||||
removeQueryValuesMatching(req, "auth_token", clientKey)
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
// Could generate one here if needed
|
||||
@@ -50,7 +90,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Users going through ampcode.com proxy are paying for the service and should get all features
|
||||
// including 1M context window (context-1m-2025-08-07)
|
||||
|
||||
// Inject API key from secret source (precedence: config > env > file)
|
||||
// Inject API key from secret source (only uses upstream-api-key from config)
|
||||
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
||||
req.Header.Set("X-Api-Key", key)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
|
||||
@@ -3,11 +3,15 @@ package amp
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// Helper: compress data with gzip
|
||||
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
|
||||
type captured struct {
|
||||
headers http.Header
|
||||
query string
|
||||
}
|
||||
got := make(chan captured, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer client-key")
|
||||
req.Header.Set("X-Api-Key", "client-key")
|
||||
req.Header.Set("X-Goog-Api-Key", "client-key")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
c := <-got
|
||||
|
||||
// These are client-provided credentials and must not reach the upstream.
|
||||
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
|
||||
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
|
||||
}
|
||||
|
||||
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
|
||||
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
|
||||
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
|
||||
}
|
||||
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
|
||||
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
|
||||
}
|
||||
|
||||
// Query-based credentials should be stripped only when they match the authenticated client key.
|
||||
// Should keep unrelated values and parameters.
|
||||
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
|
||||
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
|
||||
}
|
||||
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
|
||||
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
mapped := NewMappedSecretSource(defaultSource)
|
||||
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "u1" {
|
||||
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer u1" {
|
||||
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
mapped := NewMappedSecretSource(defaultSource)
|
||||
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "default" {
|
||||
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer default" {
|
||||
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
// Point proxy to a non-routable address to trigger error
|
||||
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
||||
|
||||
@@ -39,7 +39,13 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -16,6 +17,37 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// clientAPIKeyContextKey is the context key used to pass the client API key
|
||||
// from gin.Context to the request context for SecretSource lookup.
|
||||
type clientAPIKeyContextKey struct{}
|
||||
|
||||
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
|
||||
// into the request context so that SecretSource can look it up for per-client upstream routing.
|
||||
func clientAPIKeyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Extract the client API key from gin context (set by AuthMiddleware)
|
||||
if apiKey, exists := c.Get("apiKey"); exists {
|
||||
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
|
||||
// Inject into request context for SecretSource.Get(ctx) to read
|
||||
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// getClientAPIKeyFromContext retrieves the client API key from request context.
|
||||
// Returns empty string if not present.
|
||||
func getClientAPIKeyFromContext(ctx context.Context) string {
|
||||
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
|
||||
if keyStr, ok := val.(string); ok {
|
||||
return keyStr
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
||||
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
||||
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
@@ -95,10 +127,25 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
|
||||
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
auth(c)
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
|
||||
// The auth middleware validates Authorization header against configured API keys.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
@@ -107,10 +154,16 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
if !m.IsRestrictedToLocalhost() {
|
||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
// Apply authentication middleware - requires valid API key in Authorization header
|
||||
var authWithBypass gin.HandlerFunc
|
||||
if auth != nil {
|
||||
ampAPI.Use(auth)
|
||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
||||
}
|
||||
|
||||
// Inject client API key into request context for per-client upstream routing
|
||||
ampAPI.Use(clientAPIKeyMiddleware())
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||
@@ -154,7 +207,18 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
if authWithBypass != nil {
|
||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||
}
|
||||
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
|
||||
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
|
||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/settings", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
@@ -217,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
if auth != nil {
|
||||
ampProviders.Use(auth)
|
||||
}
|
||||
// Inject client API key into request context for per-client upstream routing
|
||||
ampProviders.Use(clientAPIKeyMiddleware())
|
||||
|
||||
provider := ampProviders.Group("/:provider")
|
||||
|
||||
@@ -262,7 +328,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,9 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
m.setProxy(proxy)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base)
|
||||
m.registerManagementRoutes(r, base, nil)
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
@@ -63,11 +65,17 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path.path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(path.method, path.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path.path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SecretSource provides Amp API keys with configurable precedence and caching
|
||||
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
|
||||
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
|
||||
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
|
||||
// When a request context contains a client API key that matches a configured mapping,
|
||||
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
|
||||
type MappedSecretSource struct {
|
||||
defaultSource SecretSource
|
||||
mu sync.RWMutex
|
||||
lookup map[string]string // clientKey -> upstreamKey
|
||||
}
|
||||
|
||||
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
|
||||
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
|
||||
return &MappedSecretSource{
|
||||
defaultSource: defaultSource,
|
||||
lookup: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves the Amp API key, checking per-client mappings first.
|
||||
// If the request context contains a client API key that matches a configured mapping,
|
||||
// returns the corresponding upstream key. Otherwise, falls back to the default source.
|
||||
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
|
||||
// Try to get client API key from request context
|
||||
clientKey := getClientAPIKeyFromContext(ctx)
|
||||
if clientKey != "" {
|
||||
s.mu.RLock()
|
||||
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
|
||||
s.mu.RUnlock()
|
||||
return upstreamKey, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
|
||||
// Fall back to default source
|
||||
return s.defaultSource.Get(ctx)
|
||||
}
|
||||
|
||||
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
|
||||
// If the same client key appears in multiple entries, logs a warning and uses the first one.
|
||||
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
|
||||
newLookup := make(map[string]string)
|
||||
|
||||
for _, entry := range entries {
|
||||
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
for _, clientKey := range entry.APIKeys {
|
||||
trimmedKey := strings.TrimSpace(clientKey)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := newLookup[trimmedKey]; exists {
|
||||
// Log warning for duplicate client key, first one wins
|
||||
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
|
||||
continue
|
||||
}
|
||||
newLookup[trimmedKey] = upstreamKey
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.lookup = newLookup
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
|
||||
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
|
||||
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(key)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
|
||||
func (s *MappedSecretSource) InvalidateCache() {
|
||||
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
)
|
||||
|
||||
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
||||
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
||||
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "u1" {
|
||||
t.Fatalf("want u1, got %q", got)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
|
||||
got, err = s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "default" {
|
||||
t.Fatalf("want default fallback, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
{
|
||||
UpstreamAPIKey: "u2",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "u1" {
|
||||
t.Fatalf("want u1 (first wins), got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
|
||||
hook := test.NewLocal(log.StandardLogger())
|
||||
defer hook.Reset()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
{
|
||||
UpstreamAPIKey: "u2",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
foundWarning := false
|
||||
for _, entry := range hook.AllEntries() {
|
||||
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
|
||||
foundWarning = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundWarning {
|
||||
t.Fatal("expected warning log for duplicate client key, but none was found")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,13 +209,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// Resolve logs directory relative to the configuration file directory.
|
||||
var requestLogger logging.RequestLogger
|
||||
var toggle func(bool)
|
||||
if optionState.requestLoggerFactory != nil {
|
||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||
}
|
||||
if requestLogger != nil {
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggle = setter.SetEnabled
|
||||
if !cfg.CommercialMode {
|
||||
if optionState.requestLoggerFactory != nil {
|
||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||
}
|
||||
if requestLogger != nil {
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggle = setter.SetEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,13 +232,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -334,8 +332,8 @@ func (s *Server) setupRoutes() {
|
||||
v1beta.Use(AuthMiddleware(s.accessManager))
|
||||
{
|
||||
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
@@ -358,10 +356,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
// Persist to a temporary file keyed by state
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -371,9 +370,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -383,9 +384,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -395,9 +398,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -407,9 +412,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -469,6 +476,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
||||
{
|
||||
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
||||
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
|
||||
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
|
||||
mgmt.GET("/config", s.mgmt.GetConfig)
|
||||
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||
@@ -491,6 +500,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
|
||||
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
|
||||
|
||||
mgmt.POST("/api-call", s.mgmt.APICall)
|
||||
|
||||
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
|
||||
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
||||
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
||||
@@ -513,6 +524,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
||||
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
||||
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
||||
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
|
||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||
@@ -539,6 +551,10 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
||||
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
|
||||
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
|
||||
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
|
||||
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
|
||||
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
@@ -568,6 +584,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||
|
||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
@@ -580,6 +597,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
}
|
||||
@@ -608,7 +626,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
@@ -837,11 +855,20 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
if err := logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to reconfigure log output: %v", err)
|
||||
} else {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
if oldCfg == nil {
|
||||
log.Debug("log output configuration refreshed")
|
||||
} else {
|
||||
if oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
}
|
||||
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -918,17 +945,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.OpenAICompatProviders = providerNames
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -46,6 +47,12 @@ var (
|
||||
type GeminiAuth struct {
|
||||
}
|
||||
|
||||
// WebLoginOptions customizes the interactive OAuth flow.
|
||||
type WebLoginOptions struct {
|
||||
NoBrowser bool
|
||||
Prompt func(string) (string, error)
|
||||
}
|
||||
|
||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||
func NewGeminiAuth() *GeminiAuth {
|
||||
return &GeminiAuth{}
|
||||
@@ -59,12 +66,12 @@ func NewGeminiAuth() *GeminiAuth {
|
||||
// - ctx: The context for the HTTP client
|
||||
// - ts: The Gemini token storage containing authentication tokens
|
||||
// - cfg: The configuration containing proxy settings
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
// - opts: Optional parameters to customize browser and prompt behavior
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client configured with authentication
|
||||
// - error: An error if the client configuration fails, nil otherwise
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) {
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
@@ -109,7 +116,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
// If no token is found in storage, initiate the web-based OAuth flow.
|
||||
if ts.Token == nil {
|
||||
fmt.Printf("Could not load token from file, starting OAuth flow.\n")
|
||||
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...)
|
||||
token, err = g.getTokenFromWeb(ctx, conf, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||
}
|
||||
@@ -205,15 +212,15 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
// Parameters:
|
||||
// - ctx: The context for the HTTP client
|
||||
// - config: The OAuth2 configuration
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
// - opts: Optional parameters to customize browser and prompt behavior
|
||||
//
|
||||
// Returns:
|
||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||
// - error: An error if the token acquisition fails, nil otherwise
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) {
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
codeChan := make(chan string)
|
||||
errChan := make(chan error)
|
||||
codeChan := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// Create a new HTTP server with its own multiplexer.
|
||||
mux := http.NewServeMux()
|
||||
@@ -223,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.URL.Query().Get("error"); err != "" {
|
||||
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
||||
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
||||
select {
|
||||
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
|
||||
errChan <- fmt.Errorf("code not found in callback")
|
||||
select {
|
||||
case errChan <- fmt.Errorf("code not found in callback"):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||
codeChan <- code
|
||||
select {
|
||||
case codeChan <- code:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
// Start the server in a goroutine.
|
||||
@@ -250,7 +266,12 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
// Open the authorization URL in the user's browser.
|
||||
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
if len(noBrowser) == 1 && !noBrowser[0] {
|
||||
noBrowser := false
|
||||
if opts != nil {
|
||||
noBrowser = opts.NoBrowser
|
||||
}
|
||||
|
||||
if !noBrowser {
|
||||
fmt.Println("Opening browser for authentication...")
|
||||
|
||||
// Check if browser is available
|
||||
@@ -281,13 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
|
||||
// Wait for the authorization code or an error.
|
||||
var authCode string
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-time.After(5 * time.Minute): // Timeout
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
timeoutTimer := time.NewTimer(5 * time.Minute)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
break waitForCallback
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
break waitForCallback
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := misc.ParseOAuthCallback(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
if parsed.Error != "" {
|
||||
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
|
||||
}
|
||||
if parsed.Code == "" {
|
||||
return nil, fmt.Errorf("code not found in callback")
|
||||
}
|
||||
authCode = parsed.Code
|
||||
break waitForCallback
|
||||
case <-timeoutTimer.C:
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the server.
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -36,3 +39,61 @@ func SanitizeIFlowFileName(raw string) string {
|
||||
}
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
|
||||
// ExtractBXAuth extracts the BXAuth value from a cookie string.
|
||||
func ExtractBXAuth(cookie string) string {
|
||||
parts := strings.Split(cookie, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "BXAuth=") {
|
||||
return strings.TrimPrefix(part, "BXAuth=")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CheckDuplicateBXAuth checks if the given BXAuth value already exists in any iflow auth file.
|
||||
// Returns the path of the existing file if found, empty string otherwise.
|
||||
func CheckDuplicateBXAuth(authDir, bxAuth string) (string, error) {
|
||||
if bxAuth == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(authDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("read auth dir failed: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "iflow-") || !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, name)
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var tokenData struct {
|
||||
Cookie string `json:"cookie"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
existingBXAuth := ExtractBXAuth(tokenData.Cookie)
|
||||
if existingBXAuth != "" && existingBXAuth == bxAuth {
|
||||
return filePath, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -494,11 +494,18 @@ func (ia *IFlowAuth) CreateCookieTokenStorage(data *IFlowTokenData) *IFlowTokenS
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only save the BXAuth field from the cookie
|
||||
bxAuth := ExtractBXAuth(data.Cookie)
|
||||
cookieToSave := ""
|
||||
if bxAuth != "" {
|
||||
cookieToSave = "BXAuth=" + bxAuth + ";"
|
||||
}
|
||||
|
||||
return &IFlowTokenStorage{
|
||||
APIKey: data.APIKey,
|
||||
Email: data.Email,
|
||||
Expire: data.Expire,
|
||||
Cookie: data.Cookie,
|
||||
Cookie: cookieToSave,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
Type: "iflow",
|
||||
}
|
||||
|
||||
164
internal/cache/signature_cache.go
vendored
Normal file
164
internal/cache/signature_cache.go
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SignatureEntry holds a cached thinking signature with timestamp
|
||||
type SignatureEntry struct {
|
||||
Signature string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
// SignatureCacheTTL is how long signatures are valid
|
||||
SignatureCacheTTL = 1 * time.Hour
|
||||
|
||||
// MaxEntriesPerSession limits memory usage per session
|
||||
MaxEntriesPerSession = 100
|
||||
|
||||
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
|
||||
SignatureTextHashLen = 16
|
||||
|
||||
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
||||
MinValidSignatureLen = 50
|
||||
)
|
||||
|
||||
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
|
||||
var signatureCache sync.Map
|
||||
|
||||
// sessionCache is the inner map type
|
||||
type sessionCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]SignatureEntry
|
||||
}
|
||||
|
||||
// hashText creates a stable, Unicode-safe key from text content
|
||||
func hashText(text string) string {
|
||||
h := sha256.Sum256([]byte(text))
|
||||
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
||||
}
|
||||
|
||||
// getOrCreateSession gets or creates a session cache
|
||||
func getOrCreateSession(sessionID string) *sessionCache {
|
||||
if val, ok := signatureCache.Load(sessionID); ok {
|
||||
return val.(*sessionCache)
|
||||
}
|
||||
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
|
||||
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
|
||||
return actual.(*sessionCache)
|
||||
}
|
||||
|
||||
// CacheSignature stores a thinking signature for a given session and text.
|
||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||
func CacheSignature(sessionID, text, signature string) {
|
||||
if sessionID == "" || text == "" || signature == "" {
|
||||
return
|
||||
}
|
||||
if len(signature) < MinValidSignatureLen {
|
||||
return
|
||||
}
|
||||
|
||||
sc := getOrCreateSession(sessionID)
|
||||
textHash := hashText(text)
|
||||
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
// Evict expired entries if at capacity
|
||||
if len(sc.entries) >= MaxEntriesPerSession {
|
||||
now := time.Now()
|
||||
for key, entry := range sc.entries {
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, key)
|
||||
}
|
||||
}
|
||||
// If still at capacity, remove oldest entries
|
||||
if len(sc.entries) >= MaxEntriesPerSession {
|
||||
// Find and remove oldest quarter
|
||||
oldest := make([]struct {
|
||||
key string
|
||||
ts time.Time
|
||||
}, 0, len(sc.entries))
|
||||
for key, entry := range sc.entries {
|
||||
oldest = append(oldest, struct {
|
||||
key string
|
||||
ts time.Time
|
||||
}{key, entry.Timestamp})
|
||||
}
|
||||
// Sort by timestamp (oldest first) using sort.Slice
|
||||
sort.Slice(oldest, func(i, j int) bool {
|
||||
return oldest[i].ts.Before(oldest[j].ts)
|
||||
})
|
||||
|
||||
toRemove := len(oldest) / 4
|
||||
if toRemove < 1 {
|
||||
toRemove = 1
|
||||
}
|
||||
|
||||
for i := 0; i < toRemove; i++ {
|
||||
delete(sc.entries, oldest[i].key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sc.entries[textHash] = SignatureEntry{
|
||||
Signature: signature,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetCachedSignature retrieves a cached signature for a given session and text.
|
||||
// Returns empty string if not found or expired.
|
||||
func GetCachedSignature(sessionID, text string) string {
|
||||
if sessionID == "" || text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
val, ok := signatureCache.Load(sessionID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
sc := val.(*sessionCache)
|
||||
|
||||
textHash := hashText(text)
|
||||
|
||||
sc.mu.RLock()
|
||||
entry, exists := sc.entries[textHash]
|
||||
sc.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Since(entry.Timestamp) > SignatureCacheTTL {
|
||||
sc.mu.Lock()
|
||||
delete(sc.entries, textHash)
|
||||
sc.mu.Unlock()
|
||||
return ""
|
||||
}
|
||||
|
||||
return entry.Signature
|
||||
}
|
||||
|
||||
// ClearSignatureCache clears signature cache for a specific session or all sessions.
|
||||
func ClearSignatureCache(sessionID string) {
|
||||
if sessionID != "" {
|
||||
signatureCache.Delete(sessionID)
|
||||
} else {
|
||||
signatureCache.Range(func(key, _ any) bool {
|
||||
signatureCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
||||
func HasValidSignature(signature string) bool {
|
||||
return signature != "" && len(signature) >= MinValidSignatureLen
|
||||
}
|
||||
216
internal/cache/signature_cache_test.go
vendored
Normal file
216
internal/cache/signature_cache_test.go
vendored
Normal file
@@ -0,0 +1,216 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "test-session-1"
|
||||
text := "This is some thinking text content"
|
||||
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
|
||||
// Store signature
|
||||
CacheSignature(sessionID, text, signature)
|
||||
|
||||
// Retrieve signature
|
||||
retrieved := GetCachedSignature(sessionID, text)
|
||||
if retrieved != signature {
|
||||
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_DifferentSessions(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "Same text in different sessions"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature("session-a", text, sig1)
|
||||
CacheSignature("session-b", text, sig2)
|
||||
|
||||
if GetCachedSignature("session-a", text) != sig1 {
|
||||
t.Error("Session-a signature mismatch")
|
||||
}
|
||||
if GetCachedSignature("session-b", text) != sig2 {
|
||||
t.Error("Session-b signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_NotFound(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Non-existent session
|
||||
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
|
||||
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
||||
}
|
||||
|
||||
// Existing session but different text
|
||||
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature("session-x", "text-b"); got != "" {
|
||||
t.Errorf("Expected empty string for different text, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_EmptyInputs(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// All empty/invalid inputs should be no-ops
|
||||
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("session", "text", "")
|
||||
CacheSignature("session", "text", "short") // Too short
|
||||
|
||||
if got := GetCachedSignature("session", "text"); got != "" {
|
||||
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "test-short-sig"
|
||||
text := "Some text"
|
||||
shortSig := "abc123" // Less than 50 chars
|
||||
|
||||
CacheSignature(sessionID, text, shortSig)
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != "" {
|
||||
t.Errorf("Short signature should be rejected, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSignatureCache_SpecificSession(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("session-1", "text", sig)
|
||||
CacheSignature("session-2", "text", sig)
|
||||
|
||||
ClearSignatureCache("session-1")
|
||||
|
||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("session-2", "text"); got != sig {
|
||||
t.Error("session-2 should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSignatureCache_AllSessions(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("session-1", "text", sig)
|
||||
CacheSignature("session-2", "text", sig)
|
||||
|
||||
ClearSignatureCache("")
|
||||
|
||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("session-2", "text"); got != "" {
|
||||
t.Error("session-2 should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasValidSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
signature string
|
||||
expected bool
|
||||
}{
|
||||
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
|
||||
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
|
||||
{"empty string", "", false},
|
||||
{"short signature", "abc", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasValidSignature(tt.signature)
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "hash-test-session"
|
||||
|
||||
// Different texts should produce different hashes
|
||||
text1 := "First thinking text"
|
||||
text2 := "Second thinking text"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(sessionID, text1, sig1)
|
||||
CacheSignature(sessionID, text2, sig2)
|
||||
|
||||
if GetCachedSignature(sessionID, text1) != sig1 {
|
||||
t.Error("text1 signature mismatch")
|
||||
}
|
||||
if GetCachedSignature(sessionID, text2) != sig2 {
|
||||
t.Error("text2 signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "unicode-session"
|
||||
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
||||
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
||||
|
||||
CacheSignature(sessionID, text, sig)
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
||||
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_Overwrite(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "overwrite-session"
|
||||
text := "Same text"
|
||||
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
||||
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
||||
|
||||
CacheSignature(sessionID, text, sig1)
|
||||
CacheSignature(sessionID, text, sig2) // Overwrite
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != sig2 {
|
||||
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: TTL expiration test is tricky to test without mocking time
|
||||
// We test the logic path exists but actual expiration would require time manipulation
|
||||
func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// This test verifies the expiration check exists
|
||||
// In a real scenario, we'd mock time.Now()
|
||||
sessionID := "expiration-test"
|
||||
text := "text"
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(sessionID, text, sig)
|
||||
|
||||
// Fresh entry should be retrievable
|
||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
||||
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
||||
}
|
||||
|
||||
// We can't easily test actual expiration without time mocking
|
||||
// but the logic is verified by the implementation
|
||||
_ = time.Now() // Acknowledge we're not testing time passage
|
||||
}
|
||||
@@ -24,12 +24,17 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
|
||||
@@ -15,11 +15,16 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -37,6 +39,16 @@ func DoIFlowCookieAuth(cfg *config.Config, options *LoginOptions) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for duplicate BXAuth before authentication
|
||||
bxAuth := iflow.ExtractBXAuth(cookie)
|
||||
if existingFile, err := iflow.CheckDuplicateBXAuth(cfg.AuthDir, bxAuth); err != nil {
|
||||
fmt.Printf("Failed to check duplicate: %v\n", err)
|
||||
return
|
||||
} else if existingFile != "" {
|
||||
fmt.Printf("Duplicate BXAuth found, authentication already exists: %s\n", filepath.Base(existingFile))
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate with cookie
|
||||
auth := iflow.NewIFlowAuth(cfg)
|
||||
ctx := context.Background()
|
||||
@@ -82,5 +94,5 @@ func promptForCookie(promptFn func(string) (string, error)) (string, error) {
|
||||
// getAuthFilePath returns the auth file path for the given provider and email
|
||||
func getAuthFilePath(cfg *config.Config, provider, email string) string {
|
||||
fileName := iflow.SanitizeIFlowFileName(email)
|
||||
return fmt.Sprintf("%s/%s-%s.json", cfg.AuthDir, provider, fileName)
|
||||
return fmt.Sprintf("%s/%s-%s-%d.json", cfg.AuthDir, provider, fileName, time.Now().Unix())
|
||||
}
|
||||
|
||||
@@ -20,13 +20,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = func(prompt string) (string, error) {
|
||||
fmt.Println()
|
||||
fmt.Println(prompt)
|
||||
var value string
|
||||
_, err := fmt.Scanln(&value)
|
||||
return value, err
|
||||
}
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
|
||||
@@ -55,11 +55,22 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
trimmedProjectID := strings.TrimSpace(projectID)
|
||||
callbackPrompt := promptFn
|
||||
if trimmedProjectID == "" {
|
||||
callbackPrompt = nil
|
||||
}
|
||||
|
||||
loginOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
ProjectID: trimmedProjectID,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
Prompt: callbackPrompt,
|
||||
}
|
||||
|
||||
authenticator := sdkAuth.NewGeminiAuthenticator()
|
||||
@@ -76,7 +87,10 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
}
|
||||
|
||||
geminiAuth := gemini.NewGeminiAuth()
|
||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser)
|
||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Prompt: callbackPrompt,
|
||||
})
|
||||
if errClient != nil {
|
||||
log.Errorf("Gemini authentication failed: %v", errClient)
|
||||
return
|
||||
@@ -90,12 +104,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
return
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
|
||||
@@ -35,12 +35,17 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
|
||||
@@ -12,14 +12,15 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
config.SDKConfig `yaml:",inline"`
|
||||
SDKConfig `yaml:",inline"`
|
||||
// Host is the network host/interface on which the API server will bind.
|
||||
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
|
||||
Host string `yaml:"host" json:"-"`
|
||||
@@ -38,9 +39,16 @@ type Config struct {
|
||||
// Debug enables or disables debug-level logging and other debug features.
|
||||
Debug bool `yaml:"debug" json:"debug"`
|
||||
|
||||
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
|
||||
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
|
||||
|
||||
// LoggingToFile controls whether application logs are written to rotating files or stdout.
|
||||
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
|
||||
|
||||
// LogsMaxTotalSizeMB limits the total size (in MB) of log files under the logs directory.
|
||||
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
||||
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
||||
|
||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||
|
||||
@@ -55,6 +63,9 @@ type Config struct {
|
||||
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
|
||||
|
||||
// Routing controls credential selection behavior.
|
||||
Routing RoutingConfig `yaml:"routing" json:"routing"`
|
||||
|
||||
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
||||
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
||||
|
||||
@@ -80,6 +91,14 @@ type Config struct {
|
||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||
|
||||
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
|
||||
// These mappings affect both model listing and model routing for supported channels:
|
||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
//
|
||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
|
||||
|
||||
// Payload defines default and override rules for provider payload parameters.
|
||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||
|
||||
@@ -104,6 +123,9 @@ type RemoteManagement struct {
|
||||
SecretKey string `yaml:"secret-key"`
|
||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||
}
|
||||
|
||||
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||
@@ -116,6 +138,20 @@ type QuotaExceeded struct {
|
||||
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
||||
}
|
||||
|
||||
// RoutingConfig configures how credentials are selected for requests.
|
||||
type RoutingConfig struct {
|
||||
// Strategy selects the credential selection strategy.
|
||||
// Supported values: "round-robin" (default), "fill-first".
|
||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||
}
|
||||
|
||||
// ModelNameMapping defines a model ID rename mapping for a specific channel.
|
||||
// It maps the original model name (Name) to the client-visible alias (Alias).
|
||||
type ModelNameMapping struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||
// When Amp requests a model that isn't available locally, this mapping
|
||||
// allows routing to an alternative model that IS available.
|
||||
@@ -126,6 +162,11 @@ type AmpModelMapping struct {
|
||||
// To is the target model name to route to (e.g., "claude-sonnet-4").
|
||||
// The target model must have available providers in the registry.
|
||||
To string `yaml:"to" json:"to"`
|
||||
|
||||
// Regex indicates whether the 'from' field should be interpreted as a regular
|
||||
// expression for matching model names. When true, this mapping is evaluated
|
||||
// after exact matches and in the order provided. Defaults to false (exact match).
|
||||
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
|
||||
}
|
||||
|
||||
// AmpCode groups Amp CLI integration settings including upstream routing,
|
||||
@@ -137,9 +178,14 @@ type AmpCode struct {
|
||||
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
|
||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||
|
||||
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
||||
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
||||
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||
|
||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
||||
// browser attacks and remote access to management endpoints. Default: true (recommended).
|
||||
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
||||
RestrictManagementToLocalhost bool `yaml:"restrict-management-to-localhost" json:"restrict-management-to-localhost"`
|
||||
|
||||
// ModelMappings defines model name mappings for Amp CLI requests.
|
||||
@@ -152,6 +198,17 @@ type AmpCode struct {
|
||||
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
||||
}
|
||||
|
||||
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
|
||||
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||
// is used for the upstream Amp request.
|
||||
type AmpUpstreamAPIKeyEntry struct {
|
||||
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
|
||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||
|
||||
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
}
|
||||
|
||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||
type PayloadConfig struct {
|
||||
// Default defines rules that only set parameters when they are missing in the payload.
|
||||
@@ -182,6 +239,9 @@ type ClaudeKey struct {
|
||||
// APIKey is the authentication key for accessing Claude API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Claude API endpoint.
|
||||
// If empty, the default Claude API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -208,12 +268,18 @@ type ClaudeModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m ClaudeModel) GetName() string { return m.Name }
|
||||
func (m ClaudeModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// CodexKey represents the configuration for a Codex API key,
|
||||
// including the API key itself and an optional base URL for the API endpoint.
|
||||
type CodexKey struct {
|
||||
// APIKey is the authentication key for accessing Codex API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Codex API endpoint.
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -221,6 +287,9 @@ type CodexKey struct {
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// Models defines upstream model names and aliases for request routing.
|
||||
Models []CodexModel `yaml:"models" json:"models"`
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
@@ -228,18 +297,36 @@ type CodexKey struct {
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// CodexModel describes a mapping between an alias and the actual upstream model name.
|
||||
type CodexModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the client-facing model name that maps to Name.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m CodexModel) GetName() string { return m.Name }
|
||||
func (m CodexModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// GeminiKey represents the configuration for a Gemini API key,
|
||||
// including optional overrides for upstream base URL, proxy routing, and headers.
|
||||
type GeminiKey struct {
|
||||
// APIKey is the authentication key for accessing Gemini API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL optionally overrides the Gemini API endpoint.
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
// ProxyURL optionally overrides the global proxy for this API key.
|
||||
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||
|
||||
// Models defines upstream model names and aliases for request routing.
|
||||
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
@@ -247,12 +334,27 @@ type GeminiKey struct {
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiModel describes a mapping between an alias and the actual upstream model name.
|
||||
type GeminiModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the client-facing model name that maps to Name.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m GeminiModel) GetName() string { return m.Name }
|
||||
func (m GeminiModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
||||
type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the external OpenAI-compatible API endpoint.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
@@ -325,9 +427,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Set defaults before unmarshal so that absent keys keep defaults.
|
||||
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||
cfg.LoggingToFile = false
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||
if optional {
|
||||
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
||||
@@ -363,6 +467,15 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
_ = SaveConfigPreserveCommentsUpdateNestedScalar(configFile, []string{"remote-management", "secret-key"}, hashed)
|
||||
}
|
||||
|
||||
cfg.RemoteManagement.PanelGitHubRepository = strings.TrimSpace(cfg.RemoteManagement.PanelGitHubRepository)
|
||||
if cfg.RemoteManagement.PanelGitHubRepository == "" {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
if cfg.LogsMaxTotalSizeMB < 0 {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
@@ -384,6 +497,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Normalize OAuth provider model exclusion map.
|
||||
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
||||
|
||||
// Normalize global OAuth model name mappings.
|
||||
cfg.SanitizeOAuthModelMappings()
|
||||
|
||||
if cfg.legacyMigrationPending {
|
||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
if !optional && configFile != "" {
|
||||
@@ -400,6 +516,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// and ensures (From, To) pairs are unique within each channel.
|
||||
func (cfg *Config) SanitizeOAuthModelMappings() {
|
||||
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
|
||||
return
|
||||
}
|
||||
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
|
||||
for rawChannel, mappings := range cfg.OAuthModelMappings {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(mappings) == 0 {
|
||||
continue
|
||||
}
|
||||
seenName := make(map[string]struct{}, len(mappings))
|
||||
seenAlias := make(map[string]struct{}, len(mappings))
|
||||
clean := make([]ModelNameMapping, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
name := strings.TrimSpace(mapping.Name)
|
||||
alias := strings.TrimSpace(mapping.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
nameKey := strings.ToLower(name)
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, ok := seenName[nameKey]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := seenAlias[aliasKey]; ok {
|
||||
continue
|
||||
}
|
||||
seenName[nameKey] = struct{}{}
|
||||
seenAlias[aliasKey] = struct{}{}
|
||||
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
|
||||
}
|
||||
if len(clean) > 0 {
|
||||
out[channel] = clean
|
||||
}
|
||||
}
|
||||
cfg.OAuthModelMappings = out
|
||||
}
|
||||
|
||||
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
|
||||
// not actionable, specifically those missing a BaseURL. It trims whitespace before
|
||||
// evaluation and preserves the relative order of remaining entries.
|
||||
@@ -411,6 +571,7 @@ func (cfg *Config) SanitizeOpenAICompatibility() {
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
e := cfg.OpenAICompatibility[i]
|
||||
e.Name = strings.TrimSpace(e.Name)
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
if e.BaseURL == "" {
|
||||
@@ -431,6 +592,7 @@ func (cfg *Config) SanitizeCodexKeys() {
|
||||
out := make([]CodexKey, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
e := cfg.CodexKey[i]
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
|
||||
@@ -449,6 +611,7 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
||||
}
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
}
|
||||
@@ -468,6 +631,7 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
@@ -481,6 +645,18 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
cfg.GeminiKey = out
|
||||
}
|
||||
|
||||
func normalizeModelPrefix(prefix string) string {
|
||||
trimmed := strings.TrimSpace(prefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(trimmed, "/") {
|
||||
return ""
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
@@ -653,7 +829,7 @@ func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = config.AccessConfig{}
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
@@ -752,8 +928,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
||||
}
|
||||
|
||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||
// key order and comments of existing keys in dst. Unknown keys from src are appended
|
||||
// to dst at the end, copying their node structure from src.
|
||||
// key order and comments of existing keys in dst. New keys are only added if their
|
||||
// value is non-zero to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
@@ -764,20 +940,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
copyNodeShallow(dst, src)
|
||||
return
|
||||
}
|
||||
// Build a lookup of existing keys in dst
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
sk := src.Content[i]
|
||||
sv := src.Content[i+1]
|
||||
idx := findMapKeyIndex(dst, sk.Value)
|
||||
if idx >= 0 {
|
||||
// Merge into existing value node
|
||||
// Merge into existing value node (always update, even to zero values)
|
||||
dv := dst.Content[idx+1]
|
||||
mergeNodePreserve(dv, sv)
|
||||
} else {
|
||||
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) {
|
||||
// New key: only add if value is non-zero to avoid polluting config with defaults
|
||||
if isZeroValueNode(sv) {
|
||||
continue
|
||||
}
|
||||
// Append new key/value pair by deep-copying from src
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
||||
}
|
||||
}
|
||||
@@ -860,32 +1035,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
|
||||
switch key {
|
||||
case "generative-language-api-key",
|
||||
"gemini-api-key",
|
||||
"vertex-api-key",
|
||||
"claude-api-key",
|
||||
"codex-api-key",
|
||||
"openai-compatibility":
|
||||
return isEmptyCollectionNode(node)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isEmptyCollectionNode(node *yaml.Node) bool {
|
||||
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||
// that should not be written as a new key to preserve config cleanliness.
|
||||
// For mappings and sequences, recursively checks if all children are zero values.
|
||||
func isZeroValueNode(node *yaml.Node) bool {
|
||||
if node == nil {
|
||||
return true
|
||||
}
|
||||
switch node.Kind {
|
||||
case yaml.SequenceNode:
|
||||
return len(node.Content) == 0
|
||||
case yaml.ScalarNode:
|
||||
return node.Tag == "!!null"
|
||||
default:
|
||||
return false
|
||||
switch node.Tag {
|
||||
case "!!bool":
|
||||
return node.Value == "false"
|
||||
case "!!int", "!!float":
|
||||
return node.Value == "0" || node.Value == "0.0"
|
||||
case "!!str":
|
||||
return node.Value == ""
|
||||
case "!!null":
|
||||
return true
|
||||
}
|
||||
case yaml.SequenceNode:
|
||||
if len(node.Content) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check if all elements are zero values
|
||||
for _, child := range node.Content {
|
||||
if !isZeroValueNode(child) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case yaml.MappingNode:
|
||||
if len(node.Content) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check if all values are zero values (values are at odd indices)
|
||||
for i := 1; i < len(node.Content); i += 2 {
|
||||
if !isZeroValueNode(node.Content[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// deepCopyNode creates a deep copy of a yaml.Node graph.
|
||||
|
||||
102
internal/config/sdk_config.go
Normal file
102
internal/config/sdk_config.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// Package config provides configuration management for the CLI Proxy API server.
|
||||
// It handles loading and parsing YAML configuration files, and provides structured
|
||||
// access to application settings including server port, authentication directory,
|
||||
// debug settings, proxy configuration, and API keys.
|
||||
package config
|
||||
|
||||
// SDKConfig represents the application's configuration, loaded from a YAML file.
|
||||
type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
}
|
||||
|
||||
// StreamingConfig holds server streaming behavior configuration.
|
||||
type StreamingConfig struct {
|
||||
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
|
||||
// <= 0 disables keep-alives. Default is 0.
|
||||
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
||||
|
||||
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
|
||||
// to allow auth rotation / transient recovery.
|
||||
// <= 0 disables bootstrap retries. Default is 0.
|
||||
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@@ -13,6 +13,9 @@ type VertexCompatKey struct {
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
@@ -39,6 +42,9 @@ type VertexCompatModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m VertexCompatModel) GetName() string { return m.Name }
|
||||
func (m VertexCompatModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
|
||||
func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if cfg == nil {
|
||||
@@ -53,6 +59,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,11 +15,24 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
|
||||
var aiAPIPrefixes = []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/v1/messages",
|
||||
"/v1/responses",
|
||||
"/v1beta/models/",
|
||||
"/api/provider/",
|
||||
}
|
||||
|
||||
const skipGinLogKey = "__gin_skip_request_logging__"
|
||||
|
||||
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
|
||||
// using logrus. It captures request details including method, path, status code, latency,
|
||||
// client IP, and any error messages, formatting them in a Gin-style log format.
|
||||
// client IP, and any error messages. Request ID is only added for AI API requests.
|
||||
//
|
||||
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
|
||||
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
|
||||
//
|
||||
// Returns:
|
||||
// - gin.HandlerFunc: A middleware handler for request logging
|
||||
@@ -28,6 +42,15 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
path := c.Request.URL.Path
|
||||
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
|
||||
// Only generate request ID for AI API paths
|
||||
var requestID string
|
||||
if isAIAPIPath(path) {
|
||||
requestID = GenerateRequestID()
|
||||
SetGinRequestID(c, requestID)
|
||||
ctx := WithRequestID(c.Request.Context(), requestID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
|
||||
if shouldSkipGinRequestLogging(c) {
|
||||
@@ -49,23 +72,38 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
|
||||
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path)
|
||||
|
||||
if requestID == "" {
|
||||
requestID = "--------"
|
||||
}
|
||||
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
|
||||
if errorMessage != "" {
|
||||
logLine = logLine + " | " + errorMessage
|
||||
}
|
||||
|
||||
entry := log.WithField("request_id", requestID)
|
||||
|
||||
switch {
|
||||
case statusCode >= http.StatusInternalServerError:
|
||||
log.Error(logLine)
|
||||
entry.Error(logLine)
|
||||
case statusCode >= http.StatusBadRequest:
|
||||
log.Warn(logLine)
|
||||
entry.Warn(logLine)
|
||||
default:
|
||||
log.Info(logLine)
|
||||
entry.Info(logLine)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
|
||||
func isAIAPIPath(path string) bool {
|
||||
for _, prefix := range aiAPIPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
|
||||
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
|
||||
// and request path, then returns a 500 Internal Server Error response to the client.
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@@ -24,7 +25,8 @@ var (
|
||||
)
|
||||
|
||||
// LogFormatter defines a custom log format for logrus.
|
||||
// This formatter adds timestamp, level, and source location to each log entry.
|
||||
// This formatter adds timestamp, level, request ID, and source location to each log entry.
|
||||
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
|
||||
type LogFormatter struct{}
|
||||
|
||||
// Format renders a single log entry with custom formatting.
|
||||
@@ -39,11 +41,22 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||
message := strings.TrimRight(entry.Message, "\r\n")
|
||||
|
||||
reqID := "--------"
|
||||
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
|
||||
reqID = id
|
||||
}
|
||||
|
||||
level := entry.Level.String()
|
||||
if level == "warning" {
|
||||
level = "warn"
|
||||
}
|
||||
levelStr := fmt.Sprintf("%-5s", level)
|
||||
|
||||
var formatted string
|
||||
if entry.Caller != nil {
|
||||
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
|
||||
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
|
||||
} else {
|
||||
formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message)
|
||||
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
|
||||
}
|
||||
buffer.WriteString(formatted)
|
||||
|
||||
@@ -71,40 +84,68 @@ func SetupBaseLogger() {
|
||||
})
|
||||
}
|
||||
|
||||
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
|
||||
func isDirWritable(dir string) bool {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil || !info.IsDir() {
|
||||
return false
|
||||
}
|
||||
|
||||
testFile := filepath.Join(dir, ".perm_test")
|
||||
f, err := os.Create(testFile)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(testFile)
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||
func ConfigureLogOutput(loggingToFile bool) error {
|
||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||
// until the total size is within the limit.
|
||||
func ConfigureLogOutput(cfg *config.Config) error {
|
||||
SetupBaseLogger()
|
||||
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
if loggingToFile {
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
} else if !isDirWritable(logDir) {
|
||||
logDir = filepath.Join(cfg.AuthDir, "logs")
|
||||
}
|
||||
|
||||
protectedPath := ""
|
||||
if cfg.LoggingToFile {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
||||
}
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
}
|
||||
protectedPath = filepath.Join(logDir, "main.log")
|
||||
logWriter = &lumberjack.Logger{
|
||||
Filename: filepath.Join(logDir, "main.log"),
|
||||
Filename: protectedPath,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 0,
|
||||
MaxAge: 0,
|
||||
Compress: false,
|
||||
}
|
||||
log.SetOutput(logWriter)
|
||||
return nil
|
||||
} else {
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,6 +153,8 @@ func closeLogOutputs() {
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
|
||||
166
internal/logging/log_dir_cleaner.go
Normal file
166
internal/logging/log_dir_cleaner.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const logDirCleanerInterval = time.Minute
|
||||
|
||||
var logDirCleanerCancel context.CancelFunc
|
||||
|
||||
func configureLogDirCleanerLocked(logDir string, maxTotalSizeMB int, protectedPath string) {
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if maxTotalSizeMB <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
maxBytes := int64(maxTotalSizeMB) * 1024 * 1024
|
||||
if maxBytes <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logDirCleanerCancel = cancel
|
||||
go runLogDirCleaner(ctx, filepath.Clean(dir), maxBytes, strings.TrimSpace(protectedPath))
|
||||
}
|
||||
|
||||
func stopLogDirCleanerLocked() {
|
||||
if logDirCleanerCancel == nil {
|
||||
return
|
||||
}
|
||||
logDirCleanerCancel()
|
||||
logDirCleanerCancel = nil
|
||||
}
|
||||
|
||||
func runLogDirCleaner(ctx context.Context, logDir string, maxBytes int64, protectedPath string) {
|
||||
ticker := time.NewTicker(logDirCleanerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
cleanOnce := func() {
|
||||
deleted, errClean := enforceLogDirSizeLimit(logDir, maxBytes, protectedPath)
|
||||
if errClean != nil {
|
||||
log.WithError(errClean).Warn("logging: failed to enforce log directory size limit")
|
||||
return
|
||||
}
|
||||
if deleted > 0 {
|
||||
log.Debugf("logging: removed %d old log file(s) to enforce log directory size limit", deleted)
|
||||
}
|
||||
}
|
||||
|
||||
cleanOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cleanOnce()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enforceLogDirSizeLimit(logDir string, maxBytes int64, protectedPath string) (int, error) {
|
||||
if maxBytes <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return 0, nil
|
||||
}
|
||||
dir = filepath.Clean(dir)
|
||||
|
||||
entries, errRead := os.ReadDir(dir)
|
||||
if errRead != nil {
|
||||
if os.IsNotExist(errRead) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, errRead
|
||||
}
|
||||
|
||||
protected := strings.TrimSpace(protectedPath)
|
||||
if protected != "" {
|
||||
protected = filepath.Clean(protected)
|
||||
}
|
||||
|
||||
type logFile struct {
|
||||
path string
|
||||
size int64
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
files []logFile
|
||||
total int64
|
||||
)
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !isLogFileName(name) {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
continue
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, name)
|
||||
files = append(files, logFile{
|
||||
path: path,
|
||||
size: info.Size(),
|
||||
modTime: info.ModTime(),
|
||||
})
|
||||
total += info.Size()
|
||||
}
|
||||
|
||||
if total <= maxBytes {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].modTime.Before(files[j].modTime)
|
||||
})
|
||||
|
||||
deleted := 0
|
||||
for _, file := range files {
|
||||
if total <= maxBytes {
|
||||
break
|
||||
}
|
||||
if protected != "" && filepath.Clean(file.path) == protected {
|
||||
continue
|
||||
}
|
||||
if errRemove := os.Remove(file.path); errRemove != nil {
|
||||
log.WithError(errRemove).Warnf("logging: failed to remove old log file: %s", filepath.Base(file.path))
|
||||
continue
|
||||
}
|
||||
total -= file.size
|
||||
deleted++
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func isLogFileName(name string) bool {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
return strings.HasSuffix(lower, ".log") || strings.HasSuffix(lower, ".log.gz")
|
||||
}
|
||||
70
internal/logging/log_dir_cleaner_test.go
Normal file
70
internal/logging/log_dir_cleaner_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEnforceLogDirSizeLimitDeletesOldest(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeLogFile(t, filepath.Join(dir, "old.log"), 60, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "mid.log"), 60, time.Unix(2, 0))
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 60, time.Unix(3, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 120, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "old.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected old.log to be removed, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "mid.log")); err != nil {
|
||||
t.Fatalf("expected mid.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceLogDirSizeLimitSkipsProtected(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 200, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "other.log"), 50, time.Unix(2, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 100, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "other.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected other.log to be removed, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeLogFile(t *testing.T, path string, size int, modTime time.Time) {
|
||||
t.Helper()
|
||||
|
||||
data := make([]byte, size)
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
if err := os.Chtimes(path, modTime, modTime); err != nil {
|
||||
t.Fatalf("set times: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
@@ -25,6 +26,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
var requestLogID atomic.Uint64
|
||||
|
||||
// RequestLogger defines the interface for logging HTTP requests and responses.
|
||||
// It provides methods for logging both regular and streaming HTTP request/response cycles.
|
||||
type RequestLogger interface {
|
||||
@@ -40,10 +43,11 @@ type RequestLogger interface {
|
||||
// - response: The raw response data
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
|
||||
|
||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||
//
|
||||
@@ -52,11 +56,12 @@ type RequestLogger interface {
|
||||
// - method: The HTTP method
|
||||
// - headers: The request headers
|
||||
// - body: The request body
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - StreamingLogWriter: A writer for streaming response chunks
|
||||
// - error: An error if logging initialization fails, nil otherwise
|
||||
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error)
|
||||
LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
|
||||
|
||||
// IsEnabled returns whether request logging is currently enabled.
|
||||
//
|
||||
@@ -174,20 +179,21 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
||||
// - response: The raw response data
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
|
||||
}
|
||||
|
||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||
if !l.enabled && !force {
|
||||
return nil
|
||||
}
|
||||
@@ -197,26 +203,59 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := l.generateFilename(url)
|
||||
// Generate filename with request ID
|
||||
filename := l.generateFilename(url, requestID)
|
||||
if force && !l.enabled {
|
||||
filename = l.generateErrorFilename(url)
|
||||
filename = l.generateErrorFilename(url, requestID)
|
||||
}
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
// Decompress response if needed
|
||||
decompressedResponse, err := l.decompressResponse(responseHeaders, response)
|
||||
if err != nil {
|
||||
// If decompression fails, log the error but continue with original response
|
||||
decompressedResponse = append(response, []byte(fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", err))...)
|
||||
requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
|
||||
if errTemp != nil {
|
||||
log.WithError(errTemp).Warn("failed to create request body temp file, falling back to direct write")
|
||||
}
|
||||
if requestBodyPath != "" {
|
||||
defer func() {
|
||||
if errRemove := os.Remove(requestBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove request body temp file")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Create log content
|
||||
content := l.formatLogContent(url, method, requestHeaders, body, apiRequest, apiResponse, decompressedResponse, statusCode, responseHeaders, apiResponseErrors)
|
||||
responseToWrite, decompressErr := l.decompressResponse(responseHeaders, response)
|
||||
if decompressErr != nil {
|
||||
// If decompression fails, continue with original response and annotate the log output.
|
||||
responseToWrite = response
|
||||
}
|
||||
|
||||
// Write to file
|
||||
if err = os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write log file: %w", err)
|
||||
logFile, errOpen := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if errOpen != nil {
|
||||
return fmt.Errorf("failed to create log file: %w", errOpen)
|
||||
}
|
||||
|
||||
writeErr := l.writeNonStreamingLog(
|
||||
logFile,
|
||||
url,
|
||||
method,
|
||||
requestHeaders,
|
||||
body,
|
||||
requestBodyPath,
|
||||
apiRequest,
|
||||
apiResponse,
|
||||
apiResponseErrors,
|
||||
statusCode,
|
||||
responseHeaders,
|
||||
responseToWrite,
|
||||
decompressErr,
|
||||
)
|
||||
if errClose := logFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request log file")
|
||||
if writeErr == nil {
|
||||
return errClose
|
||||
}
|
||||
}
|
||||
if writeErr != nil {
|
||||
return fmt.Errorf("failed to write log file: %w", writeErr)
|
||||
}
|
||||
|
||||
if force && !l.enabled {
|
||||
@@ -235,11 +274,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
// - method: The HTTP method
|
||||
// - headers: The request headers
|
||||
// - body: The request body
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - StreamingLogWriter: A writer for streaming response chunks
|
||||
// - error: An error if logging initialization fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) {
|
||||
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
|
||||
if !l.enabled {
|
||||
return &NoOpStreamingLogWriter{}, nil
|
||||
}
|
||||
@@ -249,30 +289,42 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
return nil, fmt.Errorf("failed to create logs directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := l.generateFilename(url)
|
||||
// Generate filename with request ID
|
||||
filename := l.generateFilename(url, requestID)
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
// Create and open file
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create log file: %w", err)
|
||||
requestHeaders := make(map[string][]string, len(headers))
|
||||
for key, values := range headers {
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
requestHeaders[key] = headerValues
|
||||
}
|
||||
|
||||
// Write initial request information
|
||||
requestInfo := l.formatRequestInfo(url, method, headers, body)
|
||||
if _, err = file.WriteString(requestInfo); err != nil {
|
||||
_ = file.Close()
|
||||
return nil, fmt.Errorf("failed to write request info: %w", err)
|
||||
requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
|
||||
if errTemp != nil {
|
||||
return nil, fmt.Errorf("failed to create request body temp file: %w", errTemp)
|
||||
}
|
||||
|
||||
responseBodyFile, errCreate := os.CreateTemp(l.logsDir, "response-body-*.tmp")
|
||||
if errCreate != nil {
|
||||
_ = os.Remove(requestBodyPath)
|
||||
return nil, fmt.Errorf("failed to create response body temp file: %w", errCreate)
|
||||
}
|
||||
responseBodyPath := responseBodyFile.Name()
|
||||
|
||||
// Create streaming writer
|
||||
writer := &FileStreamingLogWriter{
|
||||
file: file,
|
||||
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||
closeChan: make(chan struct{}),
|
||||
errorChan: make(chan error, 1),
|
||||
bufferedChunks: &bytes.Buffer{},
|
||||
logFilePath: filePath,
|
||||
url: url,
|
||||
method: method,
|
||||
timestamp: time.Now(),
|
||||
requestHeaders: requestHeaders,
|
||||
requestBodyPath: requestBodyPath,
|
||||
responseBodyPath: responseBodyPath,
|
||||
responseBodyFile: responseBodyFile,
|
||||
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||
closeChan: make(chan struct{}),
|
||||
errorChan: make(chan error, 1),
|
||||
}
|
||||
|
||||
// Start async writer goroutine
|
||||
@@ -282,8 +334,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
}
|
||||
|
||||
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
|
||||
func (l *FileRequestLogger) generateErrorFilename(url string) string {
|
||||
return fmt.Sprintf("error-%s", l.generateFilename(url))
|
||||
func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
|
||||
return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
|
||||
}
|
||||
|
||||
// ensureLogsDir creates the logs directory if it doesn't exist.
|
||||
@@ -298,13 +350,15 @@ func (l *FileRequestLogger) ensureLogsDir() error {
|
||||
}
|
||||
|
||||
// generateFilename creates a sanitized filename from the URL path and current timestamp.
|
||||
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
|
||||
//
|
||||
// Parameters:
|
||||
// - url: The request URL
|
||||
// - requestID: Optional request ID to include in filename
|
||||
//
|
||||
// Returns:
|
||||
// - string: A sanitized filename for the log file
|
||||
func (l *FileRequestLogger) generateFilename(url string) string {
|
||||
func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
|
||||
// Extract path from URL
|
||||
path := url
|
||||
if strings.Contains(url, "?") {
|
||||
@@ -320,10 +374,18 @@ func (l *FileRequestLogger) generateFilename(url string) string {
|
||||
sanitized := l.sanitizeForFilename(path)
|
||||
|
||||
// Add timestamp
|
||||
timestamp := time.Now().Format("2006-01-02T150405-.000000000")
|
||||
timestamp = strings.Replace(timestamp, ".", "", -1)
|
||||
timestamp := time.Now().Format("2006-01-02T150405")
|
||||
|
||||
return fmt.Sprintf("%s-%s.log", sanitized, timestamp)
|
||||
// Use request ID if provided, otherwise use sequential ID
|
||||
var idPart string
|
||||
if len(requestID) > 0 && requestID[0] != "" {
|
||||
idPart = requestID[0]
|
||||
} else {
|
||||
id := requestLogID.Add(1)
|
||||
idPart = fmt.Sprintf("%d", id)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
|
||||
}
|
||||
|
||||
// sanitizeForFilename replaces characters that are not safe for filenames.
|
||||
@@ -405,6 +467,220 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) writeRequestBodyTempFile(body []byte) (string, error) {
|
||||
tmpFile, errCreate := os.CreateTemp(l.logsDir, "request-body-*.tmp")
|
||||
if errCreate != nil {
|
||||
return "", errCreate
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
if _, errCopy := io.Copy(tmpFile, bytes.NewReader(body)); errCopy != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", errCopy
|
||||
}
|
||||
if errClose := tmpFile.Close(); errClose != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", errClose
|
||||
}
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
w io.Writer,
|
||||
url, method string,
|
||||
requestHeaders map[string][]string,
|
||||
requestBody []byte,
|
||||
requestBodyPath string,
|
||||
apiRequest []byte,
|
||||
apiResponse []byte,
|
||||
apiResponseErrors []*interfaces.ErrorMessage,
|
||||
statusCode int,
|
||||
responseHeaders map[string][]string,
|
||||
response []byte,
|
||||
decompressErr error,
|
||||
) error {
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, time.Now()); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||
}
|
||||
|
||||
func writeRequestInfoWithBody(
|
||||
w io.Writer,
|
||||
url, method string,
|
||||
headers map[string][]string,
|
||||
body []byte,
|
||||
bodyPath string,
|
||||
timestamp time.Time,
|
||||
) error {
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Version: %s\n", buildinfo.Version)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("URL: %s\n", url)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "=== HEADERS ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
for key, values := range headers {
|
||||
for _, value := range values {
|
||||
masked := util.MaskSensitiveHeaderValue(key, value)
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, masked)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if bodyPath != "" {
|
||||
bodyFile, errOpen := os.Open(bodyPath)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
}
|
||||
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
|
||||
_ = bodyFile.Close()
|
||||
return errCopy
|
||||
}
|
||||
if errClose := bodyFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request body temp file")
|
||||
}
|
||||
} else if _, errWrite := w.Write(body); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(payload, []byte(sectionPrefix)) {
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if !bytes.HasSuffix(payload, []byte("\n")) {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMessage) error {
|
||||
for i := 0; i < len(apiResponseErrors); i++ {
|
||||
if apiResponseErrors[i] == nil {
|
||||
continue
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "=== API ERROR RESPONSE ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if apiResponseErrors[i].Error != nil {
|
||||
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, responseHeaders map[string][]string, responseReader io.Reader, decompressErr error, trailingNewline bool) error {
|
||||
if _, errWrite := io.WriteString(w, "=== RESPONSE ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if statusWritten {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Status: %d\n", statusCode)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if responseHeaders != nil {
|
||||
for key, values := range responseHeaders {
|
||||
for _, value := range values {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, value)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if responseReader != nil {
|
||||
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
|
||||
return errCopy
|
||||
}
|
||||
}
|
||||
if decompressErr != nil {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", decompressErr)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if trailingNewline {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatLogContent creates the complete log content for non-streaming requests.
|
||||
//
|
||||
// Parameters:
|
||||
@@ -648,13 +924,34 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
||||
}
|
||||
|
||||
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
|
||||
// It handles asynchronous writing of streaming response chunks to a file.
|
||||
// All data is buffered and written in the correct order when Close is called.
|
||||
// It spools streaming response chunks to a temporary file to avoid retaining large responses in memory.
|
||||
// The final log file is assembled when Close is called.
|
||||
type FileStreamingLogWriter struct {
|
||||
// file is the file where log data is written.
|
||||
file *os.File
|
||||
// logFilePath is the final log file path.
|
||||
logFilePath string
|
||||
|
||||
// chunkChan is a channel for receiving response chunks to buffer.
|
||||
// url is the request URL (masked upstream in middleware).
|
||||
url string
|
||||
|
||||
// method is the HTTP method.
|
||||
method string
|
||||
|
||||
// timestamp is captured when the streaming log is initialized.
|
||||
timestamp time.Time
|
||||
|
||||
// requestHeaders stores the request headers.
|
||||
requestHeaders map[string][]string
|
||||
|
||||
// requestBodyPath is a temporary file path holding the request body.
|
||||
requestBodyPath string
|
||||
|
||||
// responseBodyPath is a temporary file path holding the streaming response body.
|
||||
responseBodyPath string
|
||||
|
||||
// responseBodyFile is the temp file where chunks are appended by the async writer.
|
||||
responseBodyFile *os.File
|
||||
|
||||
// chunkChan is a channel for receiving response chunks to spool.
|
||||
chunkChan chan []byte
|
||||
|
||||
// closeChan is a channel for signaling when the writer is closed.
|
||||
@@ -663,9 +960,6 @@ type FileStreamingLogWriter struct {
|
||||
// errorChan is a channel for reporting errors during writing.
|
||||
errorChan chan error
|
||||
|
||||
// bufferedChunks stores the response chunks in order.
|
||||
bufferedChunks *bytes.Buffer
|
||||
|
||||
// responseStatus stores the HTTP status code.
|
||||
responseStatus int
|
||||
|
||||
@@ -770,85 +1064,115 @@ func (w *FileStreamingLogWriter) Close() error {
|
||||
close(w.chunkChan)
|
||||
}
|
||||
|
||||
// Wait for async writer to finish buffering chunks
|
||||
// Wait for async writer to finish spooling chunks
|
||||
if w.closeChan != nil {
|
||||
<-w.closeChan
|
||||
w.chunkChan = nil
|
||||
}
|
||||
|
||||
if w.file == nil {
|
||||
select {
|
||||
case errWrite := <-w.errorChan:
|
||||
w.cleanupTempFiles()
|
||||
return errWrite
|
||||
default:
|
||||
}
|
||||
|
||||
if w.logFilePath == "" {
|
||||
w.cleanupTempFiles()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write all content in the correct order
|
||||
var content strings.Builder
|
||||
|
||||
// 1. Write API REQUEST section
|
||||
if len(w.apiRequest) > 0 {
|
||||
if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) {
|
||||
content.Write(w.apiRequest)
|
||||
if !bytes.HasSuffix(w.apiRequest, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API REQUEST ===\n")
|
||||
content.Write(w.apiRequest)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
logFile, errOpen := os.OpenFile(w.logFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if errOpen != nil {
|
||||
w.cleanupTempFiles()
|
||||
return fmt.Errorf("failed to create log file: %w", errOpen)
|
||||
}
|
||||
|
||||
// 2. Write API RESPONSE section
|
||||
if len(w.apiResponse) > 0 {
|
||||
if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) {
|
||||
content.Write(w.apiResponse)
|
||||
if !bytes.HasSuffix(w.apiResponse, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API RESPONSE ===\n")
|
||||
content.Write(w.apiResponse)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// 3. Write RESPONSE section (status, headers, buffered chunks)
|
||||
content.WriteString("=== RESPONSE ===\n")
|
||||
if w.statusWritten {
|
||||
content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus))
|
||||
}
|
||||
|
||||
for key, values := range w.responseHeaders {
|
||||
for _, value := range values {
|
||||
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||
writeErr := w.writeFinalLog(logFile)
|
||||
if errClose := logFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request log file")
|
||||
if writeErr == nil {
|
||||
writeErr = errClose
|
||||
}
|
||||
}
|
||||
content.WriteString("\n")
|
||||
|
||||
// Write buffered response body chunks
|
||||
if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 {
|
||||
content.Write(w.bufferedChunks.Bytes())
|
||||
}
|
||||
|
||||
// Write the complete content to file
|
||||
if _, err := w.file.WriteString(content.String()); err != nil {
|
||||
_ = w.file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.file.Close()
|
||||
w.cleanupTempFiles()
|
||||
return writeErr
|
||||
}
|
||||
|
||||
// asyncWriter runs in a goroutine to buffer chunks from the channel.
|
||||
// It continuously reads chunks from the channel and buffers them for later writing.
|
||||
// It continuously reads chunks from the channel and appends them to a temp file for later assembly.
|
||||
func (w *FileStreamingLogWriter) asyncWriter() {
|
||||
defer close(w.closeChan)
|
||||
|
||||
for chunk := range w.chunkChan {
|
||||
if w.bufferedChunks != nil {
|
||||
w.bufferedChunks.Write(chunk)
|
||||
if w.responseBodyFile == nil {
|
||||
continue
|
||||
}
|
||||
if _, errWrite := w.responseBodyFile.Write(chunk); errWrite != nil {
|
||||
select {
|
||||
case w.errorChan <- errWrite:
|
||||
default:
|
||||
}
|
||||
if errClose := w.responseBodyFile.Close(); errClose != nil {
|
||||
select {
|
||||
case w.errorChan <- errClose:
|
||||
default:
|
||||
}
|
||||
}
|
||||
w.responseBodyFile = nil
|
||||
}
|
||||
}
|
||||
|
||||
if w.responseBodyFile == nil {
|
||||
return
|
||||
}
|
||||
if errClose := w.responseBodyFile.Close(); errClose != nil {
|
||||
select {
|
||||
case w.errorChan <- errClose:
|
||||
default:
|
||||
}
|
||||
}
|
||||
w.responseBodyFile = nil
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
responseBodyFile, errOpen := os.Open(w.responseBodyPath)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
}
|
||||
defer func() {
|
||||
if errClose := responseBodyFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close response body temp file")
|
||||
}
|
||||
}()
|
||||
|
||||
return writeResponseSection(logFile, w.responseStatus, w.statusWritten, w.responseHeaders, responseBodyFile, nil, false)
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) cleanupTempFiles() {
|
||||
if w.requestBodyPath != "" {
|
||||
if errRemove := os.Remove(w.requestBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove request body temp file")
|
||||
}
|
||||
w.requestBodyPath = ""
|
||||
}
|
||||
|
||||
if w.responseBodyPath != "" {
|
||||
if errRemove := os.Remove(w.responseBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove response body temp file")
|
||||
}
|
||||
w.responseBodyPath = ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
61
internal/logging/requestid.go
Normal file
61
internal/logging/requestid.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// requestIDKey is the context key for storing/retrieving request IDs.
|
||||
type requestIDKey struct{}
|
||||
|
||||
// ginRequestIDKey is the Gin context key for request IDs.
|
||||
const ginRequestIDKey = "__request_id__"
|
||||
|
||||
// GenerateRequestID creates a new 8-character hex request ID.
|
||||
func GenerateRequestID() string {
|
||||
b := make([]byte, 4)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "00000000"
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// WithRequestID returns a new context with the request ID attached.
|
||||
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
||||
return context.WithValue(ctx, requestIDKey{}, requestID)
|
||||
}
|
||||
|
||||
// GetRequestID retrieves the request ID from the context.
|
||||
// Returns empty string if not found.
|
||||
func GetRequestID(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetGinRequestID stores the request ID in the Gin context.
|
||||
func SetGinRequestID(c *gin.Context, requestID string) {
|
||||
if c != nil {
|
||||
c.Set(ginRequestIDKey, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetGinRequestID retrieves the request ID from the Gin context.
|
||||
func GetGinRequestID(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
if id, exists := c.Get(ginRequestIDKey); exists {
|
||||
if s, ok := id.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -23,10 +24,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
managementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -97,7 +99,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
|
||||
configPath, _ := schedulerConfigPath.Load().(string)
|
||||
staticDir := StaticDir(configPath)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
|
||||
runOnce()
|
||||
@@ -181,7 +183,7 @@ func FilePath(configFilePath string) string {
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string) {
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
@@ -197,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
return
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
@@ -209,14 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if err := os.MkdirAll(staticDir, 0o755); err != nil {
|
||||
log.WithError(err).Warn("failed to prepare static directory for management asset")
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
@@ -225,8 +237,15 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client)
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
@@ -238,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
@@ -254,8 +280,60 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client) (*releaseAsset, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, managementReleaseURL, nil)
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to download fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
|
||||
return true
|
||||
}
|
||||
|
||||
func resolveReleaseURL(repo string) string {
|
||||
repo = strings.TrimSpace(repo)
|
||||
if repo == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(repo)
|
||||
if err != nil || parsed.Host == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
host := strings.ToLower(parsed.Host)
|
||||
parsed.Path = strings.TrimSuffix(parsed.Path, "/")
|
||||
|
||||
if host == "api.github.com" {
|
||||
if !strings.HasSuffix(strings.ToLower(parsed.Path), "/releases/latest") {
|
||||
parsed.Path = parsed.Path + "/releases/latest"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
if host == "github.com" {
|
||||
parts := strings.Split(strings.Trim(parsed.Path, "/"), "/")
|
||||
if len(parts) >= 2 && parts[0] != "" && parts[1] != "" {
|
||||
repoName := strings.TrimSuffix(parts[1], ".git")
|
||||
return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", parts[0], repoName)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL string) (*releaseAsset, string, error) {
|
||||
if strings.TrimSpace(releaseURL) == "" {
|
||||
releaseURL = defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("create release request: %w", err)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
lastCodexMaxPrompt := ""
|
||||
last51Prompt := ""
|
||||
last52Prompt := ""
|
||||
last52CodexPrompt := ""
|
||||
// lastReviewPrompt := ""
|
||||
for _, entry := range entries {
|
||||
content, _ := codexInstructionsDir.ReadFile("codex_instructions/" + entry.Name())
|
||||
@@ -36,12 +37,16 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
last51Prompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "gpt_5_2_prompt.md") {
|
||||
last52Prompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "gpt-5.2-codex_prompt.md") {
|
||||
last52CodexPrompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "review_prompt.md") {
|
||||
// lastReviewPrompt = string(content)
|
||||
}
|
||||
}
|
||||
if strings.Contains(modelName, "codex-max") {
|
||||
return false, lastCodexMaxPrompt
|
||||
} else if strings.Contains(modelName, "5.2-codex") {
|
||||
return false, last52CodexPrompt
|
||||
} else if strings.Contains(modelName, "codex") {
|
||||
return false, lastCodexPrompt
|
||||
} else if strings.Contains(modelName, "5.1") {
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to "never", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `"require_escalated"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into "AI slop" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No "save/copy this file" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateRandomState generates a cryptographically secure random state parameter
|
||||
@@ -19,3 +21,83 @@ func GenerateRandomState() (string, error) {
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// OAuthCallback captures the parsed OAuth callback parameters.
|
||||
type OAuthCallback struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
ErrorDescription string
|
||||
}
|
||||
|
||||
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
||||
// It returns nil when the input is empty.
|
||||
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
||||
trimmed := strings.TrimSpace(input)
|
||||
if trimmed == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
candidate := trimmed
|
||||
if !strings.Contains(candidate, "://") {
|
||||
if strings.HasPrefix(candidate, "?") {
|
||||
candidate = "http://localhost" + candidate
|
||||
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
|
||||
candidate = "http://" + candidate
|
||||
} else if strings.Contains(candidate, "=") {
|
||||
candidate = "http://localhost/?" + candidate
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid callback URL")
|
||||
}
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(candidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
state := strings.TrimSpace(query.Get("state"))
|
||||
errCode := strings.TrimSpace(query.Get("error"))
|
||||
errDesc := strings.TrimSpace(query.Get("error_description"))
|
||||
|
||||
if parsedURL.Fragment != "" {
|
||||
if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil {
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(fragQuery.Get("code"))
|
||||
}
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(fragQuery.Get("state"))
|
||||
}
|
||||
if errCode == "" {
|
||||
errCode = strings.TrimSpace(fragQuery.Get("error"))
|
||||
}
|
||||
if errDesc == "" {
|
||||
errDesc = strings.TrimSpace(fragQuery.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if code != "" && state == "" && strings.Contains(code, "#") {
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
code = parts[0]
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
if errCode == "" && errDesc != "" {
|
||||
errCode = errDesc
|
||||
errDesc = ""
|
||||
}
|
||||
|
||||
if code == "" && errCode == "" {
|
||||
return nil, fmt.Errorf("callback URL missing code")
|
||||
}
|
||||
|
||||
return &OAuthCallback{
|
||||
Code: code,
|
||||
State: state,
|
||||
Error: errCode,
|
||||
ErrorDescription: errDesc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -160,7 +160,22 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Gemini 3 Flash Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -175,7 +190,7 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -240,7 +255,22 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -255,7 +285,7 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -317,11 +347,26 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
Name: "models/gemini-3-pro-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Pro Preview",
|
||||
Description: "Gemini 3 Pro Preview",
|
||||
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -387,7 +432,22 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-pro-latest",
|
||||
@@ -580,6 +640,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
Object: "model",
|
||||
Created: 1765440000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.2",
|
||||
DisplayName: "GPT 5.2 Codex",
|
||||
Description: "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
@@ -630,6 +704,13 @@ func GetQwenModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
|
||||
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
|
||||
// Uses level-based configuration so standard normalization flows apply before conversion.
|
||||
var iFlowThinkingSupport = &ThinkingSupport{
|
||||
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
|
||||
}
|
||||
|
||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
entries := []struct {
|
||||
@@ -645,19 +726,22 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2", Created: 1764576000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
||||
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -690,9 +774,36 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
|
||||
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
}
|
||||
}
|
||||
|
||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||
// Returns nil if no matching model is found.
|
||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
allModels := [][]*ModelInfo{
|
||||
GetClaudeModels(),
|
||||
GetGeminiModels(),
|
||||
GetGeminiVertexModels(),
|
||||
GetGeminiCLIModels(),
|
||||
GetAIStudioModels(),
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
if m != nil && m.ID == modelID {
|
||||
return m
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,6 +90,9 @@ type ModelRegistry struct {
|
||||
models map[string]*ModelRegistration
|
||||
// clientModels maps client ID to the models it provides
|
||||
clientModels map[string][]string
|
||||
// clientModelInfos maps client ID to a map of model ID -> ModelInfo
|
||||
// This preserves the original model info provided by each client
|
||||
clientModelInfos map[string]map[string]*ModelInfo
|
||||
// clientProviders maps client ID to its provider identifier
|
||||
clientProviders map[string]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
@@ -104,10 +107,11 @@ var registryOnce sync.Once
|
||||
func GetGlobalRegistry() *ModelRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
@@ -144,6 +148,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
// No models supplied; unregister existing client state if present.
|
||||
r.unregisterClientInternal(clientID)
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
delete(r.clientProviders, clientID)
|
||||
misc.LogCredentialSeparator()
|
||||
return
|
||||
@@ -152,7 +157,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
now := time.Now()
|
||||
|
||||
oldModels, hadExisting := r.clientModels[clientID]
|
||||
oldProvider, _ := r.clientProviders[clientID]
|
||||
oldProvider := r.clientProviders[clientID]
|
||||
providerChanged := oldProvider != provider
|
||||
if !hadExisting {
|
||||
// Pure addition path.
|
||||
@@ -161,6 +166,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
r.addModelRegistration(modelID, provider, model, now)
|
||||
}
|
||||
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||
// Store client's own model infos
|
||||
clientInfos := make(map[string]*ModelInfo, len(newModels))
|
||||
for id, m := range newModels {
|
||||
clientInfos[id] = cloneModelInfo(m)
|
||||
}
|
||||
r.clientModelInfos[clientID] = clientInfos
|
||||
if provider != "" {
|
||||
r.clientProviders[clientID] = provider
|
||||
} else {
|
||||
@@ -287,6 +298,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
if len(rawModelIDs) > 0 {
|
||||
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||
}
|
||||
// Update client's own model infos
|
||||
clientInfos := make(map[string]*ModelInfo, len(newModels))
|
||||
for id, m := range newModels {
|
||||
clientInfos[id] = cloneModelInfo(m)
|
||||
}
|
||||
r.clientModelInfos[clientID] = clientInfos
|
||||
if provider != "" {
|
||||
r.clientProviders[clientID] = provider
|
||||
} else {
|
||||
@@ -436,6 +453,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
}
|
||||
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
if hasProvider {
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
@@ -871,3 +889,44 @@ func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, erro
|
||||
|
||||
return "", fmt.Errorf("no available clients for any model in handler type: %s", handlerType)
|
||||
}
|
||||
|
||||
// GetModelsForClient returns the models registered for a specific client.
|
||||
// Parameters:
|
||||
// - clientID: The client identifier (typically auth file name or auth ID)
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: List of models registered for this client, nil if client not found
|
||||
func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
modelIDs, exists := r.clientModels[clientID]
|
||||
if !exists || len(modelIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to use client-specific model infos first
|
||||
clientInfos := r.clientModelInfos[clientID]
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
result := make([]*ModelInfo, 0, len(modelIDs))
|
||||
for _, modelID := range modelIDs {
|
||||
if _, dup := seen[modelID]; dup {
|
||||
continue
|
||||
}
|
||||
seen[modelID] = struct{}{}
|
||||
|
||||
// Prefer client's own model info to preserve original type/owned_by
|
||||
if clientInfos != nil {
|
||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||
result = append(result, info)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Fallback to global registry (for backwards compatibility)
|
||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||
result = append(result, reg.Info)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
@@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
@@ -322,10 +324,11 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
|
||||
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
||||
@@ -384,8 +387,16 @@ func ensureColonSpacedJSON(payload []byte) []byte {
|
||||
|
||||
for i := 0; i < len(indented); i++ {
|
||||
ch := indented[i]
|
||||
if ch == '"' && (i == 0 || indented[i-1] != '\\') {
|
||||
inString = !inString
|
||||
if ch == '"' {
|
||||
// A quote is escaped only when preceded by an odd number of consecutive backslashes.
|
||||
// For example: "\\\"" keeps the quote inside the string, but "\\\\" closes the string.
|
||||
backslashes := 0
|
||||
for j := i - 1; j >= 0 && indented[j] == '\\'; j-- {
|
||||
backslashes++
|
||||
}
|
||||
if backslashes%2 == 0 {
|
||||
inString = !inString
|
||||
}
|
||||
}
|
||||
|
||||
if !inString {
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,6 +17,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -30,20 +33,24 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
|
||||
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityCountTokensPath = "/v1internal:countTokens"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
)
|
||||
|
||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
)
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
type AntigravityExecutor struct {
|
||||
@@ -69,6 +76,11 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
|
||||
|
||||
// Execute performs a non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
if isClaude {
|
||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
@@ -85,8 +97,10 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -160,6 +174,338 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// executeClaudeNonStream performs a claude non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, true)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return resp, err
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errDo
|
||||
return resp, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errRead
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(nil, streamScannerBuffer)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
if payload == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
|
||||
var buffer bytes.Buffer
|
||||
for chunk := range out {
|
||||
if chunk.Err != nil {
|
||||
return resp, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
_, _ = buffer.Write(chunk.Payload)
|
||||
_, _ = buffer.Write([]byte("\n"))
|
||||
}
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
responseTemplate := ""
|
||||
var traceID string
|
||||
var finishReason string
|
||||
var modelVersion string
|
||||
var responseID string
|
||||
var role string
|
||||
var usageRaw string
|
||||
parts := make([]map[string]interface{}, 0)
|
||||
var pendingKind string
|
||||
var pendingText strings.Builder
|
||||
var pendingThoughtSig string
|
||||
|
||||
flushPending := func() {
|
||||
if pendingKind == "" {
|
||||
return
|
||||
}
|
||||
text := pendingText.String()
|
||||
switch pendingKind {
|
||||
case "text":
|
||||
if strings.TrimSpace(text) == "" {
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
return
|
||||
}
|
||||
parts = append(parts, map[string]interface{}{"text": text})
|
||||
case "thought":
|
||||
if strings.TrimSpace(text) == "" && pendingThoughtSig == "" {
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
return
|
||||
}
|
||||
part := map[string]interface{}{"thought": true}
|
||||
part["text"] = text
|
||||
if pendingThoughtSig != "" {
|
||||
part["thoughtSignature"] = pendingThoughtSig
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
}
|
||||
|
||||
normalizePart := func(partResult gjson.Result) map[string]interface{} {
|
||||
var m map[string]interface{}
|
||||
_ = json.Unmarshal([]byte(partResult.Raw), &m)
|
||||
if m == nil {
|
||||
m = map[string]interface{}{}
|
||||
}
|
||||
sig := partResult.Get("thoughtSignature").String()
|
||||
if sig == "" {
|
||||
sig = partResult.Get("thought_signature").String()
|
||||
}
|
||||
if sig != "" {
|
||||
m["thoughtSignature"] = sig
|
||||
delete(m, "thought_signature")
|
||||
}
|
||||
if inlineData, ok := m["inline_data"]; ok {
|
||||
m["inlineData"] = inlineData
|
||||
delete(m, "inline_data")
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
for _, line := range bytes.Split(stream, []byte("\n")) {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 || !gjson.ValidBytes(trimmed) {
|
||||
continue
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(trimmed)
|
||||
responseNode := root.Get("response")
|
||||
if !responseNode.Exists() {
|
||||
if root.Get("candidates").Exists() {
|
||||
responseNode = root
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
responseTemplate = responseNode.Raw
|
||||
|
||||
if traceResult := root.Get("traceId"); traceResult.Exists() && traceResult.String() != "" {
|
||||
traceID = traceResult.String()
|
||||
}
|
||||
|
||||
if roleResult := responseNode.Get("candidates.0.content.role"); roleResult.Exists() {
|
||||
role = roleResult.String()
|
||||
}
|
||||
|
||||
if finishResult := responseNode.Get("candidates.0.finishReason"); finishResult.Exists() && finishResult.String() != "" {
|
||||
finishReason = finishResult.String()
|
||||
}
|
||||
|
||||
if modelResult := responseNode.Get("modelVersion"); modelResult.Exists() && modelResult.String() != "" {
|
||||
modelVersion = modelResult.String()
|
||||
}
|
||||
if responseIDResult := responseNode.Get("responseId"); responseIDResult.Exists() && responseIDResult.String() != "" {
|
||||
responseID = responseIDResult.String()
|
||||
}
|
||||
if usageResult := responseNode.Get("usageMetadata"); usageResult.Exists() {
|
||||
usageRaw = usageResult.Raw
|
||||
} else if usageResult := root.Get("usageMetadata"); usageResult.Exists() {
|
||||
usageRaw = usageResult.Raw
|
||||
}
|
||||
|
||||
if partsResult := responseNode.Get("candidates.0.content.parts"); partsResult.IsArray() {
|
||||
for _, part := range partsResult.Array() {
|
||||
hasFunctionCall := part.Get("functionCall").Exists()
|
||||
hasInlineData := part.Get("inlineData").Exists() || part.Get("inline_data").Exists()
|
||||
sig := part.Get("thoughtSignature").String()
|
||||
if sig == "" {
|
||||
sig = part.Get("thought_signature").String()
|
||||
}
|
||||
text := part.Get("text").String()
|
||||
thought := part.Get("thought").Bool()
|
||||
|
||||
if hasFunctionCall || hasInlineData {
|
||||
flushPending()
|
||||
parts = append(parts, normalizePart(part))
|
||||
continue
|
||||
}
|
||||
|
||||
if thought || part.Get("text").Exists() {
|
||||
kind := "text"
|
||||
if thought {
|
||||
kind = "thought"
|
||||
}
|
||||
if pendingKind != "" && pendingKind != kind {
|
||||
flushPending()
|
||||
}
|
||||
pendingKind = kind
|
||||
pendingText.WriteString(text)
|
||||
if kind == "thought" && sig != "" {
|
||||
pendingThoughtSig = sig
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
flushPending()
|
||||
parts = append(parts, normalizePart(part))
|
||||
}
|
||||
}
|
||||
}
|
||||
flushPending()
|
||||
|
||||
if responseTemplate == "" {
|
||||
responseTemplate = `{"candidates":[{"content":{"role":"model","parts":[]}}]}`
|
||||
}
|
||||
|
||||
partsJSON, _ := json.Marshal(parts)
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
|
||||
if role != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
|
||||
}
|
||||
if finishReason != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
|
||||
}
|
||||
if modelVersion != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
|
||||
}
|
||||
if responseID != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
|
||||
}
|
||||
if usageRaw != "" {
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
|
||||
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
|
||||
}
|
||||
|
||||
output := `{"response":{},"traceId":""}`
|
||||
output, _ = sjson.SetRaw(output, "response", responseTemplate)
|
||||
if traceID != "" {
|
||||
output, _ = sjson.Set(output, "traceId", traceID)
|
||||
}
|
||||
return []byte(output)
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
ctx = context.WithValue(ctx, "alt", "")
|
||||
@@ -175,13 +521,17 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -312,9 +662,133 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request (not supported for Antigravity).
|
||||
func (e *AntigravityExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported"}
|
||||
// CountTokens counts tokens for the given request using the Antigravity API.
|
||||
func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return cliproxyexecutor.Response{}, errToken
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
base := strings.TrimSuffix(baseURL, "/")
|
||||
if base == "" {
|
||||
base = buildBaseURL(auth)
|
||||
}
|
||||
|
||||
var requestURL strings.Builder
|
||||
requestURL.WriteString(base)
|
||||
requestURL.WriteString(antigravityCountTokensPath)
|
||||
if opts.Alt != "" {
|
||||
requestURL.WriteString("?$alt=")
|
||||
requestURL.WriteString(url.QueryEscape(opts.Alt))
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return cliproxyexecutor.Response{}, errReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
default:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAntigravityModels retrieves available models using the supplied auth.
|
||||
@@ -545,27 +1019,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
strJSON = util.DeleteKey(strJSON, "$schema")
|
||||
strJSON = util.DeleteKey(strJSON, "maxItems")
|
||||
strJSON = util.DeleteKey(strJSON, "minItems")
|
||||
strJSON = util.DeleteKey(strJSON, "minLength")
|
||||
strJSON = util.DeleteKey(strJSON, "maxLength")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMaximum")
|
||||
strJSON = util.DeleteKey(strJSON, "$ref")
|
||||
strJSON = util.DeleteKey(strJSON, "$defs")
|
||||
|
||||
paths = make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
|
||||
for _, p := range paths {
|
||||
anyOf := gjson.Get(strJSON, p)
|
||||
if anyOf.IsArray() {
|
||||
anyOfItems := anyOf.Array()
|
||||
if len(anyOfItems) > 0 {
|
||||
strJSON, _ = sjson.SetRaw(strJSON, p[:len(p)-len(".anyOf")], anyOfItems[0].Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
||||
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
@@ -705,7 +1161,7 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
||||
}
|
||||
return []string{
|
||||
antigravityBaseURLDaily,
|
||||
// antigravityBaseURLAutopush,
|
||||
antigravitySandboxBaseURLDaily,
|
||||
antigravityBaseURLProd,
|
||||
}
|
||||
}
|
||||
@@ -741,7 +1197,7 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateSessionID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
@@ -777,15 +1233,36 @@ func generateRequestID() string {
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
randSourceMutex.Lock()
|
||||
n := randSource.Int63n(9_000_000_000_000_000_000)
|
||||
randSourceMutex.Unlock()
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
func generateStableSessionID(payload []byte) string {
|
||||
contents := gjson.GetBytes(payload, "request.contents")
|
||||
if contents.IsArray() {
|
||||
for _, content := range contents.Array() {
|
||||
if content.Get("role").String() == "user" {
|
||||
text := content.Get("parts.0.text").String()
|
||||
if text != "" {
|
||||
h := sha256.Sum256([]byte(text))
|
||||
n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return generateSessionID()
|
||||
}
|
||||
|
||||
func generateProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
randSourceMutex.Lock()
|
||||
adj := adjectives[randSource.Intn(len(adjectives))]
|
||||
noun := nouns[randSource.Intn(len(nouns))]
|
||||
randSourceMutex.Unlock()
|
||||
randomPart := strings.ToLower(uuid.NewString())[:5]
|
||||
return adj + "-" + noun + "-" + randomPart
|
||||
}
|
||||
@@ -798,6 +1275,8 @@ func modelName2Alias(modelName string) string {
|
||||
return "gemini-3-pro-image-preview"
|
||||
case "gemini-3-pro-high":
|
||||
return "gemini-3-pro-preview"
|
||||
case "gemini-3-flash":
|
||||
return "gemini-3-flash-preview"
|
||||
case "claude-sonnet-4-5":
|
||||
return "gemini-claude-sonnet-4-5"
|
||||
case "claude-sonnet-4-5-thinking":
|
||||
@@ -819,6 +1298,8 @@ func alias2ModelName(modelName string) string {
|
||||
return "gemini-3-pro-image"
|
||||
case "gemini-3-pro-preview":
|
||||
return "gemini-3-pro-high"
|
||||
case "gemini-3-flash-preview":
|
||||
return "gemini-3-flash"
|
||||
case "gemini-claude-sonnet-4-5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "gemini-claude-sonnet-4-5-thinking":
|
||||
@@ -832,7 +1313,7 @@ func alias2ModelName(modelName string) string {
|
||||
|
||||
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
|
||||
// For Claude models, it additionally ensures thinking budget < max_tokens.
|
||||
func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
||||
func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
|
||||
payload = util.StripThinkingConfigIfUnsupported(model, payload)
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return payload
|
||||
@@ -844,7 +1325,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
||||
raw := int(budget.Int())
|
||||
normalized := util.NormalizeThinkingBudget(model, raw)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(model), "claude")
|
||||
if isClaude {
|
||||
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
|
||||
if effectiveMax > 0 && normalized >= effectiveMax {
|
||||
|
||||
@@ -49,33 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
// Inject thinking config based on model metadata for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
|
||||
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||
|
||||
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(req.Model, body)
|
||||
body = ensureMaxTokensForThinking(model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
@@ -167,26 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
// Inject thinking config based on model metadata for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
|
||||
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||
body = checkSystemInstructions(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(req.Model, body)
|
||||
body = ensureMaxTokensForThinking(model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
@@ -310,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
|
||||
@@ -461,6 +446,19 @@ func (e *ClaudeExecutor) injectThinkingConfig(modelName string, metadata map[str
|
||||
return util.ApplyClaudeThinkingConfig(body, budget)
|
||||
}
|
||||
|
||||
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
|
||||
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
|
||||
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
|
||||
func disableThinkingIfToolChoiceForced(body []byte) []byte {
|
||||
toolChoiceType := gjson.GetBytes(body, "tool_choice.type").String()
|
||||
// "auto" is allowed with thinking, but "any" or "tool" (specific tool) are not
|
||||
if toolChoiceType == "any" || toolChoiceType == "tool" {
|
||||
// Remove thinking configuration entirely to avoid API error
|
||||
body, _ = sjson.DeleteBytes(body, "thinking")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
|
||||
// Anthropic API requires this constraint; violating it returns a 400 error.
|
||||
// This function should be called after all thinking configuration is finalized.
|
||||
@@ -662,7 +660,14 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
|
||||
if isAnthropicBase && useAPIKey {
|
||||
r.Header.Del("Authorization")
|
||||
r.Header.Set("x-api-key", apiKey)
|
||||
} else {
|
||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
var ginHeaders http.Header
|
||||
|
||||
@@ -49,18 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
|
||||
@@ -146,20 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -246,20 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
|
||||
modelForCounting := req.Model
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort")
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
enc, err := tokenizerForCodexModel(modelForCounting)
|
||||
enc, err := tokenizerForCodexModel(model)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
@@ -520,3 +527,87 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
trimmed := strings.TrimSpace(alias)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entry := e.resolveCodexConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||
|
||||
// Candidate names to match against configured aliases/names.
|
||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||
candidates = append(candidates, original)
|
||||
}
|
||||
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.CodexKey {
|
||||
entry := &e.cfg.CodexKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.CodexKey {
|
||||
entry := &e.cfg.CodexKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -77,6 +79,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
@@ -215,6 +218,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
@@ -314,7 +318,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response, reqBody []byte, attempt string) {
|
||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
@@ -332,14 +336,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -361,12 +365,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -413,14 +417,17 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
|
||||
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||
// Gemini CLI endpoint when iterating fallback variants.
|
||||
for _, attemptModel := range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload)
|
||||
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
|
||||
|
||||
tok, errTok := tokenSource.Token()
|
||||
if errTok != nil {
|
||||
@@ -784,20 +791,45 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
// Try to parse the retryDelay from the error response
|
||||
// Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo"
|
||||
details := gjson.GetBytes(errorBody, "error.details")
|
||||
if !details.Exists() || !details.IsArray() {
|
||||
return nil, fmt.Errorf("no error.details found")
|
||||
if details.Exists() && details.IsArray() {
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
||||
retryDelay := detail.Get("retryDelay").String()
|
||||
if retryDelay != "" {
|
||||
// Parse duration string like "0.847655010s"
|
||||
duration, err := time.ParseDuration(retryDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration")
|
||||
}
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try ErrorInfo.metadata.quotaResetDelay (e.g., "373.801628ms")
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.ErrorInfo" {
|
||||
quotaResetDelay := detail.Get("metadata.quotaResetDelay").String()
|
||||
if quotaResetDelay != "" {
|
||||
duration, err := time.ParseDuration(quotaResetDelay)
|
||||
if err == nil {
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
||||
retryDelay := detail.Get("retryDelay").String()
|
||||
if retryDelay != "" {
|
||||
// Parse duration string like "0.847655010s"
|
||||
duration, err := time.ParseDuration(retryDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration")
|
||||
}
|
||||
// Fallback: parse from error.message "Your quota will reset after Xs."
|
||||
message := gjson.GetBytes(errorBody, "error.message").String()
|
||||
if message != "" {
|
||||
re := regexp.MustCompile(`after\s+(\d+)s\.?`)
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,19 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -98,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
}
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -173,21 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -287,19 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
apiKey, bearer := geminiCreds(auth)
|
||||
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = applyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
|
||||
|
||||
requestBody := bytes.NewReader(translatedReq)
|
||||
|
||||
@@ -398,6 +410,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
|
||||
return base
|
||||
}
|
||||
|
||||
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
trimmed := strings.TrimSpace(alias)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entry := e.resolveGeminiConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||
|
||||
// Candidate names to match against configured aliases/names.
|
||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||
candidates = append(candidates, original)
|
||||
}
|
||||
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.GeminiKey {
|
||||
entry := &e.cfg.GeminiKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.GeminiKey {
|
||||
entry := &e.cfg.GeminiKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
|
||||
@@ -120,8 +120,6 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
@@ -137,7 +135,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -146,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
}
|
||||
}
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -220,24 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -250,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -321,8 +322,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -338,10 +337,10 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -438,30 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -552,8 +554,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
|
||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
@@ -566,14 +566,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
@@ -641,21 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
@@ -665,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
@@ -808,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
|
||||
}
|
||||
return tok.AccessToken, nil
|
||||
}
|
||||
|
||||
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
|
||||
// It matches the requested model alias against configured models and returns the actual upstream name.
|
||||
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
trimmed := strings.TrimSpace(alias)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entry := e.resolveVertexConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||
|
||||
// Candidate names to match against configured aliases/names.
|
||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||
candidates = append(candidates, original)
|
||||
}
|
||||
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
|
||||
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.VertexCompatAPIKey {
|
||||
entry := &e.cfg.VertexCompatAPIKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.VertexCompatAPIKey {
|
||||
entry := &e.cfg.VertexCompatAPIKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -57,15 +57,14 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
@@ -148,15 +147,14 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
@@ -219,7 +217,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
@@ -442,3 +440,99 @@ func ensureToolsArray(body []byte) []byte {
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// preserveReasoningContentInMessages ensures reasoning_content from assistant messages in the
|
||||
// conversation history is preserved when sending to iFlow models that support thinking.
|
||||
// This is critical for multi-turn conversations where the model needs to see its previous
|
||||
// reasoning to maintain coherent thought chains across tool calls and conversation turns.
|
||||
//
|
||||
// For GLM-4.7 and MiniMax-M2.1, the full assistant response (including reasoning) must be
|
||||
// appended back into message history before the next call.
|
||||
func preserveReasoningContentInMessages(body []byte) []byte {
|
||||
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||
|
||||
// Only apply to models that support thinking with history preservation
|
||||
needsPreservation := strings.HasPrefix(model, "glm-4.7") ||
|
||||
strings.HasPrefix(model, "glm-4-7") ||
|
||||
strings.HasPrefix(model, "minimax-m2.1") ||
|
||||
strings.HasPrefix(model, "minimax-m2-1")
|
||||
|
||||
if !needsPreservation {
|
||||
return body
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
// Check if any assistant message already has reasoning_content preserved
|
||||
hasReasoningContent := false
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
role := msg.Get("role").String()
|
||||
if role == "assistant" {
|
||||
rc := msg.Get("reasoning_content")
|
||||
if rc.Exists() && rc.String() != "" {
|
||||
hasReasoningContent = true
|
||||
return false // stop iteration
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// If reasoning content is already present, the messages are properly formatted
|
||||
// No need to modify - the client has correctly preserved reasoning in history
|
||||
if hasReasoningContent {
|
||||
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
|
||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
||||
//
|
||||
// Model-specific handling:
|
||||
// - GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
|
||||
// - MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
|
||||
// - Other iFlow models: Uses chat_template_kwargs.enable_thinking (boolean)
|
||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
||||
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
||||
|
||||
// Check if thinking should be enabled
|
||||
val := ""
|
||||
if effort.Exists() {
|
||||
val = strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
}
|
||||
enableThinking := effort.Exists() && val != "none" && val != ""
|
||||
|
||||
// Remove reasoning_effort as we'll convert to model-specific format
|
||||
if effort.Exists() {
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
}
|
||||
|
||||
// GLM-4.7: Use extra_body with thinking config and clear_thinking: false
|
||||
if strings.HasPrefix(model, "glm-4.7") || strings.HasPrefix(model, "glm-4-7") {
|
||||
if enableThinking {
|
||||
body, _ = sjson.SetBytes(body, "extra_body.thinking.type", "enabled")
|
||||
body, _ = sjson.SetBytes(body, "extra_body.clear_thinking", false)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// MiniMax-M2.1: Use reasoning_split=true for interleaved thinking
|
||||
if strings.HasPrefix(model, "minimax-m2.1") || strings.HasPrefix(model, "minimax-m2-1") {
|
||||
if enableThinking {
|
||||
body, _ = sjson.SetBytes(body, "reasoning_split", true)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// Other iFlow models (including GLM-4.6): Use chat_template_kwargs.enable_thinking
|
||||
if effort.Exists() {
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -54,17 +54,15 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||
if modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = normalizeThinkingConfig(translated, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
|
||||
@@ -148,17 +146,15 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
||||
if modelOverride != "" {
|
||||
translated = e.overrideModel(translated, modelOverride)
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort")
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = normalizeThinkingConfig(translated, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
|
||||
@@ -323,6 +319,27 @@ func (e *OpenAICompatExecutor) resolveUpstreamModel(alias string, auth *cliproxy
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) allowCompatReasoningEffort(model string, auth *cliproxyauth.Auth) bool {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" || e == nil || e.cfg == nil {
|
||||
return false
|
||||
}
|
||||
compat := e.resolveCompatConfig(auth)
|
||||
if compat == nil || len(compat.Models) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range compat.Models {
|
||||
entry := compat.Models[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Alias), trimmed) {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(entry.Name), trimmed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// applyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
|
||||
func applyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
|
||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
||||
return payload
|
||||
@@ -45,22 +45,38 @@ func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model str
|
||||
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
|
||||
// applyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
||||
// ApplyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
||||
// Metadata values take precedence over any existing field when the model supports thinking, intentionally
|
||||
// overwriting caller-provided values to honor suffix/default metadata priority.
|
||||
func applyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string) []byte {
|
||||
func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string, allowCompat bool) []byte {
|
||||
if len(metadata) == 0 {
|
||||
return payload
|
||||
}
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return payload
|
||||
}
|
||||
if field == "" {
|
||||
return payload
|
||||
}
|
||||
baseModel := util.ResolveOriginalModel(model, metadata)
|
||||
if baseModel == "" {
|
||||
baseModel = model
|
||||
}
|
||||
if !util.ModelSupportsThinking(baseModel) && !allowCompat {
|
||||
return payload
|
||||
}
|
||||
if effort, ok := util.ReasoningEffortFromMetadata(metadata); ok && effort != "" {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: numeric thinking_budget suffix for level-based (OpenAI-style) models.
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return payload
|
||||
@@ -216,34 +232,43 @@ func matchModelPattern(pattern, model string) bool {
|
||||
return pi == len(pattern)
|
||||
}
|
||||
|
||||
// normalizeThinkingConfig normalizes thinking-related fields in the payload
|
||||
// NormalizeThinkingConfig normalizes thinking-related fields in the payload
|
||||
// based on model capabilities. For models without thinking support, it strips
|
||||
// reasoning fields. For models with level-based thinking, it validates and
|
||||
// normalizes the reasoning effort level.
|
||||
func normalizeThinkingConfig(payload []byte, model string) []byte {
|
||||
// normalizes the reasoning effort level. For models with numeric budget thinking,
|
||||
// it strips the effort string fields.
|
||||
func NormalizeThinkingConfig(payload []byte, model string, allowCompat bool) []byte {
|
||||
if len(payload) == 0 || model == "" {
|
||||
return payload
|
||||
}
|
||||
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return stripThinkingFields(payload)
|
||||
if allowCompat {
|
||||
return payload
|
||||
}
|
||||
return StripThinkingFields(payload, false)
|
||||
}
|
||||
|
||||
if util.ModelUsesThinkingLevels(model) {
|
||||
return normalizeReasoningEffortLevel(payload, model)
|
||||
return NormalizeReasoningEffortLevel(payload, model)
|
||||
}
|
||||
|
||||
return payload
|
||||
// Model supports thinking but uses numeric budgets, not levels.
|
||||
// Strip effort string fields since they are not applicable.
|
||||
return StripThinkingFields(payload, true)
|
||||
}
|
||||
|
||||
// stripThinkingFields removes thinking-related fields from the payload for
|
||||
// models that do not support thinking.
|
||||
func stripThinkingFields(payload []byte) []byte {
|
||||
// StripThinkingFields removes thinking-related fields from the payload for
|
||||
// models that do not support thinking. If effortOnly is true, only removes
|
||||
// effort string fields (for models using numeric budgets).
|
||||
func StripThinkingFields(payload []byte, effortOnly bool) []byte {
|
||||
fieldsToRemove := []string{
|
||||
"reasoning",
|
||||
"reasoning_effort",
|
||||
"reasoning.effort",
|
||||
}
|
||||
if !effortOnly {
|
||||
fieldsToRemove = append([]string{"reasoning", "thinking"}, fieldsToRemove...)
|
||||
}
|
||||
out := payload
|
||||
for _, field := range fieldsToRemove {
|
||||
if gjson.GetBytes(out, field).Exists() {
|
||||
@@ -253,9 +278,9 @@ func stripThinkingFields(payload []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
||||
// NormalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
||||
// or reasoning.effort field for level-based thinking models.
|
||||
func normalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
func NormalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
out := payload
|
||||
|
||||
if effort := gjson.GetBytes(out, "reasoning_effort"); effort.Exists() {
|
||||
@@ -273,10 +298,10 @@ func normalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// validateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
||||
// ValidateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
||||
// Returns a statusErr with 400 when an unsupported level is supplied to avoid silently
|
||||
// downgrading requests.
|
||||
func validateThinkingConfig(payload []byte, model string) error {
|
||||
func ValidateThinkingConfig(payload []byte, model string) error {
|
||||
if len(payload) == 0 || model == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -51,13 +50,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -131,13 +127,10 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort")
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
|
||||
@@ -19,7 +19,7 @@ type usageReporter struct {
|
||||
provider string
|
||||
model string
|
||||
authID string
|
||||
authIndex uint64
|
||||
authIndex string
|
||||
apiKey string
|
||||
source string
|
||||
requestedAt time.Time
|
||||
@@ -275,6 +275,20 @@ func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return detail, true
|
||||
}
|
||||
|
||||
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
CachedTokens: node.Get("cachedContentTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
func parseGeminiCLIUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data)
|
||||
node := usageNode.Get("response.usageMetadata")
|
||||
@@ -284,16 +298,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseGeminiUsage(data []byte) usage.Detail {
|
||||
@@ -305,16 +310,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
@@ -329,16 +325,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail, true
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
@@ -353,16 +340,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail, true
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
func parseAntigravityUsage(data []byte) usage.Detail {
|
||||
@@ -377,16 +355,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
@@ -404,16 +373,7 @@ func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail, true
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
var stopChunkWithoutUsage sync.Map
|
||||
@@ -522,12 +482,16 @@ func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
|
||||
cleaned := jsonBytes
|
||||
var changed bool
|
||||
|
||||
if gjson.GetBytes(cleaned, "usageMetadata").Exists() {
|
||||
if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
|
||||
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
|
||||
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
|
||||
changed = true
|
||||
}
|
||||
|
||||
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() {
|
||||
if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
|
||||
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
|
||||
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
|
||||
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
|
||||
changed = true
|
||||
}
|
||||
|
||||
@@ -7,17 +7,40 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// deriveSessionID generates a stable session ID from the request.
|
||||
// Uses the hash of the first user message to identify the conversation.
|
||||
func deriveSessionID(rawJSON []byte) string {
|
||||
messages := gjson.GetBytes(rawJSON, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "user" {
|
||||
content := msg.Get("content").String()
|
||||
if content == "" {
|
||||
// Try to get text from content array
|
||||
content = msg.Get("content.0.text").String()
|
||||
}
|
||||
if content != "" {
|
||||
h := sha256.Sum256([]byte(content))
|
||||
return hex.EncodeToString(h[:16])
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
@@ -39,149 +62,329 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Derive session ID for signature caching
|
||||
sessionID := deriveSessionID(rawJSON)
|
||||
|
||||
// system instruction
|
||||
var systemInstruction *client.Content
|
||||
systemInstructionJSON := ""
|
||||
hasSystemInstruction := false
|
||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
systemPart := client.Part{Text: systemPrompt}
|
||||
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||
partJSON := `{}`
|
||||
if systemPrompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
|
||||
}
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
if len(systemInstruction.Parts) == 0 {
|
||||
systemInstruction = nil
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
|
||||
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
|
||||
// contents
|
||||
contents := make([]client.Content, 0)
|
||||
contentsJSON := "[]"
|
||||
hasContents := false
|
||||
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
for i := 0; i < len(messageResults); i++ {
|
||||
numMessages := len(messageResults)
|
||||
for i := 0; i < numMessages; i++ {
|
||||
messageResult := messageResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
role := roleResult.String()
|
||||
originalRole := roleResult.String()
|
||||
role := originalRole
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||
clientContentJSON := `{"role":"","parts":[]}`
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
for j := 0; j < len(contentResults); j++ {
|
||||
numContents := len(contentResults)
|
||||
var currentMessageThinkingSignature string
|
||||
for j := 0; j < numContents; j++ {
|
||||
contentResult := contentResults[j]
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
||||
prompt := contentResult.Get("thinking").String()
|
||||
// Use GetThinkingText to handle wrapped thinking objects
|
||||
thinkingText := util.GetThinkingText(contentResult)
|
||||
signatureResult := contentResult.Get("signature")
|
||||
signature := geminiCLIClaudeThoughtSignature
|
||||
if signatureResult.Exists() {
|
||||
signature = signatureResult.String()
|
||||
clientSignature := ""
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
clientSignature = signatureResult.String()
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt, Thought: true, ThoughtSignature: signature})
|
||||
|
||||
// Always try cached signature first (more reliable than client-provided)
|
||||
// Client may send stale or invalid signatures from different sessions
|
||||
signature := ""
|
||||
if sessionID != "" && thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
|
||||
signature = cachedSig
|
||||
log.Debugf("Using cached signature for thinking block")
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to client signature only if cache miss and client signature is valid
|
||||
if signature == "" && cache.HasValidSignature(clientSignature) {
|
||||
signature = clientSignature
|
||||
log.Debugf("Using client-provided signature for thinking block")
|
||||
}
|
||||
|
||||
// Store for subsequent tool_use in the same message
|
||||
if cache.HasValidSignature(signature) {
|
||||
currentMessageThinkingSignature = signature
|
||||
}
|
||||
|
||||
// Skip trailing unsigned thinking blocks on last assistant message
|
||||
isUnsigned := !cache.HasValidSignature(signature)
|
||||
|
||||
// If unsigned, skip entirely (don't convert to text)
|
||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||
// Converting to text would break this requirement
|
||||
if isUnsigned {
|
||||
// TypeScript plugin approach: drop unsigned thinking blocks entirely
|
||||
log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||
continue
|
||||
}
|
||||
|
||||
// Valid signature, send as thought block
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "thought", true)
|
||||
if thinkingText != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
|
||||
}
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||
partJSON := `{}`
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||
// Antigravity API validates signatures, so dummy values are rejected.
|
||||
// The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies.
|
||||
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := contentResult.Get("input")
|
||||
functionID := contentResult.Get("id").String()
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
if strings.Contains(modelName, "claude") {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
})
|
||||
} else {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
ThoughtSignature: geminiCLIClaudeThoughtSignature,
|
||||
})
|
||||
|
||||
// Handle both object and string input formats
|
||||
var argsRaw string
|
||||
if argsResult.IsObject() {
|
||||
argsRaw = argsResult.Raw
|
||||
} else if argsResult.Type == gjson.String {
|
||||
// Input is a JSON string, parse and validate it
|
||||
parsed := gjson.Parse(argsResult.String())
|
||||
if parsed.IsObject() {
|
||||
argsRaw = parsed.Raw
|
||||
}
|
||||
}
|
||||
|
||||
if argsRaw != "" {
|
||||
partJSON := `{}`
|
||||
|
||||
// Use skip_thought_signature_validator for tool calls without valid thinking signature
|
||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||
// and also works for Claude through Antigravity API
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
if cache.HasValidSignature(currentMessageThinkingSignature) {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||
} else {
|
||||
// No valid signature - use skip sentinel to bypass validation
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
|
||||
}
|
||||
|
||||
if functionID != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
|
||||
}
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
functionResponseJSON := `{}`
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
responseData = functionResponseResult.String()
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
if len(frResults) == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
||||
sourceResult := contentResult.Get("source")
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineData := &client.InlineData{
|
||||
MimeType: sourceResult.Get("media_type").String(),
|
||||
Data: sourceResult.Get("data").String(),
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{InlineData: inlineData})
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, clientContent)
|
||||
|
||||
// Reorder parts for 'model' role to ensure thinking block is first
|
||||
if role == "model" {
|
||||
partsResult := gjson.Get(clientContentJSON, "parts")
|
||||
if partsResult.IsArray() {
|
||||
parts := partsResult.Array()
|
||||
var thinkingParts []gjson.Result
|
||||
var otherParts []gjson.Result
|
||||
for _, part := range parts {
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingParts = append(thinkingParts, part)
|
||||
} else {
|
||||
otherParts = append(otherParts, part)
|
||||
}
|
||||
}
|
||||
if len(thinkingParts) > 0 {
|
||||
firstPartIsThinking := parts[0].Get("thought").Bool()
|
||||
if !firstPartIsThinking || len(thinkingParts) > 1 {
|
||||
var newParts []interface{}
|
||||
for _, p := range thinkingParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
for _, p := range otherParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||
partJSON := `{}`
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tools
|
||||
var tools []client.ToolDeclaration
|
||||
toolsJSON := ""
|
||||
toolDeclCount := 0
|
||||
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsJSON = `[{"functionDeclarations":[]}]`
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
inputSchemaResult := toolResult.Get("input_schema")
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
inputSchema := inputSchemaResult.Raw
|
||||
// Sanitize the input schema for Antigravity API compatibility
|
||||
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
var toolDeclaration any
|
||||
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
for toolKey := range gjson.Parse(tool).Map() {
|
||||
if util.InArray(allowedToolKeys, toolKey) {
|
||||
continue
|
||||
}
|
||||
tool, _ = sjson.Delete(tool, toolKey)
|
||||
}
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolDeclCount++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
if systemInstruction != nil {
|
||||
b, _ := json.Marshal(systemInstruction)
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
|
||||
|
||||
// Inject interleaved thinking hint when both tools and thinking are active
|
||||
hasTools := toolDeclCount > 0
|
||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
|
||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||
|
||||
if hasTools && hasThinking && isClaudeThinking {
|
||||
interleavedHint := "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them."
|
||||
|
||||
if hasSystemInstruction {
|
||||
// Append hint as a new part to existing system instruction
|
||||
hintPart := `{"text":""}`
|
||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
||||
} else {
|
||||
// Create new system instruction with hint
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
hintPart := `{"text":""}`
|
||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
b, _ := json.Marshal(contents)
|
||||
out, _ = sjson.SetRaw(out, "request.contents", string(b))
|
||||
|
||||
if hasSystemInstruction {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
|
||||
}
|
||||
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
|
||||
b, _ := json.Marshal(tools)
|
||||
out, _ = sjson.SetRaw(out, "request.tools", string(b))
|
||||
if hasContents {
|
||||
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
|
||||
}
|
||||
if toolDeclCount > 0 {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
|
||||
@@ -0,0 +1,658 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_BasicStructure(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"system": [
|
||||
{"type": "text", "text": "You are helpful"}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check model
|
||||
if gjson.Get(outputStr, "model").String() != "claude-sonnet-4-5" {
|
||||
t.Errorf("Expected model 'claude-sonnet-4-5', got '%s'", gjson.Get(outputStr, "model").String())
|
||||
}
|
||||
|
||||
// Check contents exist
|
||||
contents := gjson.Get(outputStr, "request.contents")
|
||||
if !contents.Exists() || !contents.IsArray() {
|
||||
t.Error("request.contents should exist and be an array")
|
||||
}
|
||||
|
||||
// Check role mapping (assistant -> model)
|
||||
firstContent := gjson.Get(outputStr, "request.contents.0")
|
||||
if firstContent.Get("role").String() != "user" {
|
||||
t.Errorf("Expected role 'user', got '%s'", firstContent.Get("role").String())
|
||||
}
|
||||
|
||||
// Check systemInstruction
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if !sysInstruction.Exists() {
|
||||
t.Error("systemInstruction should exist")
|
||||
}
|
||||
if sysInstruction.Get("parts.0.text").String() != "You are helpful" {
|
||||
t.Error("systemInstruction text mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hi"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// assistant should be mapped to model
|
||||
secondContent := gjson.Get(outputStr, "request.contents.1")
|
||||
if secondContent.Get("role").String() != "model" {
|
||||
t.Errorf("Expected role 'model' (mapped from 'assistant'), got '%s'", secondContent.Get("role").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||
// Valid signature must be at least 50 characters
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
||||
{"type": "text", "text": "Answer"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check thinking block conversion
|
||||
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
|
||||
if !firstPart.Get("thought").Bool() {
|
||||
t.Error("thinking block should have thought: true")
|
||||
}
|
||||
if firstPart.Get("text").String() != "Let me think..." {
|
||||
t.Error("thinking text mismatch")
|
||||
}
|
||||
if firstPart.Get("thoughtSignature").String() != validSignature {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, firstPart.Get("thoughtSignature").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
// Unsigned thinking blocks should be removed entirely (not converted to text)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me think..."},
|
||||
{"type": "text", "text": "Answer"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Without signature, thinking block should be removed (not converted to text)
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Only text part should remain
|
||||
if parts[0].Get("thought").Bool() {
|
||||
t.Error("Thinking block should be removed, not preserved")
|
||||
}
|
||||
if parts[0].Get("text").String() != "Answer" {
|
||||
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [],
|
||||
"tools": [
|
||||
{
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("gemini-1.5-pro", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check tools structure
|
||||
tools := gjson.Get(outputStr, "request.tools")
|
||||
if !tools.Exists() {
|
||||
t.Error("Tools should exist in output")
|
||||
}
|
||||
|
||||
funcDecl := gjson.Get(outputStr, "request.tools.0.functionDeclarations.0")
|
||||
if funcDecl.Get("name").String() != "test_tool" {
|
||||
t.Errorf("Expected tool name 'test_tool', got '%s'", funcDecl.Get("name").String())
|
||||
}
|
||||
|
||||
// Check input_schema renamed to parametersJsonSchema
|
||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
||||
t.Log("parametersJsonSchema exists (expected)")
|
||||
}
|
||||
if funcDecl.Get("input_schema").Exists() {
|
||||
t.Error("input_schema should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": "{\"location\": \"Paris\"}"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Now we expect only 1 part (tool_use), no dummy thinking block injected
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (tool only, no dummy injection), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Check function call conversion at parts[0]
|
||||
funcCall := parts[0].Get("functionCall")
|
||||
if !funcCall.Exists() {
|
||||
t.Error("functionCall should exist at parts[0]")
|
||||
}
|
||||
if funcCall.Get("name").String() != "get_weather" {
|
||||
t.Errorf("Expected function name 'get_weather', got '%s'", funcCall.Get("name").String())
|
||||
}
|
||||
if funcCall.Get("id").String() != "call_123" {
|
||||
t.Errorf("Expected function id 'call_123', got '%s'", funcCall.Get("id").String())
|
||||
}
|
||||
// Verify skip_thought_signature_validator is added (bypass for tools without valid thinking)
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
actualSig := parts[0].Get("thoughtSignature").String()
|
||||
if actualSig != expectedSig {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, actualSig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": "{\"location\": \"Paris\"}"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function call has the signature from the preceding thinking block
|
||||
part := gjson.Get(outputStr, "request.contents.0.parts.1")
|
||||
if part.Get("functionCall.name").String() != "get_weather" {
|
||||
t.Errorf("Expected functionCall, got %s", part.Raw)
|
||||
}
|
||||
if part.Get("thoughtSignature").String() != validSignature {
|
||||
t.Errorf("Expected thoughtSignature '%s' on tool_use, got '%s'", validSignature, part.Get("thoughtSignature").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||
// Case: text block followed by thinking block -> should be reordered to thinking first
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here is the plan."},
|
||||
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify order: Thinking block MUST be first
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
if !parts[0].Get("thought").Bool() {
|
||||
t.Error("First part should be thinking block after reordering")
|
||||
}
|
||||
if parts[1].Get("text").String() != "Here is the plan." {
|
||||
t.Error("Second part should be text block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "get_weather-call-123",
|
||||
"content": "22C sunny"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function response conversion
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Error("functionResponse should exist")
|
||||
}
|
||||
if funcResp.Get("id").String() != "get_weather-call-123" {
|
||||
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
||||
// Note: This test requires the model to be registered in the registry
|
||||
// with Thinking metadata. If the registry is not populated in test environment,
|
||||
// thinkingConfig won't be added. We'll test the basic structure only.
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [],
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 8000
|
||||
}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check thinking config conversion (only if model supports thinking in registry)
|
||||
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
||||
if thinkingConfig.Exists() {
|
||||
if thinkingConfig.Get("thinkingBudget").Int() != 8000 {
|
||||
t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int())
|
||||
}
|
||||
if !thinkingConfig.Get("include_thoughts").Bool() {
|
||||
t.Error("include_thoughts should be true")
|
||||
}
|
||||
} else {
|
||||
t.Log("thinkingConfig not present - model may not be registered in test registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check inline data conversion
|
||||
inlineData := gjson.Get(outputStr, "request.contents.0.parts.0.inlineData")
|
||||
if !inlineData.Exists() {
|
||||
t.Error("inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mime_type").String() != "image/png" {
|
||||
t.Error("mime_type mismatch")
|
||||
}
|
||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||
t.Error("data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_GenerationConfig(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [],
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"max_tokens": 2000
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
genConfig := gjson.Get(outputStr, "request.generationConfig")
|
||||
if genConfig.Get("temperature").Float() != 0.7 {
|
||||
t.Errorf("Expected temperature 0.7, got %f", genConfig.Get("temperature").Float())
|
||||
}
|
||||
if genConfig.Get("topP").Float() != 0.9 {
|
||||
t.Errorf("Expected topP 0.9, got %f", genConfig.Get("topP").Float())
|
||||
}
|
||||
if genConfig.Get("topK").Float() != 40 {
|
||||
t.Errorf("Expected topK 40, got %f", genConfig.Get("topK").Float())
|
||||
}
|
||||
if genConfig.Get("maxOutputTokens").Float() != 2000 {
|
||||
t.Errorf("Expected maxOutputTokens 2000, got %f", genConfig.Get("maxOutputTokens").Float())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Trailing Unsigned Thinking Block Removal
|
||||
// ============================================================================
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *testing.T) {
|
||||
// Last assistant message ends with unsigned thinking block - should be removed
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here is my answer"},
|
||||
{"type": "thinking", "thinking": "I should think more..."}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// The last part of the last assistant message should NOT be a thinking block
|
||||
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
|
||||
if !lastMessageParts.IsArray() {
|
||||
t.Fatal("Last message should have parts array")
|
||||
}
|
||||
parts := lastMessageParts.Array()
|
||||
if len(parts) == 0 {
|
||||
t.Fatal("Last message should have at least one part")
|
||||
}
|
||||
|
||||
// The unsigned thinking should be removed, leaving only the text
|
||||
lastPart := parts[len(parts)-1]
|
||||
if lastPart.Get("thought").Bool() {
|
||||
t.Error("Trailing unsigned thinking block should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
||||
// Last assistant message ends with signed thinking block - should be kept
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here is my answer"},
|
||||
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// The signed thinking block should be preserved
|
||||
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
|
||||
parts := lastMessageParts.Array()
|
||||
if len(parts) < 2 {
|
||||
t.Error("Signed thinking block should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_Removed(t *testing.T) {
|
||||
// Middle message has unsigned thinking - should be removed entirely
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Middle thinking..."},
|
||||
{"type": "text", "text": "Answer"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Follow up"}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Unsigned thinking should be removed entirely
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Only text part should remain
|
||||
if parts[0].Get("thought").Bool() {
|
||||
t.Error("Thinking block should be removed, not preserved")
|
||||
}
|
||||
if parts[0].Get("text").String() != "Answer" {
|
||||
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool + Thinking System Hint Injection
|
||||
// ============================================================================
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_HintInjected(t *testing.T) {
|
||||
// When both tools and thinking are enabled, hint should be injected into system instruction
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"system": [{"type": "text", "text": "You are helpful."}],
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
|
||||
}
|
||||
],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// System instruction should contain the interleaved thinking hint
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if !sysInstruction.Exists() {
|
||||
t.Fatal("systemInstruction should exist")
|
||||
}
|
||||
|
||||
// Check if hint is appended
|
||||
sysText := sysInstruction.Get("parts").Array()
|
||||
found := false
|
||||
for _, part := range sysText {
|
||||
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Interleaved thinking hint should be injected when tools and thinking are both active, got: %v", sysInstruction.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolsOnly_NoHint(t *testing.T) {
|
||||
// When only tools are present (no thinking), hint should NOT be injected
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"system": [{"type": "text", "text": "You are helpful."}],
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// System instruction should NOT contain the hint
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if sysInstruction.Exists() {
|
||||
for _, part := range sysInstruction.Get("parts").Array() {
|
||||
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
|
||||
t.Error("Hint should NOT be injected when only tools are present (no thinking)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
|
||||
// When only thinking is enabled (no tools), hint should NOT be injected
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"system": [{"type": "text", "text": "You are helpful."}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// System instruction should NOT contain the hint (no tools)
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if sysInstruction.Exists() {
|
||||
for _, part := range sysInstruction.Get("parts").Array() {
|
||||
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
|
||||
t.Error("Hint should NOT be injected when only thinking is present (no tools)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||
// When tools + thinking but no system instruction, should create one with hint
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
|
||||
}
|
||||
],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// System instruction should be created with hint
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if !sysInstruction.Exists() {
|
||||
t.Fatal("systemInstruction should be created when tools + thinking are active")
|
||||
}
|
||||
|
||||
sysText := sysInstruction.Get("parts").Array()
|
||||
found := false
|
||||
for _, part := range sysText {
|
||||
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
|
||||
}
|
||||
}
|
||||
@@ -9,12 +9,14 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -33,9 +35,14 @@ type Params struct {
|
||||
CandidatesTokenCount int64 // Cached candidate token count from usage metadata
|
||||
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
|
||||
TotalTokenCount int64 // Cached total token count from usage metadata
|
||||
CachedTokenCount int64 // Cached content token count (indicates prompt caching)
|
||||
HasSentFinalEvents bool // Indicates if final content/message events have been sent
|
||||
HasToolUse bool // Indicates if tool use was observed in the stream
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
|
||||
// Signature caching support
|
||||
SessionID string // Session ID derived from request for signature caching
|
||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -63,6 +70,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
SessionID: deriveSessionID(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +99,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// This follows the Claude Code API specification for streaming message initialization
|
||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||
|
||||
// Use cpaUsageMetadata within the message_start event for Claude.
|
||||
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
|
||||
}
|
||||
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
|
||||
}
|
||||
|
||||
// Override default values with actual response metadata if available from the Gemini CLI response
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
@@ -120,11 +136,20 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Process thinking content (internal reasoning)
|
||||
if partResult.Get("thought").Bool() {
|
||||
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||
log.Debug("Branch: signature_delta")
|
||||
|
||||
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
|
||||
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||
log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
||||
params.CurrentThinkingText.Reset()
|
||||
}
|
||||
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.HasContent = true
|
||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||
params.CurrentThinkingText.WriteString(partTextResult.String())
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
@@ -153,6 +178,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.ResponseType = 2 // Set state to thinking
|
||||
params.HasContent = true
|
||||
// Start accumulating thinking text for signature caching
|
||||
params.CurrentThinkingText.Reset()
|
||||
params.CurrentThinkingText.WriteString(partTextResult.String())
|
||||
}
|
||||
} else {
|
||||
finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason")
|
||||
@@ -251,7 +279,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
params.HasUsageMetadata = true
|
||||
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int()
|
||||
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
|
||||
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
|
||||
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
|
||||
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
|
||||
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
|
||||
@@ -303,6 +332,14 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
||||
*output = *output + "event: message_delta\n"
|
||||
*output = *output + "data: "
|
||||
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
|
||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||
if params.CachedTokenCount > 0 {
|
||||
var err error
|
||||
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
*output = *output + delta + "\n\n\n"
|
||||
|
||||
params.HasSentFinalEvents = true
|
||||
@@ -342,6 +379,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
|
||||
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
|
||||
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
|
||||
cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int()
|
||||
outputTokens := candidateTokens + thoughtTokens
|
||||
if outputTokens == 0 && totalTokens > 0 {
|
||||
outputTokens = totalTokens - promptTokens
|
||||
@@ -350,24 +388,33 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": root.Get("response.responseId").String(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": root.Get("response.modelVersion").String(),
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": promptTokens,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||
if cachedTokens > 0 {
|
||||
var err error
|
||||
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
contentArrayInitialized := false
|
||||
ensureContentArray := func() {
|
||||
if contentArrayInitialized {
|
||||
return
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
|
||||
contentArrayInitialized = true
|
||||
}
|
||||
|
||||
parts := root.Get("response.candidates.0.content.parts")
|
||||
var contentBlocks []interface{}
|
||||
textBuilder := strings.Builder{}
|
||||
thinkingBuilder := strings.Builder{}
|
||||
thinkingSignature := ""
|
||||
toolIDCounter := 0
|
||||
hasToolCall := false
|
||||
|
||||
@@ -375,28 +422,43 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
if textBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": textBuilder.String(),
|
||||
})
|
||||
ensureContentArray()
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", textBuilder.String())
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
textBuilder.Reset()
|
||||
}
|
||||
|
||||
flushThinking := func() {
|
||||
if thinkingBuilder.Len() == 0 {
|
||||
if thinkingBuilder.Len() == 0 && thinkingSignature == "" {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkingBuilder.String(),
|
||||
})
|
||||
ensureContentArray()
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
if thinkingSignature != "" {
|
||||
block, _ = sjson.Set(block, "signature", thinkingSignature)
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
thinkingSignature = ""
|
||||
}
|
||||
|
||||
if parts.IsArray() {
|
||||
for _, part := range parts.Array() {
|
||||
isThought := part.Get("thought").Bool()
|
||||
if isThought {
|
||||
sig := part.Get("thoughtSignature")
|
||||
if !sig.Exists() {
|
||||
sig = part.Get("thought_signature")
|
||||
}
|
||||
if sig.Exists() && sig.String() != "" {
|
||||
thinkingSignature = sig.String()
|
||||
}
|
||||
}
|
||||
|
||||
if text := part.Get("text"); text.Exists() && text.String() != "" {
|
||||
if part.Get("thought").Bool() {
|
||||
if isThought {
|
||||
flushText()
|
||||
thinkingBuilder.WriteString(text.String())
|
||||
continue
|
||||
@@ -413,21 +475,16 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
toolIDCounter++
|
||||
toolBlock := map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("tool_%d", toolIDCounter),
|
||||
"name": name,
|
||||
"input": map[string]interface{}{},
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
|
||||
}
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() {
|
||||
var parsed interface{}
|
||||
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
|
||||
toolBlock["input"] = parsed
|
||||
}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, toolBlock)
|
||||
ensureContentArray()
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -436,8 +493,6 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
flushThinking()
|
||||
flushText()
|
||||
|
||||
response["content"] = contentBlocks
|
||||
|
||||
stopReason := "end_turn"
|
||||
if hasToolCall {
|
||||
stopReason = "tool_use"
|
||||
@@ -453,19 +508,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
}
|
||||
response["stop_reason"] = stopReason
|
||||
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
|
||||
|
||||
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) {
|
||||
if promptTokens == 0 && outputTokens == 0 {
|
||||
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
||||
delete(response, "usage")
|
||||
responseJSON, _ = sjson.Delete(responseJSON, "usage")
|
||||
}
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(encoded)
|
||||
return responseJSON
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// Signature Caching Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Request with user message - should derive session ID
|
||||
requestJSON := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
// First response chunk with thinking
|
||||
responseJSON := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "Let me think...", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
ctx := context.Background()
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m)
|
||||
|
||||
// Verify session ID was set
|
||||
params := param.(*Params)
|
||||
if params.SessionID == "" {
|
||||
t.Error("SessionID should be derived from request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_ThinkingTextAccumulated(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Test"}]}]
|
||||
}`)
|
||||
|
||||
// First thinking chunk
|
||||
chunk1 := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "First part of thinking...", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
// Second thinking chunk (continuation)
|
||||
chunk2 := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": " Second part of thinking...", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
ctx := context.Background()
|
||||
|
||||
// Process first chunk - starts new thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk1, ¶m)
|
||||
params := param.(*Params)
|
||||
|
||||
if params.CurrentThinkingText.Len() == 0 {
|
||||
t.Error("Thinking text should be accumulated after first chunk")
|
||||
}
|
||||
|
||||
// Process second chunk - continues thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk2, ¶m)
|
||||
|
||||
text := params.CurrentThinkingText.String()
|
||||
if !strings.Contains(text, "First part") || !strings.Contains(text, "Second part") {
|
||||
t.Errorf("Thinking text should accumulate both parts, got: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
|
||||
}`)
|
||||
|
||||
// Thinking chunk
|
||||
thinkingChunk := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "My thinking process here", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
// Signature chunk
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
signatureChunk := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSignature + `"}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
ctx := context.Background()
|
||||
|
||||
// Process thinking chunk
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m)
|
||||
params := param.(*Params)
|
||||
sessionID := params.SessionID
|
||||
thinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
if sessionID == "" {
|
||||
t.Fatal("SessionID should be set")
|
||||
}
|
||||
if thinkingText == "" {
|
||||
t.Fatal("Thinking text should be accumulated")
|
||||
}
|
||||
|
||||
// Process signature chunk - should cache the signature
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m)
|
||||
|
||||
// Verify signature was cached
|
||||
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
|
||||
if cachedSig != validSignature {
|
||||
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
|
||||
}
|
||||
|
||||
// Verify thinking text was reset after caching
|
||||
if params.CurrentThinkingText.Len() != 0 {
|
||||
t.Error("Thinking text should be reset after signature is cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
|
||||
}`)
|
||||
|
||||
validSig1 := "signature1_12345678901234567890123456789012345678901234567"
|
||||
validSig2 := "signature2_12345678901234567890123456789012345678901234567"
|
||||
|
||||
// First thinking block with signature
|
||||
block1Thinking := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "First thinking block", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
block1Sig := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig1 + `"}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
// Text content (breaks thinking)
|
||||
textBlock := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "Regular text output"}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
// Second thinking block with signature
|
||||
block2Thinking := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "Second thinking block", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
block2Sig := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig2 + `"}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
ctx := context.Background()
|
||||
|
||||
// Process first thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m)
|
||||
params := param.(*Params)
|
||||
sessionID := params.SessionID
|
||||
firstThinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m)
|
||||
|
||||
// Verify first signature cached
|
||||
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
|
||||
t.Error("First thinking block signature should be cached")
|
||||
}
|
||||
|
||||
// Process text (transitions out of thinking)
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, textBlock, ¶m)
|
||||
|
||||
// Process second thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Thinking, ¶m)
|
||||
secondThinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m)
|
||||
|
||||
// Verify second signature cached
|
||||
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
|
||||
t.Error("Second thinking block signature should be cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "valid user message",
|
||||
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "user message with content array",
|
||||
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "no user message",
|
||||
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty messages",
|
||||
input: []byte(`{"messages": []}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "no messages field",
|
||||
input: []byte(`{}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := deriveSessionID(tt.input)
|
||||
if tt.wantEmpty && result != "" {
|
||||
t.Errorf("Expected empty session ID, got '%s'", result)
|
||||
}
|
||||
if !tt.wantEmpty && result == "" {
|
||||
t.Error("Expected non-empty session ID")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
|
||||
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
|
||||
|
||||
id1 := deriveSessionID(input)
|
||||
id2 := deriveSessionID(input)
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
|
||||
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
|
||||
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
|
||||
|
||||
id1 := deriveSessionID(input1)
|
||||
id2 := deriveSessionID(input2)
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("Different messages should produce different session IDs")
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -98,16 +97,34 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
}
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
|
||||
// Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
|
||||
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
// First pass: collect indices of thinking parts to remove
|
||||
var thinkingIndicesToRemove []int64
|
||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||
// Mark thinking blocks for removal
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
|
||||
}
|
||||
// Add skip sentinel to functionCall parts
|
||||
if part.Get("functionCall").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
existingSig := part.Get("thoughtSignature").String()
|
||||
if existingSig == "" || len(existingSig) < 50 {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Remove thinking blocks in reverse order to preserve indices
|
||||
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
|
||||
idx := thinkingIndicesToRemove[i]
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -117,11 +134,33 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
|
||||
|
||||
// FunctionCallGroup represents a group of function calls and their responses
|
||||
type FunctionCallGroup struct {
|
||||
ModelContent map[string]interface{}
|
||||
FunctionCalls []gjson.Result
|
||||
ResponsesNeeded int
|
||||
}
|
||||
|
||||
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
|
||||
// Falls back to a minimal "functionResponse" object when parsing fails.
|
||||
func parseFunctionResponseRaw(response gjson.Result) string {
|
||||
if response.IsObject() && gjson.Valid(response.Raw) {
|
||||
return response.Raw
|
||||
}
|
||||
|
||||
log.Debugf("parse function response failed, using fallback")
|
||||
funcResp := response.Get("functionResponse")
|
||||
if funcResp.Exists() {
|
||||
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
|
||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
||||
if id := funcResp.Get("id").String(); id != "" {
|
||||
fr, _ = sjson.Set(fr, "functionResponse.id", id)
|
||||
}
|
||||
return fr
|
||||
}
|
||||
|
||||
fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
|
||||
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
|
||||
return fr
|
||||
}
|
||||
|
||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
// This function transforms the CLI tool response format by intelligently grouping function calls
|
||||
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
||||
@@ -146,7 +185,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
}
|
||||
|
||||
// Initialize data structures for processing and grouping
|
||||
var newContents []interface{} // Final processed contents array
|
||||
contentsWrapper := `{"contents":[]}`
|
||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||
|
||||
@@ -178,23 +217,16 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
var responseParts []interface{}
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
partRaw := parseFunctionResponseRaw(response)
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
functionResponseContent := map[string]interface{}{
|
||||
"parts": responseParts,
|
||||
"role": "function",
|
||||
}
|
||||
newContents = append(newContents, functionResponseContent)
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
|
||||
// Remove this group as it's been satisfied
|
||||
@@ -208,50 +240,42 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
// If this is a model with function calls, create a new group
|
||||
if role == "model" {
|
||||
var functionCallsInThisModel []gjson.Result
|
||||
functionCallsCount := 0
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
functionCallsInThisModel = append(functionCallsInThisModel, part)
|
||||
functionCallsCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(functionCallsInThisModel) > 0 {
|
||||
if functionCallsCount > 0 {
|
||||
// Add the model content
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse model content")
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
ModelContent: contentMap,
|
||||
FunctionCalls: functionCallsInThisModel,
|
||||
ResponsesNeeded: len(functionCallsInThisModel),
|
||||
ResponsesNeeded: functionCallsCount,
|
||||
}
|
||||
pendingGroups = append(pendingGroups, group)
|
||||
} else {
|
||||
// Regular model content without function calls
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
}
|
||||
} else {
|
||||
// Non-model content (user, etc.)
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -263,31 +287,23 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
var responseParts []interface{}
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
partRaw := parseFunctionResponseRaw(response)
|
||||
if partRaw != "" {
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
functionResponseContent := map[string]interface{}{
|
||||
"parts": responseParts,
|
||||
"role": "function",
|
||||
}
|
||||
newContents = append(newContents, functionResponseContent)
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the original JSON with the new contents
|
||||
result := input
|
||||
newContentsJSON, _ := json.Marshal(newContents)
|
||||
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
|
||||
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T) {
|
||||
// Valid signature on functionCall should be preserved
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(fmt.Sprintf(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "test_tool", "args": {}}, "thoughtSignature": "%s"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`, validSignature))
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that valid thoughtSignature is preserved
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part, got %d", len(parts))
|
||||
}
|
||||
|
||||
sig := parts[0].Get("thoughtSignature").String()
|
||||
if sig != validSignature {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *testing.T) {
|
||||
// functionCall without signature should get skip_thought_signature_validator
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "test_tool", "args": {}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that skip_thought_signature_validator is added to functionCall
|
||||
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
if sig != expectedSig {
|
||||
t.Errorf("Expected skip sentinel '%s', got '%s'", expectedSig, sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) {
|
||||
// Thinking blocks should be removed entirely for Gemini
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(fmt.Sprintf(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"thought": true, "text": "Thinking...", "thoughtSignature": "%s"},
|
||||
{"text": "Here is my response"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`, validSignature))
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that thinking block is removed
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Only text part should remain
|
||||
if parts[0].Get("thought").Bool() {
|
||||
t.Error("Thinking block should be removed for Gemini")
|
||||
}
|
||||
if parts[0].Get("text").String() != "Here is my response" {
|
||||
t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
||||
// Multiple functionCalls should all get skip_thought_signature_validator
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "tool_one", "args": {"a": "1"}}},
|
||||
{"functionCall": {"name": "tool_two", "args": {"b": "2"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
for i, part := range parts {
|
||||
sig := part.Get("thoughtSignature").String()
|
||||
if sig != expectedSig {
|
||||
t.Errorf("Part %d: Expected '%s', got '%s'", i, expectedSig, sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -40,30 +40,27 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
effort := strings.ToLower(strings.TrimSpace(re.String()))
|
||||
if util.IsGemini3Model(modelName) {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig")
|
||||
case "auto":
|
||||
includeThoughts := true
|
||||
out = util.ApplyGeminiCLIThinkingLevel(out, "", &includeThoughts)
|
||||
default:
|
||||
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
|
||||
out = util.ApplyGeminiCLIThinkingLevel(out, level, nil)
|
||||
}
|
||||
}
|
||||
} else if !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, effort)
|
||||
}
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
@@ -195,6 +192,14 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else if content.IsObject() && content.Get("type").String() == "text" {
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
|
||||
} else if content.IsArray() {
|
||||
contents := content.Array()
|
||||
if len(contents) > 0 {
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
|
||||
for j := 0; j < len(contents); j++ {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if role == "user" || (role == "system" && len(arr) == 1) {
|
||||
// Build single user content node to avoid splitting into multiple contents
|
||||
@@ -240,64 +245,89 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
if content.Type == gjson.String && content.String() != "" {
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||
for _, item := range content.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
case "image_url":
|
||||
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
||||
imageURL := item.Get("image_url.url").String()
|
||||
if len(imageURL) > 5 { // expect data:...
|
||||
pieces := strings.SplitN(imageURL[5:], ";", 2)
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
p++
|
||||
}
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
if gjson.Valid(fargs) {
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
} else {
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", []byte(fargs))
|
||||
}
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -323,7 +353,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
@@ -338,7 +368,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
@@ -379,18 +409,3 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -8,12 +8,13 @@ package chat_completions
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -86,18 +87,27 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
@@ -171,21 +181,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagePayload, err := json.Marshal(map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]string{
|
||||
"url": imageURL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload))
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,14 +114,16 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
}
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() {
|
||||
if includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", thinkingBudget.Int())
|
||||
}
|
||||
}
|
||||
// Only apply for models that support thinking and use numeric budgets, not discrete levels.
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
// Check for thinkingBudget first - if present, enable thinking with budget
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() && thinkingBudget.Int() > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
normalizedBudget := util.NormalizeThinkingBudget(modelName, int(thinkingBudget.Int()))
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", normalizedBudget)
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
// Fallback to include_thoughts if no budget specified
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -192,7 +194,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if name := fc.Get("name"); name.Exists() {
|
||||
toolUse, _ = sjson.Set(toolUse, "name", name.String())
|
||||
}
|
||||
if args := fc.Get("args"); args.Exists() {
|
||||
if args := fc.Get("args"); args.Exists() && args.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
|
||||
@@ -312,11 +314,11 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
if mode := funcCalling.Get("mode"); mode.Exists() {
|
||||
switch mode.String() {
|
||||
case "AUTO":
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"})
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
case "NONE":
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "none"})
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
|
||||
case "ANY":
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"})
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,51 +263,6 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
|
||||
}
|
||||
}
|
||||
|
||||
// convertArrayToJSON converts []interface{} to JSON array string
|
||||
func convertArrayToJSON(arr []interface{}) string {
|
||||
result := "[]"
|
||||
for _, item := range arr {
|
||||
switch itemData := item.(type) {
|
||||
case map[string]interface{}:
|
||||
itemJSON := convertMapToJSON(itemData)
|
||||
result, _ = sjson.SetRaw(result, "-1", itemJSON)
|
||||
case string:
|
||||
result, _ = sjson.Set(result, "-1", itemData)
|
||||
case bool:
|
||||
result, _ = sjson.Set(result, "-1", itemData)
|
||||
case float64, int, int64:
|
||||
result, _ = sjson.Set(result, "-1", itemData)
|
||||
default:
|
||||
result, _ = sjson.Set(result, "-1", itemData)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// convertMapToJSON converts map[string]interface{} to JSON object string
|
||||
func convertMapToJSON(m map[string]interface{}) string {
|
||||
result := "{}"
|
||||
for key, value := range m {
|
||||
switch val := value.(type) {
|
||||
case map[string]interface{}:
|
||||
nestedJSON := convertMapToJSON(val)
|
||||
result, _ = sjson.SetRaw(result, key, nestedJSON)
|
||||
case []interface{}:
|
||||
arrayJSON := convertArrayToJSON(val)
|
||||
result, _ = sjson.SetRaw(result, key, arrayJSON)
|
||||
case string:
|
||||
result, _ = sjson.Set(result, key, val)
|
||||
case bool:
|
||||
result, _ = sjson.Set(result, key, val)
|
||||
case float64, int, int64:
|
||||
result, _ = sjson.Set(result, key, val)
|
||||
default:
|
||||
result, _ = sjson.Set(result, key, val)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConvertClaudeResponseToGeminiNonStream converts a non-streaming Claude Code response to a non-streaming Gemini response.
|
||||
// This function processes the complete Claude Code response and transforms it into a single Gemini-compatible
|
||||
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
|
||||
@@ -356,8 +311,8 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
}
|
||||
|
||||
// Process each streaming event and collect parts
|
||||
var allParts []interface{}
|
||||
var finalUsage map[string]interface{}
|
||||
var allParts []string
|
||||
var finalUsageJSON string
|
||||
var responseID string
|
||||
var createdAt int64
|
||||
|
||||
@@ -407,16 +362,14 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
if text := delta.Get("text"); text.Exists() && text.String() != "" {
|
||||
partJSON := `{"text":""}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", text.String())
|
||||
part := gjson.Parse(partJSON).Value().(map[string]interface{})
|
||||
allParts = append(allParts, part)
|
||||
allParts = append(allParts, partJSON)
|
||||
}
|
||||
case "thinking_delta":
|
||||
// Process reasoning/thinking content
|
||||
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
|
||||
partJSON := `{"thought":true,"text":""}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", text.String())
|
||||
part := gjson.Parse(partJSON).Value().(map[string]interface{})
|
||||
allParts = append(allParts, part)
|
||||
allParts = append(allParts, partJSON)
|
||||
}
|
||||
case "input_json_delta":
|
||||
// accumulate args partial_json for this index
|
||||
@@ -456,9 +409,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
if argsTrim != "" {
|
||||
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
|
||||
}
|
||||
// Parse back to interface{} for allParts
|
||||
functionCall := gjson.Parse(functionCallJSON).Value().(map[string]interface{})
|
||||
allParts = append(allParts, functionCall)
|
||||
allParts = append(allParts, functionCallJSON)
|
||||
// cleanup used state for this index
|
||||
if newParam.ToolUseArgs != nil {
|
||||
delete(newParam.ToolUseArgs, idx)
|
||||
@@ -501,8 +452,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
// Set traffic type (required by Gemini API)
|
||||
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
|
||||
|
||||
// Convert to map[string]interface{} using gjson
|
||||
finalUsage = gjson.Parse(usageJSON).Value().(map[string]interface{})
|
||||
finalUsageJSON = usageJSON
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -520,12 +470,16 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Set the consolidated parts array
|
||||
if len(consolidatedParts) > 0 {
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", convertToJSONString(consolidatedParts))
|
||||
partsJSON := "[]"
|
||||
for _, partJSON := range consolidatedParts {
|
||||
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
|
||||
}
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
|
||||
}
|
||||
|
||||
// Set usage metadata
|
||||
if finalUsage != nil {
|
||||
template, _ = sjson.SetRaw(template, "usageMetadata", convertToJSONString(finalUsage))
|
||||
if finalUsageJSON != "" {
|
||||
template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
|
||||
}
|
||||
|
||||
return template
|
||||
@@ -539,12 +493,12 @@ func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
// This function processes the parts array to combine adjacent text elements and thinking elements
|
||||
// into single consolidated parts, which results in a more readable and efficient response structure.
|
||||
// Tool calls and other non-text parts are preserved as separate elements.
|
||||
func consolidateParts(parts []interface{}) []interface{} {
|
||||
func consolidateParts(parts []string) []string {
|
||||
if len(parts) == 0 {
|
||||
return parts
|
||||
}
|
||||
|
||||
var consolidated []interface{}
|
||||
var consolidated []string
|
||||
var currentTextPart strings.Builder
|
||||
var currentThoughtPart strings.Builder
|
||||
var hasText, hasThought bool
|
||||
@@ -554,8 +508,7 @@ func consolidateParts(parts []interface{}) []interface{} {
|
||||
if hasText && currentTextPart.Len() > 0 {
|
||||
textPartJSON := `{"text":""}`
|
||||
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
|
||||
textPart := gjson.Parse(textPartJSON).Value().(map[string]interface{})
|
||||
consolidated = append(consolidated, textPart)
|
||||
consolidated = append(consolidated, textPartJSON)
|
||||
currentTextPart.Reset()
|
||||
hasText = false
|
||||
}
|
||||
@@ -566,42 +519,42 @@ func consolidateParts(parts []interface{}) []interface{} {
|
||||
if hasThought && currentThoughtPart.Len() > 0 {
|
||||
thoughtPartJSON := `{"thought":true,"text":""}`
|
||||
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
|
||||
thoughtPart := gjson.Parse(thoughtPartJSON).Value().(map[string]interface{})
|
||||
consolidated = append(consolidated, thoughtPart)
|
||||
consolidated = append(consolidated, thoughtPartJSON)
|
||||
currentThoughtPart.Reset()
|
||||
hasThought = false
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range parts {
|
||||
partMap, ok := part.(map[string]interface{})
|
||||
if !ok {
|
||||
for _, partJSON := range parts {
|
||||
part := gjson.Parse(partJSON)
|
||||
if !part.Exists() || !part.IsObject() {
|
||||
// Flush any pending parts and add this non-text part
|
||||
flushText()
|
||||
flushThought()
|
||||
consolidated = append(consolidated, part)
|
||||
consolidated = append(consolidated, partJSON)
|
||||
continue
|
||||
}
|
||||
|
||||
if thought, isThought := partMap["thought"]; isThought && thought == true {
|
||||
thought := part.Get("thought")
|
||||
if thought.Exists() && thought.Type == gjson.True {
|
||||
// This is a thinking part - flush any pending text first
|
||||
flushText() // Flush any pending text first
|
||||
|
||||
if text, hasTextContent := partMap["text"].(string); hasTextContent {
|
||||
currentThoughtPart.WriteString(text)
|
||||
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
currentThoughtPart.WriteString(text.String())
|
||||
hasThought = true
|
||||
}
|
||||
} else if text, hasTextContent := partMap["text"].(string); hasTextContent {
|
||||
} else if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
// This is a regular text part - flush any pending thought first
|
||||
flushThought() // Flush any pending thought first
|
||||
|
||||
currentTextPart.WriteString(text)
|
||||
currentTextPart.WriteString(text.String())
|
||||
hasText = true
|
||||
} else {
|
||||
// This is some other type of part (like function call) - flush both text and thought
|
||||
flushText()
|
||||
flushThought()
|
||||
consolidated = append(consolidated, part)
|
||||
consolidated = append(consolidated, partJSON)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -611,20 +564,3 @@ func consolidateParts(parts []interface{}) []interface{} {
|
||||
|
||||
return consolidated
|
||||
}
|
||||
|
||||
// convertToJSONString converts interface{} to JSON string using sjson/gjson.
|
||||
// This function provides a consistent way to serialize different data types to JSON strings
|
||||
// for inclusion in the Gemini API response structure.
|
||||
func convertToJSONString(v interface{}) string {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
return convertArrayToJSON(val)
|
||||
case map[string]interface{}:
|
||||
return convertMapToJSON(val)
|
||||
default:
|
||||
// For simple types, create a temporary JSON and extract the value
|
||||
temp := `{"temp":null}`
|
||||
temp, _ = sjson.Set(temp, "temp", val)
|
||||
return gjson.Get(temp, "temp").Raw
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,12 +10,12 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -65,18 +65,23 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
if v := root.Get("reasoning_effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
switch v.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 1024)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 8192)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 24576)
|
||||
if v := root.Get("reasoning_effort"); v.Exists() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := util.ThinkingEffortToBudget(modelName, effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,9 +136,6 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
|
||||
// Process messages and transform them to Claude Code format
|
||||
var anthropicMessages []interface{}
|
||||
var toolCallIDs []string // Track tool call IDs for matching with tool results
|
||||
|
||||
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
role := message.Get("role").String()
|
||||
@@ -146,33 +148,23 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
role = "user"
|
||||
}
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"role": role,
|
||||
"content": []interface{}{},
|
||||
}
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
|
||||
// Handle content based on its type (string or array)
|
||||
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
|
||||
// Simple text content conversion
|
||||
msg["content"] = []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": contentResult.String(),
|
||||
},
|
||||
}
|
||||
part := `{"type":"text","text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentResult.String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
} else if contentResult.Exists() && contentResult.IsArray() {
|
||||
// Array of content parts processing
|
||||
var contentParts []interface{}
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "text":
|
||||
// Text part conversion
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": part.Get("text").String(),
|
||||
})
|
||||
textPart := `{"type":"text","text":""}`
|
||||
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
|
||||
|
||||
case "image_url":
|
||||
// Convert OpenAI image format to Claude Code format
|
||||
@@ -185,132 +177,95 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
|
||||
data := parts[1]
|
||||
|
||||
contentParts = append(contentParts, map[string]interface{}{
|
||||
"type": "image",
|
||||
"source": map[string]interface{}{
|
||||
"type": "base64",
|
||||
"media_type": mediaType,
|
||||
"data": data,
|
||||
},
|
||||
})
|
||||
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
|
||||
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
|
||||
imagePart, _ = sjson.Set(imagePart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if len(contentParts) > 0 {
|
||||
msg["content"] = contentParts
|
||||
}
|
||||
} else {
|
||||
// Initialize empty content array for tool calls
|
||||
msg["content"] = []interface{}{}
|
||||
}
|
||||
|
||||
// Handle tool calls (for assistant messages)
|
||||
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() && role == "assistant" {
|
||||
var contentParts []interface{}
|
||||
|
||||
// Add existing text content if any
|
||||
if existingContent, ok := msg["content"].([]interface{}); ok {
|
||||
contentParts = existingContent
|
||||
}
|
||||
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
if toolCall.Get("type").String() == "function" {
|
||||
toolCallID := toolCall.Get("id").String()
|
||||
if toolCallID == "" {
|
||||
toolCallID = genToolCallID()
|
||||
}
|
||||
toolCallIDs = append(toolCallIDs, toolCallID)
|
||||
|
||||
function := toolCall.Get("function")
|
||||
toolUse := map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolCallID,
|
||||
"name": function.Get("name").String(),
|
||||
}
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
|
||||
|
||||
// Parse arguments for the tool call
|
||||
if args := function.Get("arguments"); args.Exists() {
|
||||
argsStr := args.String()
|
||||
if argsStr != "" {
|
||||
var argsMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
|
||||
toolUse["input"] = argsMap
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
if argsJSON.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
|
||||
} else {
|
||||
toolUse["input"] = map[string]interface{}{}
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
}
|
||||
} else {
|
||||
toolUse["input"] = map[string]interface{}{}
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
}
|
||||
} else {
|
||||
toolUse["input"] = map[string]interface{}{}
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
|
||||
}
|
||||
|
||||
contentParts = append(contentParts, toolUse)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
|
||||
}
|
||||
return true
|
||||
})
|
||||
msg["content"] = contentParts
|
||||
}
|
||||
|
||||
anthropicMessages = append(anthropicMessages, msg)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
|
||||
case "tool":
|
||||
// Handle tool result messages conversion
|
||||
toolCallID := message.Get("tool_call_id").String()
|
||||
content := message.Get("content").String()
|
||||
|
||||
// Create tool result message in Claude Code format
|
||||
msg := map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolCallID,
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anthropicMessages = append(anthropicMessages, msg)
|
||||
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
|
||||
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
|
||||
msg, _ = sjson.Set(msg, "content.0.content", content)
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", msg)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Set messages in the output template
|
||||
if len(anthropicMessages) > 0 {
|
||||
messagesJSON, _ := json.Marshal(anthropicMessages)
|
||||
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
|
||||
}
|
||||
|
||||
// Tools mapping: OpenAI tools -> Claude Code tools
|
||||
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
|
||||
var anthropicTools []interface{}
|
||||
hasAnthropicTools := false
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
if tool.Get("type").String() == "function" {
|
||||
function := tool.Get("function")
|
||||
anthropicTool := map[string]interface{}{
|
||||
"name": function.Get("name").String(),
|
||||
"description": function.Get("description").String(),
|
||||
}
|
||||
anthropicTool := `{"name":"","description":""}`
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
|
||||
anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
|
||||
|
||||
// Convert parameters schema for the tool
|
||||
if parameters := function.Get("parameters"); parameters.Exists() {
|
||||
anthropicTool["input_schema"] = parameters.Value()
|
||||
} else if parameters = function.Get("parametersJsonSchema"); parameters.Exists() {
|
||||
anthropicTool["input_schema"] = parameters.Value()
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
|
||||
} else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
|
||||
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
|
||||
}
|
||||
|
||||
anthropicTools = append(anthropicTools, anthropicTool)
|
||||
out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
|
||||
hasAnthropicTools = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(anthropicTools) > 0 {
|
||||
toolsJSON, _ := json.Marshal(anthropicTools)
|
||||
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
|
||||
if !hasAnthropicTools {
|
||||
out, _ = sjson.Delete(out, "tools")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,18 +278,17 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
case "none":
|
||||
// Don't set tool_choice, Claude Code will not use tools
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"})
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
case "required":
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"})
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
}
|
||||
case gjson.JSON:
|
||||
// Specific tool choice mapping
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
functionName := toolChoice.Get("function.name").String()
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{
|
||||
"type": "tool",
|
||||
"name": functionName,
|
||||
})
|
||||
toolChoiceJSON := `{"type":"tool","name":""}`
|
||||
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ package chat_completions
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -178,18 +178,11 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
if arguments == "" {
|
||||
arguments = "{}"
|
||||
}
|
||||
|
||||
toolCall := map[string]interface{}{
|
||||
"index": index,
|
||||
"id": accumulator.ID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": accumulator.Name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls", []interface{}{toolCall})
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
|
||||
|
||||
// Clean up the accumulator for this index
|
||||
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
|
||||
@@ -210,12 +203,14 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
|
||||
// Handle usage information for token counts
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
usageObj := map[string]interface{}{
|
||||
"prompt_tokens": usage.Get("input_tokens").Int(),
|
||||
"completion_tokens": usage.Get("output_tokens").Int(),
|
||||
"total_tokens": usage.Get("input_tokens").Int() + usage.Get("output_tokens").Int(),
|
||||
}
|
||||
template, _ = sjson.Set(template, "usage", usageObj)
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
}
|
||||
return []string{template}
|
||||
|
||||
@@ -230,14 +225,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
case "error":
|
||||
// Error event - format and return error response
|
||||
if errorData := root.Get("error"); errorData.Exists() {
|
||||
errorResponse := map[string]interface{}{
|
||||
"error": map[string]interface{}{
|
||||
"message": errorData.Get("message").String(),
|
||||
"type": errorData.Get("type").String(),
|
||||
},
|
||||
}
|
||||
errorJSON, _ := json.Marshal(errorResponse)
|
||||
return []string{string(errorJSON)}
|
||||
errorJSON := `{"error":{"message":"","type":""}}`
|
||||
errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
|
||||
errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
|
||||
return []string{errorJSON}
|
||||
}
|
||||
return []string{}
|
||||
|
||||
@@ -293,15 +284,10 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
var messageID string
|
||||
var model string
|
||||
var createdAt int64
|
||||
var inputTokens, outputTokens int64
|
||||
var reasoningTokens int64
|
||||
var stopReason string
|
||||
var contentParts []string
|
||||
var reasoningParts []string
|
||||
// Use map to track tool calls by index for proper merging
|
||||
toolCallsMap := make(map[int]map[string]interface{})
|
||||
// Track tool call arguments accumulation
|
||||
toolCallArgsMap := make(map[int]strings.Builder)
|
||||
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
root := gjson.ParseBytes(chunk)
|
||||
@@ -314,9 +300,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
messageID = message.Get("id").String()
|
||||
model = message.Get("model").String()
|
||||
createdAt = time.Now().Unix()
|
||||
if usage := message.Get("usage"); usage.Exists() {
|
||||
inputTokens = usage.Get("input_tokens").Int()
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
@@ -327,18 +310,12 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
// Start of thinking/reasoning content - skip for now as it's handled in delta
|
||||
continue
|
||||
} else if blockType == "tool_use" {
|
||||
// Initialize tool call tracking for this index
|
||||
// Initialize tool call accumulator for this index
|
||||
index := int(root.Get("index").Int())
|
||||
toolCallsMap[index] = map[string]interface{}{
|
||||
"id": contentBlock.Get("id").String(),
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": contentBlock.Get("name").String(),
|
||||
"arguments": "",
|
||||
},
|
||||
toolCallsAccumulator[index] = &ToolCallAccumulator{
|
||||
ID: contentBlock.Get("id").String(),
|
||||
Name: contentBlock.Get("name").String(),
|
||||
}
|
||||
// Initialize arguments builder for this tool call
|
||||
toolCallArgsMap[index] = strings.Builder{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -361,9 +338,8 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
// Accumulate tool call arguments
|
||||
if partialJSON := delta.Get("partial_json"); partialJSON.Exists() {
|
||||
index := int(root.Get("index").Int())
|
||||
if builder, exists := toolCallArgsMap[index]; exists {
|
||||
builder.WriteString(partialJSON.String())
|
||||
toolCallArgsMap[index] = builder
|
||||
if accumulator, exists := toolCallsAccumulator[index]; exists {
|
||||
accumulator.Arguments.WriteString(partialJSON.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -372,14 +348,9 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
case "content_block_stop":
|
||||
// Finalize tool call arguments for this index when content block ends
|
||||
index := int(root.Get("index").Int())
|
||||
if toolCall, exists := toolCallsMap[index]; exists {
|
||||
if builder, argsExists := toolCallArgsMap[index]; argsExists {
|
||||
// Set the accumulated arguments for the tool call
|
||||
arguments := builder.String()
|
||||
if arguments == "" {
|
||||
arguments = "{}"
|
||||
}
|
||||
toolCall["function"].(map[string]interface{})["arguments"] = arguments
|
||||
if accumulator, exists := toolCallsAccumulator[index]; exists {
|
||||
if accumulator.Arguments.Len() == 0 {
|
||||
accumulator.Arguments.WriteString("{}")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,11 +362,14 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
}
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
outputTokens = usage.Get("output_tokens").Int()
|
||||
// Estimate reasoning tokens from accumulated thinking content
|
||||
if len(reasoningParts) > 0 {
|
||||
reasoningTokens = int64(len(strings.Join(reasoningParts, "")) / 4) // Rough estimation
|
||||
}
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
|
||||
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -417,24 +391,35 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
|
||||
// Set tool calls if any were accumulated during processing
|
||||
if len(toolCallsMap) > 0 {
|
||||
// Convert tool calls map to array, preserving order by index
|
||||
var toolCallsArray []interface{}
|
||||
// Find the maximum index to determine the range
|
||||
if len(toolCallsAccumulator) > 0 {
|
||||
toolCallsCount := 0
|
||||
maxIndex := -1
|
||||
for index := range toolCallsMap {
|
||||
for index := range toolCallsAccumulator {
|
||||
if index > maxIndex {
|
||||
maxIndex = index
|
||||
}
|
||||
}
|
||||
// Iterate through all possible indices up to maxIndex
|
||||
|
||||
for i := 0; i <= maxIndex; i++ {
|
||||
if toolCall, exists := toolCallsMap[i]; exists {
|
||||
toolCallsArray = append(toolCallsArray, toolCall)
|
||||
accumulator, exists := toolCallsAccumulator[i]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
arguments := accumulator.Arguments.String()
|
||||
|
||||
idPath := fmt.Sprintf("choices.0.message.tool_calls.%d.id", toolCallsCount)
|
||||
typePath := fmt.Sprintf("choices.0.message.tool_calls.%d.type", toolCallsCount)
|
||||
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
|
||||
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
|
||||
|
||||
out, _ = sjson.Set(out, idPath, accumulator.ID)
|
||||
out, _ = sjson.Set(out, typePath, "function")
|
||||
out, _ = sjson.Set(out, namePath, accumulator.Name)
|
||||
out, _ = sjson.Set(out, argumentsPath, arguments)
|
||||
toolCallsCount++
|
||||
}
|
||||
if len(toolCallsArray) > 0 {
|
||||
out, _ = sjson.Set(out, "choices.0.message.tool_calls", toolCallsArray)
|
||||
if toolCallsCount > 0 {
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
@@ -443,16 +428,5 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
|
||||
}
|
||||
|
||||
// Set usage information including prompt tokens, completion tokens, and total tokens
|
||||
totalTokens := inputTokens + outputTokens
|
||||
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
|
||||
out, _ = sjson.Set(out, "usage.total_tokens", totalTokens)
|
||||
|
||||
// Add reasoning tokens to usage details if any reasoning content was processed
|
||||
if reasoningTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.completion_tokens_details.reasoning_tokens", reasoningTokens)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -52,20 +53,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
if v := root.Get("reasoning.effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
switch v.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 1024)
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 4096)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 8192)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 24576)
|
||||
if v := root.Get("reasoning.effort"); v.Exists() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := util.ThinkingEffortToBudget(modelName, effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,13 +114,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
var builder strings.Builder
|
||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
text := part.Get("text").String()
|
||||
textResult := part.Get("text")
|
||||
text := textResult.String()
|
||||
if builder.Len() > 0 && text != "" {
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
builder.WriteString(text)
|
||||
return true
|
||||
})
|
||||
} else if parts.Type == gjson.String {
|
||||
builder.WriteString(parts.String())
|
||||
}
|
||||
instructionsText = builder.String()
|
||||
if instructionsText != "" {
|
||||
@@ -203,6 +210,8 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if parts.Type == gjson.String {
|
||||
textAggregate.WriteString(parts.String())
|
||||
}
|
||||
|
||||
// Fallback to given role if content types not decisive
|
||||
@@ -250,7 +259,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
toolUse, _ = sjson.Set(toolUse, "id", callID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", name)
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argsStr)
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
if argsJSON.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
asst := `{"role":"assistant","content":[]}`
|
||||
@@ -305,16 +317,18 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
case gjson.String:
|
||||
switch toolChoice.String() {
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"})
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
|
||||
case "none":
|
||||
// Leave unset; implies no tools
|
||||
case "required":
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"})
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
|
||||
}
|
||||
case gjson.JSON:
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
fn := toolChoice.Get("function.name").String()
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "tool", "name": fn})
|
||||
toolChoiceJSON := `{"name":"","type":"tool"}`
|
||||
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
|
||||
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
|
||||
}
|
||||
default:
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
}
|
||||
// response.created
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"instructions":""}}`
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
|
||||
@@ -197,11 +197,11 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
if st.ReasoningActive {
|
||||
if t := d.Get("thinking"); t.Exists() {
|
||||
st.ReasoningBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "text", t.String())
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
}
|
||||
@@ -344,31 +344,20 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
|
||||
// Build response.output from aggregated state
|
||||
var outputs []interface{}
|
||||
outputsWrapper := `{"arr":[]}`
|
||||
// reasoning item (if any)
|
||||
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
|
||||
r := map[string]interface{}{
|
||||
"id": st.ReasoningItemID,
|
||||
"type": "reasoning",
|
||||
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": st.ReasoningBuf.String()}},
|
||||
}
|
||||
outputs = append(outputs, r)
|
||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
// assistant message item (if any text)
|
||||
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
|
||||
m := map[string]interface{}{
|
||||
"id": st.CurrentMsgID,
|
||||
"type": "message",
|
||||
"status": "completed",
|
||||
"content": []interface{}{map[string]interface{}{
|
||||
"type": "output_text",
|
||||
"annotations": []interface{}{},
|
||||
"logprobs": []interface{}{},
|
||||
"text": st.TextBuf.String(),
|
||||
}},
|
||||
"role": "assistant",
|
||||
}
|
||||
outputs = append(outputs, m)
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
// function_call items (in ascending index order for determinism)
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
@@ -395,19 +384,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
if callID == "" && st.CurrentFCID != "" {
|
||||
callID = st.CurrentFCID
|
||||
}
|
||||
item := map[string]interface{}{
|
||||
"id": fmt.Sprintf("fc_%s", callID),
|
||||
"type": "function_call",
|
||||
"status": "completed",
|
||||
"arguments": args,
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
}
|
||||
outputs = append(outputs, item)
|
||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.Set(item, "arguments", args)
|
||||
item, _ = sjson.Set(item, "call_id", callID)
|
||||
item, _ = sjson.Set(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
}
|
||||
if len(outputs) > 0 {
|
||||
completed, _ = sjson.Set(completed, "response.output", outputs)
|
||||
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
|
||||
}
|
||||
|
||||
reasoningTokens := int64(0)
|
||||
@@ -628,27 +614,18 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
}
|
||||
|
||||
// Build output array
|
||||
var outputs []interface{}
|
||||
outputsWrapper := `{"arr":[]}`
|
||||
if reasoningBuf.Len() > 0 {
|
||||
outputs = append(outputs, map[string]interface{}{
|
||||
"id": reasoningItemID,
|
||||
"type": "reasoning",
|
||||
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": reasoningBuf.String()}},
|
||||
})
|
||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", reasoningItemID)
|
||||
item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if currentMsgID != "" || textBuf.Len() > 0 {
|
||||
outputs = append(outputs, map[string]interface{}{
|
||||
"id": currentMsgID,
|
||||
"type": "message",
|
||||
"status": "completed",
|
||||
"content": []interface{}{map[string]interface{}{
|
||||
"type": "output_text",
|
||||
"annotations": []interface{}{},
|
||||
"logprobs": []interface{}{},
|
||||
"text": textBuf.String(),
|
||||
}},
|
||||
"role": "assistant",
|
||||
})
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", currentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", textBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
// Preserve index order
|
||||
@@ -669,18 +646,16 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
outputs = append(outputs, map[string]interface{}{
|
||||
"id": fmt.Sprintf("fc_%s", st.id),
|
||||
"type": "function_call",
|
||||
"status": "completed",
|
||||
"arguments": args,
|
||||
"call_id": st.id,
|
||||
"name": st.name,
|
||||
})
|
||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
|
||||
item, _ = sjson.Set(item, "arguments", args)
|
||||
item, _ = sjson.Set(item, "call_id", st.id)
|
||||
item, _ = sjson.Set(item, "name", st.name)
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
}
|
||||
if len(outputs) > 0 {
|
||||
out, _ = sjson.Set(out, "output", outputs)
|
||||
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
|
||||
out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
|
||||
}
|
||||
|
||||
// Usage
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -214,7 +215,27 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Add additional configuration parameters for the Codex API.
|
||||
template, _ = sjson.Set(template, "parallel_tool_calls", true)
|
||||
template, _ = sjson.Set(template, "reasoning.effort", "low")
|
||||
|
||||
// Convert thinking.budget_tokens to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if thinking := rootResult.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
switch thinking.Get("type").String() {
|
||||
case "enabled":
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
|
||||
template, _ = sjson.Set(template, "reasoning.summary", "auto")
|
||||
template, _ = sjson.Set(template, "stream", true)
|
||||
template, _ = sjson.Set(template, "store", false)
|
||||
|
||||
@@ -9,7 +9,6 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -191,21 +190,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
return ""
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": responseData.Get("id").String(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": responseData.Get("model").String(),
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": responseData.Get("usage.input_tokens").Int(),
|
||||
"output_tokens": responseData.Get("usage.output_tokens").Int(),
|
||||
},
|
||||
}
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
|
||||
|
||||
var contentBlocks []interface{}
|
||||
hasToolCall := false
|
||||
|
||||
if output := responseData.Get("output"); output.Exists() && output.IsArray() {
|
||||
@@ -244,10 +234,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
}
|
||||
}
|
||||
if thinkingBuilder.Len() > 0 {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkingBuilder.String(),
|
||||
})
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
}
|
||||
case "message":
|
||||
if content := item.Get("content"); content.Exists() {
|
||||
@@ -256,10 +245,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
if part.Get("type").String() == "output_text" {
|
||||
text := part.Get("text").String()
|
||||
if text != "" {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
})
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -267,10 +255,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
} else {
|
||||
text := content.String()
|
||||
if text != "" {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
})
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -281,54 +268,41 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
name = original
|
||||
}
|
||||
|
||||
toolBlock := map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": item.Get("call_id").String(),
|
||||
"name": name,
|
||||
"input": map[string]interface{}{},
|
||||
}
|
||||
|
||||
if argsStr := item.Get("arguments").String(); argsStr != "" {
|
||||
var args interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
|
||||
toolBlock["input"] = args
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
inputRaw := "{}"
|
||||
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
|
||||
argsJSON := gjson.Parse(argsStr)
|
||||
if argsJSON.IsObject() {
|
||||
inputRaw = argsJSON.Raw
|
||||
}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, toolBlock)
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if len(contentBlocks) > 0 {
|
||||
response["content"] = contentBlocks
|
||||
}
|
||||
|
||||
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
|
||||
response["stop_reason"] = stopReason.String()
|
||||
out, _ = sjson.Set(out, "stop_reason", stopReason.String())
|
||||
} else if hasToolCall {
|
||||
response["stop_reason"] = "tool_use"
|
||||
out, _ = sjson.Set(out, "stop_reason", "tool_use")
|
||||
} else {
|
||||
response["stop_reason"] = "end_turn"
|
||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||
}
|
||||
|
||||
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
|
||||
response["stop_sequence"] = stopSequence.Value()
|
||||
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
|
||||
}
|
||||
|
||||
if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() {
|
||||
response["usage"] = map[string]interface{}{
|
||||
"input_tokens": responseData.Get("usage.input_tokens").Int(),
|
||||
"output_tokens": responseData.Get("usage.output_tokens").Int(),
|
||||
}
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
|
||||
}
|
||||
|
||||
responseJSON, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(responseJSON)
|
||||
return out
|
||||
}
|
||||
|
||||
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.
|
||||
|
||||
@@ -245,7 +245,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Fixed flags aligning with Codex expectations
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "low")
|
||||
|
||||
// Convert thinkingBudget to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out, _ = sjson.Set(out, "reasoning.effort", reasoningEffort)
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.Set(out, "stream", true)
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
|
||||
@@ -7,7 +7,6 @@ package gemini
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -190,19 +189,19 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
}
|
||||
|
||||
// Process output content to build parts array
|
||||
var parts []interface{}
|
||||
hasToolCall := false
|
||||
var pendingFunctionCalls []interface{}
|
||||
var pendingFunctionCalls []string
|
||||
|
||||
flushPendingFunctionCalls := func() {
|
||||
if len(pendingFunctionCalls) > 0 {
|
||||
// Add all pending function calls as individual parts
|
||||
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
|
||||
for _, fc := range pendingFunctionCalls {
|
||||
parts = append(parts, fc)
|
||||
}
|
||||
pendingFunctionCalls = nil
|
||||
if len(pendingFunctionCalls) == 0 {
|
||||
return
|
||||
}
|
||||
// Add all pending function calls as individual parts
|
||||
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
|
||||
for _, fc := range pendingFunctionCalls {
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
|
||||
}
|
||||
pendingFunctionCalls = nil
|
||||
}
|
||||
|
||||
if output := responseData.Get("output"); output.Exists() && output.IsArray() {
|
||||
@@ -216,11 +215,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
|
||||
// Add thinking content
|
||||
if content := value.Get("content"); content.Exists() {
|
||||
part := map[string]interface{}{
|
||||
"thought": true,
|
||||
"text": content.String(),
|
||||
}
|
||||
parts = append(parts, part)
|
||||
part := `{"text":"","thought":true}`
|
||||
part, _ = sjson.Set(part, "text", content.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
}
|
||||
|
||||
case "message":
|
||||
@@ -232,10 +229,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
content.ForEach(func(_, contentItem gjson.Result) bool {
|
||||
if contentItem.Get("type").String() == "output_text" {
|
||||
if text := contentItem.Get("text"); text.Exists() {
|
||||
part := map[string]interface{}{
|
||||
"text": text.String(),
|
||||
}
|
||||
parts = append(parts, part)
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", text.String())
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -245,28 +241,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
case "function_call":
|
||||
// Collect function call for potential merging with consecutive ones
|
||||
hasToolCall = true
|
||||
functionCall := map[string]interface{}{
|
||||
"functionCall": map[string]interface{}{
|
||||
"name": func() string {
|
||||
n := value.Get("name").String()
|
||||
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
|
||||
if orig, ok := rev[n]; ok {
|
||||
return orig
|
||||
}
|
||||
return n
|
||||
}(),
|
||||
"args": map[string]interface{}{},
|
||||
},
|
||||
functionCall := `{"functionCall":{"args":{},"name":""}}`
|
||||
{
|
||||
n := value.Get("name").String()
|
||||
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
|
||||
if orig, ok := rev[n]; ok {
|
||||
n = orig
|
||||
}
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
|
||||
}
|
||||
|
||||
// Parse and set arguments
|
||||
if argsStr := value.Get("arguments").String(); argsStr != "" {
|
||||
argsResult := gjson.Parse(argsStr)
|
||||
if argsResult.IsObject() {
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
|
||||
functionCall["functionCall"].(map[string]interface{})["args"] = args
|
||||
}
|
||||
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,11 +268,6 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
|
||||
flushPendingFunctionCalls()
|
||||
}
|
||||
|
||||
// Set the parts array
|
||||
if len(parts) > 0 {
|
||||
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", mustMarshalJSON(parts))
|
||||
}
|
||||
|
||||
// Set finish reason based on whether there were tool calls
|
||||
if hasToolCall {
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||
@@ -323,15 +307,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
|
||||
return rev
|
||||
}
|
||||
|
||||
// mustMarshalJSON marshals a value to JSON, panicking on error.
|
||||
func mustMarshalJSON(v interface{}) string {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", v.Value())
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "low")
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "medium")
|
||||
}
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
|
||||
@@ -7,10 +7,8 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -41,92 +39,102 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
|
||||
// system instruction
|
||||
var systemInstruction *client.Content
|
||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
systemPart := client.Part{Text: systemPrompt}
|
||||
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
|
||||
systemInstruction := `{"role":"user","parts":[]}`
|
||||
hasSystemParts := false
|
||||
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
|
||||
if systemPromptResult.Get("type").String() == "text" {
|
||||
textResult := systemPromptResult.Get("text")
|
||||
if textResult.Type == gjson.String {
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", textResult.String())
|
||||
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
|
||||
hasSystemParts = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if hasSystemParts {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
|
||||
}
|
||||
if len(systemInstruction.Parts) == 0 {
|
||||
systemInstruction = nil
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String())
|
||||
}
|
||||
|
||||
// contents
|
||||
contents := make([]client.Content, 0)
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
for i := 0; i < len(messageResults); i++ {
|
||||
messageResult := messageResults[i]
|
||||
if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
|
||||
messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
|
||||
roleResult := messageResult.Get("role")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
return true
|
||||
}
|
||||
role := roleResult.String()
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||
|
||||
contentJSON := `{"role":"","parts":[]}`
|
||||
contentJSON, _ = sjson.Set(contentJSON, "role", role)
|
||||
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
for j := 0; j < len(contentResults); j++ {
|
||||
contentResult := contentResults[j]
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
|
||||
switch contentResult.Get("type").String() {
|
||||
case "text":
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
|
||||
case "tool_use":
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{Name: functionName, Args: args},
|
||||
ThoughtSignature: geminiCLIClaudeThoughtSignature,
|
||||
})
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() && gjson.Valid(functionArgs) {
|
||||
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
|
||||
part, _ = sjson.Set(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
|
||||
part, _ = sjson.Set(part, "functionCall.name", functionName)
|
||||
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
|
||||
case "tool_result":
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
if toolCallID == "" {
|
||||
return true
|
||||
}
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
part, _ = sjson.Set(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
}
|
||||
}
|
||||
contents = append(contents, clientContent)
|
||||
return true
|
||||
})
|
||||
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentsResult.String())
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// tools
|
||||
var tools []client.ToolDeclaration
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
|
||||
hasTools := false
|
||||
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
|
||||
inputSchemaResult := toolResult.Get("input_schema")
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
inputSchema := inputSchemaResult.Raw
|
||||
@@ -134,30 +142,21 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
var toolDeclaration any
|
||||
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
tool, _ = sjson.Delete(tool, "type")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
|
||||
if !hasTools {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
|
||||
hasTools = true
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "request.tools.0.functionDeclarations.-1", tool)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !hasTools {
|
||||
out, _ = sjson.Delete(out, "request.tools")
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
if systemInstruction != nil {
|
||||
b, _ := json.Marshal(systemInstruction)
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
b, _ := json.Marshal(contents)
|
||||
out, _ = sjson.SetRaw(out, "request.contents", string(b))
|
||||
}
|
||||
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
|
||||
b, _ := json.Marshal(tools)
|
||||
out, _ = sjson.SetRaw(out, "request.tools", string(b))
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
|
||||
@@ -9,7 +9,6 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -276,22 +275,16 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": root.Get("response.responseId").String(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": root.Get("response.modelVersion").String(),
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": root.Get("response.usageMetadata.promptTokenCount").Int(),
|
||||
"output_tokens": root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int(),
|
||||
},
|
||||
}
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("response.responseId").String())
|
||||
out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String())
|
||||
|
||||
inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
|
||||
outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int()
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
|
||||
parts := root.Get("response.candidates.0.content.parts")
|
||||
var contentBlocks []interface{}
|
||||
textBuilder := strings.Builder{}
|
||||
thinkingBuilder := strings.Builder{}
|
||||
toolIDCounter := 0
|
||||
@@ -301,10 +294,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
if textBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": textBuilder.String(),
|
||||
})
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", textBuilder.String())
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
textBuilder.Reset()
|
||||
}
|
||||
|
||||
@@ -312,10 +304,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
if thinkingBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkingBuilder.String(),
|
||||
})
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
}
|
||||
|
||||
@@ -339,21 +330,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
toolIDCounter++
|
||||
toolBlock := map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("tool_%d", toolIDCounter),
|
||||
"name": name,
|
||||
"input": map[string]interface{}{},
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
inputRaw := "{}"
|
||||
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
|
||||
inputRaw = args.Raw
|
||||
}
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() {
|
||||
var parsed interface{}
|
||||
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
|
||||
toolBlock["input"] = parsed
|
||||
}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, toolBlock)
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -362,8 +347,6 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
flushThinking()
|
||||
flushText()
|
||||
|
||||
response["content"] = contentBlocks
|
||||
|
||||
stopReason := "end_turn"
|
||||
if hasToolCall {
|
||||
stopReason = "tool_use"
|
||||
@@ -379,19 +362,13 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
|
||||
}
|
||||
}
|
||||
}
|
||||
response["stop_reason"] = stopReason
|
||||
out, _ = sjson.Set(out, "stop_reason", stopReason)
|
||||
|
||||
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) {
|
||||
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
||||
delete(response, "usage")
|
||||
}
|
||||
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() {
|
||||
out, _ = sjson.Delete(out, "usage")
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(encoded)
|
||||
return out
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
|
||||
@@ -7,7 +7,6 @@ package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -117,8 +116,6 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
|
||||
|
||||
// FunctionCallGroup represents a group of function calls and their responses
|
||||
type FunctionCallGroup struct {
|
||||
ModelContent map[string]interface{}
|
||||
FunctionCalls []gjson.Result
|
||||
ResponsesNeeded int
|
||||
}
|
||||
|
||||
@@ -146,7 +143,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
}
|
||||
|
||||
// Initialize data structures for processing and grouping
|
||||
var newContents []interface{} // Final processed contents array
|
||||
contentsWrapper := `{"contents":[]}`
|
||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||
|
||||
@@ -178,23 +175,17 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
var responseParts []interface{}
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
if !response.IsObject() {
|
||||
log.Warnf("failed to parse function response")
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
functionResponseContent := map[string]interface{}{
|
||||
"parts": responseParts,
|
||||
"role": "function",
|
||||
}
|
||||
newContents = append(newContents, functionResponseContent)
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
|
||||
// Remove this group as it's been satisfied
|
||||
@@ -208,50 +199,42 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
// If this is a model with function calls, create a new group
|
||||
if role == "model" {
|
||||
var functionCallsInThisModel []gjson.Result
|
||||
functionCallsCount := 0
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
functionCallsInThisModel = append(functionCallsInThisModel, part)
|
||||
functionCallsCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(functionCallsInThisModel) > 0 {
|
||||
if functionCallsCount > 0 {
|
||||
// Add the model content
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse model content")
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
ModelContent: contentMap,
|
||||
FunctionCalls: functionCallsInThisModel,
|
||||
ResponsesNeeded: len(functionCallsInThisModel),
|
||||
ResponsesNeeded: functionCallsCount,
|
||||
}
|
||||
pendingGroups = append(pendingGroups, group)
|
||||
} else {
|
||||
// Regular model content without function calls
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
}
|
||||
} else {
|
||||
// Non-model content (user, etc.)
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||
if !value.IsObject() {
|
||||
log.Warnf("failed to parse content")
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -263,31 +246,24 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
var responseParts []interface{}
|
||||
functionResponseContent := `{"parts":[],"role":"function"}`
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
if !response.IsObject() {
|
||||
log.Warnf("failed to parse function response")
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
functionResponseContent := map[string]interface{}{
|
||||
"parts": responseParts,
|
||||
"role": "function",
|
||||
}
|
||||
newContents = append(newContents, functionResponseContent)
|
||||
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
|
||||
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the original JSON with the new contents
|
||||
result := input
|
||||
newContentsJSON, _ := json.Marshal(newContents)
|
||||
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
|
||||
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -39,31 +39,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String())
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
@@ -178,6 +160,14 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
} else if content.IsObject() && content.Get("type").String() == "text" {
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
|
||||
} else if content.IsArray() {
|
||||
contents := content.Array()
|
||||
if len(contents) > 0 {
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
|
||||
for j := 0; j < len(contents); j++ {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if role == "user" || (role == "system" && len(arr) == 1) {
|
||||
// Build single user content node to avoid splitting into multiple contents
|
||||
@@ -223,54 +213,77 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
p := 0
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||
for _, item := range content.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
case "image_url":
|
||||
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
||||
imageURL := item.Get("image_url.url").String()
|
||||
if len(imageURL) > 5 { // expect data:...
|
||||
pieces := strings.SplitN(imageURL[5:], ";", 2)
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
p++
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -296,7 +309,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
@@ -311,7 +324,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
|
||||
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
@@ -352,18 +365,3 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user