mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Compare commits
332 Commits
v6.6.76
...
c82d8e250a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c82d8e250a | ||
|
|
73db4e64f6 | ||
|
|
69ca0a8fac | ||
|
|
3b04e11544 | ||
|
|
e0927afa40 | ||
|
|
f97d9f3e11 | ||
|
|
6d8609e457 | ||
|
|
d216adeffc | ||
|
|
bb09708c02 | ||
|
|
1150d972a1 | ||
|
|
13bb7cf704 | ||
|
|
8bce696a7c | ||
|
|
6db8d2a28e | ||
|
|
2854e04bbb | ||
|
|
f99cddf97f | ||
|
|
f887f9985d | ||
|
|
550da0cee8 | ||
|
|
7ff3936efe | ||
|
|
f36a5f5654 | ||
|
|
c1facdff67 | ||
|
|
4ee46bc9f2 | ||
|
|
c3e94a8277 | ||
|
|
6b6d030ed3 | ||
|
|
538039f583 | ||
|
|
ca796510e9 | ||
|
|
d0d66cdcb7 | ||
|
|
d7d54fa2cc | ||
|
|
31649325f0 | ||
|
|
3a43ecb19b | ||
|
|
a709e5a12d | ||
|
|
f0ac77197b | ||
|
|
da0bbf2a3f | ||
|
|
295f34d7f0 | ||
|
|
c41ce77eea | ||
|
|
4eb1e6093f | ||
|
|
189a066807 | ||
|
|
d0bada7a43 | ||
|
|
9dc0e6d08b | ||
|
|
8510fc313e | ||
|
|
2666708c30 | ||
|
|
9e5b1d24e8 | ||
|
|
a7dae6ad52 | ||
|
|
e93e05ae25 | ||
|
|
c8c27325dc | ||
|
|
c3b6f3918c | ||
|
|
bbb55a8ab4 | ||
|
|
04b2290927 | ||
|
|
53920b0399 | ||
|
|
7583193c2a | ||
|
|
7cc3bd4ba0 | ||
|
|
88a0f095e8 | ||
|
|
c65f64dce0 | ||
|
|
d18cd217e1 | ||
|
|
ba4a1ab433 | ||
|
|
decddb521e | ||
|
|
95096bc3fc | ||
|
|
70897247b2 | ||
|
|
9c341f5aa5 | ||
|
|
2af4a8dc12 | ||
|
|
0f53b952b2 | ||
|
|
f30ffd5f5e | ||
|
|
bc9a24d705 | ||
|
|
2c879f13ef | ||
|
|
07b4a08979 | ||
|
|
7f612bb069 | ||
|
|
5743b78694 | ||
|
|
2e6a2b655c | ||
|
|
cb47ac21bf | ||
|
|
a1394b4596 | ||
|
|
9e97948f03 | ||
|
|
f7bfa8a05c | ||
|
|
46c6fb1e7a | ||
|
|
9f9fec5d4c | ||
|
|
e95be10485 | ||
|
|
f3d58fa0ce | ||
|
|
8c0eaa1f71 | ||
|
|
405df58f72 | ||
|
|
e7f13aa008 | ||
|
|
7cb6a9b89a | ||
|
|
9aa5344c29 | ||
|
|
8ba0ebbd2a | ||
|
|
c65407ab9f | ||
|
|
9e59685212 | ||
|
|
4a4dfaa910 | ||
|
|
0d6ecb0191 | ||
|
|
f16461bfe7 | ||
|
|
c32e2a8196 | ||
|
|
873d41582f | ||
|
|
6fb7d85558 | ||
|
|
6da7ed53f2 | ||
|
|
d5e3e32d58 | ||
|
|
f353a54555 | ||
|
|
1d6e2e751d | ||
|
|
cc50b63422 | ||
|
|
15ae83a15b | ||
|
|
81b369aed9 | ||
|
|
c8620d1633 | ||
|
|
ecc850bfb7 | ||
|
|
19b4ef33e0 | ||
|
|
7ca045d8b9 | ||
|
|
abfca6aab2 | ||
|
|
3c71c075db | ||
|
|
9c2992bfb2 | ||
|
|
269a1c5452 | ||
|
|
22ce65ac72 | ||
|
|
a2f8f59192 | ||
|
|
8c7c446f33 | ||
|
|
30a59168d7 | ||
|
|
c8884f5e25 | ||
|
|
d9c6317c84 | ||
|
|
d29ec95526 | ||
|
|
ef4508dbc8 | ||
|
|
f775e46fe2 | ||
|
|
65ad5c0c9d | ||
|
|
88bf4e77ec | ||
|
|
a4f8015caa | ||
|
|
ffd129909e | ||
|
|
9332316383 | ||
|
|
6dcbbf64c3 | ||
|
|
2ce3553612 | ||
|
|
2e14f787d4 | ||
|
|
523b41ccd2 | ||
|
|
09970dc7af | ||
|
|
d81abd401c | ||
|
|
a6cba25bc1 | ||
|
|
c6fa1d0e67 | ||
|
|
ac56e1e88b | ||
|
|
9b72ea9efa | ||
|
|
9f364441e8 | ||
|
|
e49a1c07bf | ||
|
|
8d9f4edf9b | ||
|
|
020e61d0da | ||
|
|
6184c43319 | ||
|
|
2cbe4a790c | ||
|
|
68b3565d7b | ||
|
|
3f385a8572 | ||
|
|
9823dc35e1 | ||
|
|
059bfee91b | ||
|
|
7beaf0eaa2 | ||
|
|
1fef90ff58 | ||
|
|
8447fd27a0 | ||
|
|
7831cba9f6 | ||
|
|
e02b2d58d5 | ||
|
|
28726632a9 | ||
|
|
3b26129c82 | ||
|
|
d4bb4e6624 | ||
|
|
0766c49f93 | ||
|
|
a7ffc77e3d | ||
|
|
e641fde25c | ||
|
|
5717c7f2f4 | ||
|
|
8734d4cb90 | ||
|
|
2f6004d74a | ||
|
|
5baa753539 | ||
|
|
ead98e4bca | ||
|
|
a1634909e8 | ||
|
|
1d2fe55310 | ||
|
|
c175821cc4 | ||
|
|
239a28793c | ||
|
|
c421d653e7 | ||
|
|
2542c2920d | ||
|
|
52e46ced1b | ||
|
|
cf9daf470c | ||
|
|
140d6211cc | ||
|
|
60f9a1442c | ||
|
|
cb6caf3f87 | ||
|
|
99c7abbbf1 | ||
|
|
8f511ac33c | ||
|
|
1046152119 | ||
|
|
f88228f1c5 | ||
|
|
62e2b672d9 | ||
|
|
03005b5d29 | ||
|
|
c7e8830a56 | ||
|
|
d5ef4a6d15 | ||
|
|
97b67e0e49 | ||
|
|
dd6d78cb31 | ||
|
|
46433a25f8 | ||
|
|
c8843edb81 | ||
|
|
f89feb881c | ||
|
|
dbba71028e | ||
|
|
8549a92e9a | ||
|
|
109cffc010 | ||
|
|
f8f3ad84fc | ||
|
|
bc7167e9fe | ||
|
|
384578a88c | ||
|
|
65b4e1ec6c | ||
|
|
6600d58ba2 | ||
|
|
4dc7af5a5d | ||
|
|
902bea24b4 | ||
|
|
c3ef46f409 | ||
|
|
aa0b63e214 | ||
|
|
ea3d22831e | ||
|
|
3b4d6d359b | ||
|
|
48cba39a12 | ||
|
|
cec4e251bd | ||
|
|
526dd866ba | ||
|
|
b31ddc7bf1 | ||
|
|
22e1ad3d8a | ||
|
|
f571b1deb0 | ||
|
|
67f8732683 | ||
|
|
2b387e169b | ||
|
|
199cf480b0 | ||
|
|
4ad6189487 | ||
|
|
fe5b3c80cb | ||
|
|
e0ffec885c | ||
|
|
ff4ff6bc2f | ||
|
|
7248f65c36 | ||
|
|
5c40a2db21 | ||
|
|
086eb3df7a | ||
|
|
ee2976cca0 | ||
|
|
8bc6df329f | ||
|
|
bcd4d9595f | ||
|
|
5a77b7728e | ||
|
|
1fbbba6f59 | ||
|
|
847be0e99d | ||
|
|
f6a2d072e6 | ||
|
|
ed8b0f25ee | ||
|
|
6e4a602c60 | ||
|
|
2262479365 | ||
|
|
33d66959e9 | ||
|
|
7f1b2b3f6e | ||
|
|
40ee065eff | ||
|
|
a75fb6af90 | ||
|
|
72f2125668 | ||
|
|
e8f5888d8e | ||
|
|
0b06d637e7 | ||
|
|
5a7e5bd870 | ||
|
|
6f8a8f8136 | ||
|
|
5df195ea82 | ||
|
|
b163f8ed9e | ||
|
|
a1da6ff5ac | ||
|
|
5977af96a0 | ||
|
|
43652d044c | ||
|
|
b1b379ea18 | ||
|
|
21ac161b21 | ||
|
|
94e979865e | ||
|
|
6c324f2c8b | ||
|
|
543dfd67e0 | ||
|
|
28bd1323a2 | ||
|
|
220ca45f74 | ||
|
|
70a82d80ac | ||
|
|
ac626111ac | ||
|
|
5bb9c2a2bd | ||
|
|
0b5bbe9234 | ||
|
|
14c74e5e84 | ||
|
|
6448d0ee7c | ||
|
|
b0c17af2cf | ||
|
|
8cfe26f10c | ||
|
|
80db2dc254 | ||
|
|
e8e3bc8616 | ||
|
|
bc3195c8d8 | ||
|
|
6494330c6b | ||
|
|
4d7f389b69 | ||
|
|
95f87d5669 | ||
|
|
c83365a349 | ||
|
|
6b3604cf2b | ||
|
|
af6bdca14f | ||
|
|
1c773c428f | ||
|
|
e785bfcd12 | ||
|
|
47dacce6ea | ||
|
|
dcac3407ab | ||
|
|
7004295e1d | ||
|
|
ee62ef4745 | ||
|
|
ef6bafbf7e | ||
|
|
ed28b71e87 | ||
|
|
d47b7dc79a | ||
|
|
49b9709ce5 | ||
|
|
a2eba2cdf5 | ||
|
|
3d01b3cfe8 | ||
|
|
af2efa6f7e | ||
|
|
d73b61d367 | ||
|
|
59a448b645 | ||
|
|
4adb9eed77 | ||
|
|
b6a0f7a07f | ||
|
|
1b2f907671 | ||
|
|
bda04eed8a | ||
|
|
67985d8226 | ||
|
|
cbcb061812 | ||
|
|
9fc2e1b3c8 | ||
|
|
3b484aea9e | ||
|
|
963a0950fa | ||
|
|
f4ba1ab910 | ||
|
|
2662f91082 | ||
|
|
c1db2c7d7c | ||
|
|
5e5d8142f9 | ||
|
|
b01619b441 | ||
|
|
f861bd6a94 | ||
|
|
6dbfdd140d | ||
|
|
aa8526edc0 | ||
|
|
ac3ca0ad8e | ||
|
|
fe6043aec7 | ||
|
|
386ccffed4 | ||
|
|
08d21b76e2 | ||
|
|
ffddd1c90a | ||
|
|
33aa665555 | ||
|
|
00280b6fe8 | ||
|
|
8f8dfd081b | ||
|
|
9f1b445c7c | ||
|
|
ae933dfe14 | ||
|
|
e124db723b | ||
|
|
05444cf32d | ||
|
|
8edbda57cf | ||
|
|
52760a4eaa | ||
|
|
bc32096e9c | ||
|
|
821249a5ed | ||
|
|
ee33863b47 | ||
|
|
cd22c849e2 | ||
|
|
f0e73efda2 | ||
|
|
3156109c71 | ||
|
|
6762e081f3 | ||
|
|
7815ee338d | ||
|
|
44b6c872e2 | ||
|
|
7a77b23f2d | ||
|
|
672e8549c0 | ||
|
|
66f5269a23 | ||
|
|
ebec293497 | ||
|
|
e02ceecd35 | ||
|
|
c8b33a8cc3 | ||
|
|
dca8d5ded8 | ||
|
|
2a7fd1e897 | ||
|
|
b9d1e70ac2 | ||
|
|
fdf5720217 | ||
|
|
f40bd0cd51 | ||
|
|
e33676bb87 | ||
|
|
2a663d5cba | ||
|
|
414db44c00 | ||
|
|
cb3bdffb43 | ||
|
|
48f19aab51 | ||
|
|
48f6d7abdf | ||
|
|
79fbcb3ec4 | ||
|
|
0e4148b229 | ||
|
|
31bd90c748 | ||
|
|
0b834fcb54 |
111
.github/workflows/docker-image.yml
vendored
111
.github/workflows/docker-image.yml
vendored
@@ -10,13 +10,11 @@ env:
|
|||||||
DOCKERHUB_REPO: eceasy/cli-proxy-api
|
DOCKERHUB_REPO: eceasy/cli-proxy-api
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
docker:
|
docker_amd64:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v3
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -29,18 +27,113 @@ jobs:
|
|||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- name: Build and push
|
- name: Build and push (amd64)
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: |
|
platforms: linux/amd64
|
||||||
linux/amd64
|
|
||||||
linux/arm64
|
|
||||||
push: true
|
push: true
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ env.VERSION }}
|
VERSION=${{ env.VERSION }}
|
||||||
COMMIT=${{ env.COMMIT }}
|
COMMIT=${{ env.COMMIT }}
|
||||||
BUILD_DATE=${{ env.BUILD_DATE }}
|
BUILD_DATE=${{ env.BUILD_DATE }}
|
||||||
tags: |
|
tags: |
|
||||||
${{ env.DOCKERHUB_REPO }}:latest
|
${{ env.DOCKERHUB_REPO }}:latest-amd64
|
||||||
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}
|
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-amd64
|
||||||
|
|
||||||
|
docker_arm64:
|
||||||
|
runs-on: ubuntu-24.04-arm
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Login to DockerHub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
- name: Generate Build Metadata
|
||||||
|
run: |
|
||||||
|
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||||
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
|
- name: Build and push (arm64)
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: linux/arm64
|
||||||
|
push: true
|
||||||
|
build-args: |
|
||||||
|
VERSION=${{ env.VERSION }}
|
||||||
|
COMMIT=${{ env.COMMIT }}
|
||||||
|
BUILD_DATE=${{ env.BUILD_DATE }}
|
||||||
|
tags: |
|
||||||
|
${{ env.DOCKERHUB_REPO }}:latest-arm64
|
||||||
|
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-arm64
|
||||||
|
|
||||||
|
docker_manifest:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs:
|
||||||
|
- docker_amd64
|
||||||
|
- docker_arm64
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Login to DockerHub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
- name: Generate Build Metadata
|
||||||
|
run: |
|
||||||
|
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||||
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
|
- name: Create and push multi-arch manifests
|
||||||
|
run: |
|
||||||
|
docker buildx imagetools create \
|
||||||
|
--tag "${DOCKERHUB_REPO}:latest" \
|
||||||
|
"${DOCKERHUB_REPO}:latest-amd64" \
|
||||||
|
"${DOCKERHUB_REPO}:latest-arm64"
|
||||||
|
docker buildx imagetools create \
|
||||||
|
--tag "${DOCKERHUB_REPO}:${VERSION}" \
|
||||||
|
"${DOCKERHUB_REPO}:${VERSION}-amd64" \
|
||||||
|
"${DOCKERHUB_REPO}:${VERSION}-arm64"
|
||||||
|
- name: Cleanup temporary tags
|
||||||
|
continue-on-error: true
|
||||||
|
env:
|
||||||
|
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
namespace="${DOCKERHUB_REPO%%/*}"
|
||||||
|
repo_name="${DOCKERHUB_REPO#*/}"
|
||||||
|
|
||||||
|
token="$(
|
||||||
|
curl -fsSL \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d "{\"username\":\"${DOCKERHUB_USERNAME}\",\"password\":\"${DOCKERHUB_TOKEN}\"}" \
|
||||||
|
'https://hub.docker.com/v2/users/login/' \
|
||||||
|
| python3 -c 'import json,sys; print(json.load(sys.stdin)["token"])'
|
||||||
|
)"
|
||||||
|
|
||||||
|
delete_tag() {
|
||||||
|
local tag="$1"
|
||||||
|
local url="https://hub.docker.com/v2/repositories/${namespace}/${repo_name}/tags/${tag}/"
|
||||||
|
local http_code
|
||||||
|
http_code="$(curl -sS -o /dev/null -w "%{http_code}" -X DELETE -H "Authorization: JWT ${token}" "${url}" || true)"
|
||||||
|
if [ "${http_code}" = "204" ] || [ "${http_code}" = "404" ]; then
|
||||||
|
echo "Docker Hub tag removed (or missing): ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
echo "Docker Hub tag delete failed: ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
delete_tag "latest-amd64"
|
||||||
|
delete_tag "latest-arm64"
|
||||||
|
delete_tag "${VERSION}-amd64"
|
||||||
|
delete_tag "${VERSION}-arm64"
|
||||||
|
|||||||
35
README.md
35
README.md
@@ -118,9 +118,44 @@ Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings,
|
|||||||
|
|
||||||
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
|
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
|
||||||
|
|
||||||
|
### [CodMate](https://github.com/loocor/CodMate)
|
||||||
|
|
||||||
|
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
|
||||||
|
|
||||||
|
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||||
|
|
||||||
|
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
|
||||||
|
|
||||||
|
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||||
|
|
||||||
|
VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
|
||||||
|
|
||||||
|
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
|
||||||
|
|
||||||
|
Windows desktop app built with Tauri + React for monitoring AI coding assistant quotas via CLIProxyAPI. Track usage across Gemini, Claude, OpenAI Codex, and Antigravity accounts with real-time dashboard, system tray integration, and one-click proxy control - no API keys needed.
|
||||||
|
|
||||||
|
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||||
|
|
||||||
|
A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
|
||||||
|
|
||||||
|
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||||
|
|
||||||
|
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||||
|
|
||||||
|
## More choices
|
||||||
|
|
||||||
|
Those projects are ports of CLIProxyAPI or inspired by it:
|
||||||
|
|
||||||
|
### [9Router](https://github.com/decolua/9router)
|
||||||
|
|
||||||
|
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||||
|
|||||||
35
README_CN.md
35
README_CN.md
@@ -117,9 +117,44 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
|
|||||||
|
|
||||||
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
|
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
|
||||||
|
|
||||||
|
### [CodMate](https://github.com/loocor/CodMate)
|
||||||
|
|
||||||
|
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话(Claude Code、Codex、Gemini CLI),提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
|
||||||
|
|
||||||
|
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
|
||||||
|
|
||||||
|
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
|
||||||
|
|
||||||
|
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
|
||||||
|
|
||||||
|
一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
|
||||||
|
|
||||||
|
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
|
||||||
|
|
||||||
|
Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI 监控 AI 编程助手配额。支持跨 Gemini、Claude、OpenAI Codex 和 Antigravity 账户的使用量追踪,提供实时仪表盘、系统托盘集成和一键代理控制,无需 API 密钥。
|
||||||
|
|
||||||
|
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||||
|
|
||||||
|
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||||
|
|
||||||
|
## 更多选择
|
||||||
|
|
||||||
|
以下项目是 CLIProxyAPI 的移植版或受其启发:
|
||||||
|
|
||||||
|
### [9Router](https://github.com/decolua/9router)
|
||||||
|
|
||||||
|
基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
|
||||||
|
|
||||||
|
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||||
|
|
||||||
|
Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方库。主要功能包括:自动创建快捷方式、静默运行、密码管理、通道切换(Main / Plus)以及自动下载与更新。
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ func main() {
|
|||||||
var iflowLogin bool
|
var iflowLogin bool
|
||||||
var iflowCookie bool
|
var iflowCookie bool
|
||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
|
var oauthCallbackPort int
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
@@ -75,6 +76,7 @@ func main() {
|
|||||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
|
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
@@ -425,7 +427,8 @@ func main() {
|
|||||||
|
|
||||||
// Create login options to be used in authentication flows.
|
// Create login options to be used in authentication flows.
|
||||||
options := &cmd.LoginOptions{
|
options := &cmd.LoginOptions{
|
||||||
NoBrowser: noBrowser,
|
NoBrowser: noBrowser,
|
||||||
|
CallbackPort: oauthCallbackPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the shared token store once so all components use the same persistence backend.
|
// Register the shared token store once so all components use the same persistence backend.
|
||||||
|
|||||||
@@ -50,6 +50,10 @@ logging-to-file: false
|
|||||||
# files are deleted until within the limit. Set to 0 to disable.
|
# files are deleted until within the limit. Set to 0 to disable.
|
||||||
logs-max-total-size-mb: 0
|
logs-max-total-size-mb: 0
|
||||||
|
|
||||||
|
# Maximum number of error log files retained when request logging is disabled.
|
||||||
|
# When exceeded, the oldest error log files are deleted. Default is 10. Set to 0 to disable cleanup.
|
||||||
|
error-logs-max-files: 10
|
||||||
|
|
||||||
# When false, disable in-memory usage statistics aggregation
|
# When false, disable in-memory usage statistics aggregation
|
||||||
usage-statistics-enabled: false
|
usage-statistics-enabled: false
|
||||||
|
|
||||||
@@ -77,11 +81,18 @@ routing:
|
|||||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
ws-auth: false
|
ws-auth: false
|
||||||
|
|
||||||
|
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
||||||
|
nonstream-keepalive-interval: 0
|
||||||
|
|
||||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||||
# streaming:
|
# streaming:
|
||||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||||
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
||||||
|
|
||||||
|
# When true, enable official Codex instructions injection for Codex API requests.
|
||||||
|
# When false (default), CodexInstructionsForModel returns immediately without modification.
|
||||||
|
codex-instructions-enabled: false
|
||||||
|
|
||||||
# Gemini API keys
|
# Gemini API keys
|
||||||
# gemini-api-key:
|
# gemini-api-key:
|
||||||
# - api-key: "AIzaSy...01"
|
# - api-key: "AIzaSy...01"
|
||||||
@@ -134,6 +145,15 @@ ws-auth: false
|
|||||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||||
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||||
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
||||||
|
# cloak: # optional: request cloaking for non-Claude-Code clients
|
||||||
|
# mode: "auto" # "auto" (default): cloak only when client is not Claude Code
|
||||||
|
# # "always": always apply cloaking
|
||||||
|
# # "never": never apply cloaking
|
||||||
|
# strict-mode: false # false (default): prepend Claude Code prompt to user system messages
|
||||||
|
# # true: strip all user system messages, keep only Claude Code prompt
|
||||||
|
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||||
|
# - "API"
|
||||||
|
# - "proxy"
|
||||||
|
|
||||||
# OpenAI compatibility providers
|
# OpenAI compatibility providers
|
||||||
# openai-compatibility:
|
# openai-compatibility:
|
||||||
@@ -198,23 +218,37 @@ ws-auth: false
|
|||||||
# - from: "claude-haiku-4-5-20251001"
|
# - from: "claude-haiku-4-5-20251001"
|
||||||
# to: "gemini-2.5-flash"
|
# to: "gemini-2.5-flash"
|
||||||
|
|
||||||
# Global OAuth model name mappings (per channel)
|
# Global OAuth model name aliases (per channel)
|
||||||
# These mappings rename model IDs for both model listing and request routing.
|
# These aliases rename model IDs for both model listing and request routing.
|
||||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||||
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||||
# oauth-model-mappings:
|
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||||
|
oauth-model-alias:
|
||||||
|
antigravity:
|
||||||
|
- name: "rev19-uic3-1p"
|
||||||
|
alias: "gemini-2.5-computer-use-preview-10-2025"
|
||||||
|
- name: "gemini-3-pro-image"
|
||||||
|
alias: "gemini-3-pro-image-preview"
|
||||||
|
- name: "gemini-3-pro-high"
|
||||||
|
alias: "gemini-3-pro-preview"
|
||||||
|
- name: "gemini-3-flash"
|
||||||
|
alias: "gemini-3-flash-preview"
|
||||||
|
- name: "claude-sonnet-4-5"
|
||||||
|
alias: "gemini-claude-sonnet-4-5"
|
||||||
|
- name: "claude-sonnet-4-5-thinking"
|
||||||
|
alias: "gemini-claude-sonnet-4-5-thinking"
|
||||||
|
- name: "claude-opus-4-5-thinking"
|
||||||
|
alias: "gemini-claude-opus-4-5-thinking"
|
||||||
# gemini-cli:
|
# gemini-cli:
|
||||||
# - name: "gemini-2.5-pro" # original model name under this channel
|
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||||
# alias: "g2.5p" # client-visible alias
|
# alias: "g2.5p" # client-visible alias
|
||||||
|
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
|
||||||
# vertex:
|
# vertex:
|
||||||
# - name: "gemini-2.5-pro"
|
# - name: "gemini-2.5-pro"
|
||||||
# alias: "g2.5p"
|
# alias: "g2.5p"
|
||||||
# aistudio:
|
# aistudio:
|
||||||
# - name: "gemini-2.5-pro"
|
# - name: "gemini-2.5-pro"
|
||||||
# alias: "g2.5p"
|
# alias: "g2.5p"
|
||||||
# antigravity:
|
|
||||||
# - name: "gemini-3-pro-preview"
|
|
||||||
# alias: "g3p"
|
|
||||||
# claude:
|
# claude:
|
||||||
# - name: "claude-sonnet-4-5-20250929"
|
# - name: "claude-sonnet-4-5-20250929"
|
||||||
# alias: "cs4.5"
|
# alias: "cs4.5"
|
||||||
@@ -255,12 +289,31 @@ ws-auth: false
|
|||||||
# default: # Default rules only set parameters when they are missing in the payload.
|
# default: # Default rules only set parameters when they are missing in the payload.
|
||||||
# - models:
|
# - models:
|
||||||
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
# params: # JSON path (gjson/sjson syntax) -> value
|
# params: # JSON path (gjson/sjson syntax) -> value
|
||||||
# "generationConfig.thinkingConfig.thinkingBudget": 32768
|
# "generationConfig.thinkingConfig.thinkingBudget": 32768
|
||||||
|
# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON).
|
||||||
|
# - models:
|
||||||
|
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||||
|
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
|
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
||||||
|
# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}"
|
||||||
# override: # Override rules always set parameters, overwriting any existing values.
|
# override: # Override rules always set parameters, overwriting any existing values.
|
||||||
# - models:
|
# - models:
|
||||||
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
# params: # JSON path (gjson/sjson syntax) -> value
|
# params: # JSON path (gjson/sjson syntax) -> value
|
||||||
# "reasoning.effort": "high"
|
# "reasoning.effort": "high"
|
||||||
|
# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON).
|
||||||
|
# - models:
|
||||||
|
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||||
|
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
|
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
||||||
|
# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}"
|
||||||
|
# filter: # Filter rules remove specified parameters from the payload.
|
||||||
|
# - models:
|
||||||
|
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||||
|
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
|
# params: # JSON paths (gjson/sjson syntax) to remove from the payload
|
||||||
|
# - "generationConfig.thinkingConfig.thinkingBudget"
|
||||||
|
# - "generationConfig.responseJsonSchema"
|
||||||
|
|||||||
124
docker-build.sh
124
docker-build.sh
@@ -5,9 +5,115 @@
|
|||||||
# This script automates the process of building and running the Docker container
|
# This script automates the process of building and running the Docker container
|
||||||
# with version information dynamically injected at build time.
|
# with version information dynamically injected at build time.
|
||||||
|
|
||||||
# Exit immediately if a command exits with a non-zero status.
|
# Hidden feature: Preserve usage statistics across rebuilds
|
||||||
|
# Usage: ./docker-build.sh --with-usage
|
||||||
|
# First run prompts for management API key, saved to temp/stats/.api_secret
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
STATS_DIR="temp/stats"
|
||||||
|
STATS_FILE="${STATS_DIR}/.usage_backup.json"
|
||||||
|
SECRET_FILE="${STATS_DIR}/.api_secret"
|
||||||
|
WITH_USAGE=false
|
||||||
|
|
||||||
|
get_port() {
|
||||||
|
if [[ -f "config.yaml" ]]; then
|
||||||
|
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
|
||||||
|
else
|
||||||
|
echo "8317"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
export_stats_api_secret() {
|
||||||
|
if [[ -f "${SECRET_FILE}" ]]; then
|
||||||
|
API_SECRET=$(cat "${SECRET_FILE}")
|
||||||
|
else
|
||||||
|
if [[ ! -d "${STATS_DIR}" ]]; then
|
||||||
|
mkdir -p "${STATS_DIR}"
|
||||||
|
fi
|
||||||
|
echo "First time using --with-usage. Management API key required."
|
||||||
|
read -r -p "Enter management key: " -s API_SECRET
|
||||||
|
echo
|
||||||
|
echo "${API_SECRET}" > "${SECRET_FILE}"
|
||||||
|
chmod 600 "${SECRET_FILE}"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
check_container_running() {
|
||||||
|
local port
|
||||||
|
port=$(get_port)
|
||||||
|
|
||||||
|
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
|
||||||
|
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
|
||||||
|
echo "Please start the container first or use without --with-usage flag."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
export_stats() {
|
||||||
|
local port
|
||||||
|
port=$(get_port)
|
||||||
|
|
||||||
|
if [[ ! -d "${STATS_DIR}" ]]; then
|
||||||
|
mkdir -p "${STATS_DIR}"
|
||||||
|
fi
|
||||||
|
check_container_running
|
||||||
|
echo "Exporting usage statistics..."
|
||||||
|
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
|
||||||
|
"http://localhost:${port}/v0/management/usage/export")
|
||||||
|
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
|
||||||
|
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
|
||||||
|
|
||||||
|
if [[ "${HTTP_CODE}" != "200" ]]; then
|
||||||
|
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
|
||||||
|
echo "Statistics exported to ${STATS_FILE}"
|
||||||
|
}
|
||||||
|
|
||||||
|
import_stats() {
|
||||||
|
local port
|
||||||
|
port=$(get_port)
|
||||||
|
|
||||||
|
echo "Importing usage statistics..."
|
||||||
|
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
|
||||||
|
-H "X-Management-Key: ${API_SECRET}" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d @"${STATS_FILE}" \
|
||||||
|
"http://localhost:${port}/v0/management/usage/import")
|
||||||
|
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
|
||||||
|
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
|
||||||
|
|
||||||
|
if [[ "${IMPORT_CODE}" == "200" ]]; then
|
||||||
|
echo "Statistics imported successfully"
|
||||||
|
else
|
||||||
|
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -f "${STATS_FILE}"
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_service() {
|
||||||
|
local port
|
||||||
|
port=$(get_port)
|
||||||
|
|
||||||
|
echo "Waiting for service to be ready..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if [[ "${1:-}" == "--with-usage" ]]; then
|
||||||
|
WITH_USAGE=true
|
||||||
|
export_stats_api_secret
|
||||||
|
fi
|
||||||
|
|
||||||
# --- Step 1: Choose Environment ---
|
# --- Step 1: Choose Environment ---
|
||||||
echo "Please select an option:"
|
echo "Please select an option:"
|
||||||
echo "1) Run using Pre-built Image (Recommended)"
|
echo "1) Run using Pre-built Image (Recommended)"
|
||||||
@@ -18,7 +124,14 @@ read -r -p "Enter choice [1-2]: " choice
|
|||||||
case "$choice" in
|
case "$choice" in
|
||||||
1)
|
1)
|
||||||
echo "--- Running with Pre-built Image ---"
|
echo "--- Running with Pre-built Image ---"
|
||||||
|
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||||
|
export_stats
|
||||||
|
fi
|
||||||
docker compose up -d --remove-orphans --no-build
|
docker compose up -d --remove-orphans --no-build
|
||||||
|
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||||
|
wait_for_service
|
||||||
|
import_stats
|
||||||
|
fi
|
||||||
echo "Services are starting from remote image."
|
echo "Services are starting from remote image."
|
||||||
echo "Run 'docker compose logs -f' to see the logs."
|
echo "Run 'docker compose logs -f' to see the logs."
|
||||||
;;
|
;;
|
||||||
@@ -45,9 +158,18 @@ case "$choice" in
|
|||||||
--build-arg COMMIT="${COMMIT}" \
|
--build-arg COMMIT="${COMMIT}" \
|
||||||
--build-arg BUILD_DATE="${BUILD_DATE}"
|
--build-arg BUILD_DATE="${BUILD_DATE}"
|
||||||
|
|
||||||
|
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||||
|
export_stats
|
||||||
|
fi
|
||||||
|
|
||||||
echo "Starting the services..."
|
echo "Starting the services..."
|
||||||
docker compose up -d --remove-orphans --pull never
|
docker compose up -d --remove-orphans --pull never
|
||||||
|
|
||||||
|
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||||
|
wait_for_service
|
||||||
|
import_stats
|
||||||
|
fi
|
||||||
|
|
||||||
echo "Build complete. Services are starting."
|
echo "Build complete. Services are starting."
|
||||||
echo "Run 'docker compose logs -f' to see the logs."
|
echo "Run 'docker compose logs -f' to see the logs."
|
||||||
;;
|
;;
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ services:
|
|||||||
- "51121:51121"
|
- "51121:51121"
|
||||||
- "11451:11451"
|
- "11451:11451"
|
||||||
volumes:
|
volumes:
|
||||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
- ${CLI_PROXY_CONFIG_PATH:-./config.yaml}:/CLIProxyAPI/config.yaml
|
||||||
- ./auths:/root/.cli-proxy-api
|
- ${CLI_PROXY_AUTH_PATH:-./auths}:/root/.cli-proxy-api
|
||||||
- ./logs:/CLIProxyAPI/logs
|
- ${CLI_PROXY_LOG_PATH:-./logs}:/CLIProxyAPI/logs
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -122,7 +123,9 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
|
|||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
// Inject credentials via PrepareRequest hook.
|
// Inject credentials via PrepareRequest hook.
|
||||||
_ = (MyExecutor{}).PrepareRequest(httpReq, a)
|
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
|
||||||
|
return clipexec.Response{}, errPrep
|
||||||
|
}
|
||||||
|
|
||||||
resp, errDo := client.Do(httpReq)
|
resp, errDo := client.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
@@ -130,13 +133,28 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
// Best-effort close; log if needed in real projects.
|
fmt.Fprintf(os.Stderr, "close response body error: %v\n", errClose)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
return clipexec.Response{Payload: body}, nil
|
return clipexec.Response{Payload: body}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (MyExecutor) HttpRequest(ctx context.Context, a *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("myprov executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
|
||||||
|
return nil, errPrep
|
||||||
|
}
|
||||||
|
client := buildHTTPClient(a)
|
||||||
|
return client.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||||
}
|
}
|
||||||
@@ -187,7 +205,7 @@ func main() {
|
|||||||
// Optional: add a simple middleware + custom request logger
|
// Optional: add a simple middleware + custom request logger
|
||||||
api.WithMiddleware(func(c *gin.Context) { c.Header("X-Example", "custom-provider"); c.Next() }),
|
api.WithMiddleware(func(c *gin.Context) { c.Header("X-Example", "custom-provider"); c.Next() }),
|
||||||
api.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
|
api.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
|
||||||
return logging.NewFileRequestLogger(true, "logs", filepath.Dir(cfgPath))
|
return logging.NewFileRequestLoggerWithOptions(true, "logs", filepath.Dir(cfgPath), cfg.ErrorLogsMaxFiles)
|
||||||
}),
|
}),
|
||||||
).
|
).
|
||||||
WithHooks(hooks).
|
WithHooks(hooks).
|
||||||
@@ -199,8 +217,8 @@ func main() {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
if errRun := svc.Run(ctx); errRun != nil && !errors.Is(errRun, context.Canceled) {
|
||||||
panic(err)
|
panic(errRun)
|
||||||
}
|
}
|
||||||
_ = os.Stderr // keep os import used (demo only)
|
_ = os.Stderr // keep os import used (demo only)
|
||||||
_ = time.Second
|
_ = time.Second
|
||||||
|
|||||||
140
examples/http-request/main.go
Normal file
140
examples/http-request/main.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
// Package main demonstrates how to use coreauth.Manager.HttpRequest/NewHttpRequest
|
||||||
|
// to execute arbitrary HTTP requests with provider credentials injected.
|
||||||
|
//
|
||||||
|
// This example registers a minimal custom executor that injects an Authorization
|
||||||
|
// header from auth.Attributes["api_key"], then performs two requests against
|
||||||
|
// httpbin.org to show the injected headers.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const providerKey = "echo"
|
||||||
|
|
||||||
|
// EchoExecutor is a minimal provider implementation for demonstration purposes.
|
||||||
|
type EchoExecutor struct{}
|
||||||
|
|
||||||
|
func (EchoExecutor) Identifier() string { return providerKey }
|
||||||
|
|
||||||
|
func (EchoExecutor) PrepareRequest(req *http.Request, auth *coreauth.Auth) error {
|
||||||
|
if req == nil || auth == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
if apiKey := strings.TrimSpace(auth.Attributes["api_key"]); apiKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (EchoExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("echo executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if errPrep := (EchoExecutor{}).PrepareRequest(httpReq, auth); errPrep != nil {
|
||||||
|
return nil, errPrep
|
||||||
|
}
|
||||||
|
return http.DefaultClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||||
|
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||||
|
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (EchoExecutor) Refresh(context.Context, *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return nil, errors.New("echo executor: Refresh not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (EchoExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||||
|
return clipexec.Response{}, errors.New("echo executor: CountTokens not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
core := coreauth.NewManager(nil, nil, nil)
|
||||||
|
core.RegisterExecutor(EchoExecutor{})
|
||||||
|
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "demo-echo",
|
||||||
|
Provider: providerKey,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "demo-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 1: Build a prepared request and execute it using your own http.Client.
|
||||||
|
reqPrepared, errReqPrepared := core.NewHttpRequest(
|
||||||
|
ctx,
|
||||||
|
auth,
|
||||||
|
http.MethodGet,
|
||||||
|
"https://httpbin.org/anything",
|
||||||
|
nil,
|
||||||
|
http.Header{"X-Example": []string{"prepared"}},
|
||||||
|
)
|
||||||
|
if errReqPrepared != nil {
|
||||||
|
panic(errReqPrepared)
|
||||||
|
}
|
||||||
|
respPrepared, errDoPrepared := http.DefaultClient.Do(reqPrepared)
|
||||||
|
if errDoPrepared != nil {
|
||||||
|
panic(errDoPrepared)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := respPrepared.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
bodyPrepared, errReadPrepared := io.ReadAll(respPrepared.Body)
|
||||||
|
if errReadPrepared != nil {
|
||||||
|
panic(errReadPrepared)
|
||||||
|
}
|
||||||
|
fmt.Printf("Prepared request status: %d\n%s\n\n", respPrepared.StatusCode, bodyPrepared)
|
||||||
|
|
||||||
|
// Example 2: Execute a raw request via core.HttpRequest (auto inject + do).
|
||||||
|
rawBody := []byte(`{"hello":"world"}`)
|
||||||
|
rawReq, errRawReq := http.NewRequestWithContext(ctx, http.MethodPost, "https://httpbin.org/anything", bytes.NewReader(rawBody))
|
||||||
|
if errRawReq != nil {
|
||||||
|
panic(errRawReq)
|
||||||
|
}
|
||||||
|
rawReq.Header.Set("Content-Type", "application/json")
|
||||||
|
rawReq.Header.Set("X-Example", "executed")
|
||||||
|
|
||||||
|
respExec, errDoExec := core.HttpRequest(ctx, auth, rawReq)
|
||||||
|
if errDoExec != nil {
|
||||||
|
panic(errDoExec)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := respExec.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
bodyExec, errReadExec := io.ReadAll(respExec.Body)
|
||||||
|
if errReadExec != nil {
|
||||||
|
panic(errReadExec)
|
||||||
|
}
|
||||||
|
fmt.Printf("Manager HttpRequest status: %d\n%s\n", respExec.StatusCode, bodyExec)
|
||||||
|
}
|
||||||
1
go.mod
1
go.mod
@@ -13,6 +13,7 @@ require (
|
|||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/klauspost/compress v1.17.4
|
github.com/klauspost/compress v1.17.4
|
||||||
github.com/minio/minio-go/v7 v7.0.66
|
github.com/minio/minio-go/v7 v7.0.66
|
||||||
|
github.com/refraction-networking/utls v1.8.2
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -118,6 +118,8 @@ github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
|||||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||||
|
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||||
|
|||||||
@@ -33,6 +33,13 @@ var geminiOAuthScopes = []string{
|
|||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
|
antigravityOAuthClientSecret = "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"
|
||||||
|
)
|
||||||
|
|
||||||
|
var antigravityOAuthTokenURL = "https://oauth2.googleapis.com/token"
|
||||||
|
|
||||||
type apiCallRequest struct {
|
type apiCallRequest struct {
|
||||||
AuthIndexSnake *string `json:"auth_index"`
|
AuthIndexSnake *string `json:"auth_index"`
|
||||||
AuthIndexCamel *string `json:"authIndex"`
|
AuthIndexCamel *string `json:"authIndex"`
|
||||||
@@ -251,6 +258,10 @@ func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth)
|
|||||||
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
|
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
|
||||||
return token, errToken
|
return token, errToken
|
||||||
}
|
}
|
||||||
|
if provider == "antigravity" {
|
||||||
|
token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth)
|
||||||
|
return token, errToken
|
||||||
|
}
|
||||||
|
|
||||||
return tokenValueForAuth(auth), nil
|
return tokenValueForAuth(auth), nil
|
||||||
}
|
}
|
||||||
@@ -325,6 +336,161 @@ func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *corea
|
|||||||
return strings.TrimSpace(currentToken.AccessToken), nil
|
return strings.TrimSpace(currentToken.AccessToken), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if auth == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := auth.Metadata
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return "", fmt.Errorf("antigravity oauth metadata missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
current := strings.TrimSpace(tokenValueFromMetadata(metadata))
|
||||||
|
if current != "" && !antigravityTokenNeedsRefresh(metadata) {
|
||||||
|
return current, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := stringValue(metadata, "refresh_token")
|
||||||
|
if refreshToken == "" {
|
||||||
|
return "", fmt.Errorf("antigravity refresh token missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenURL := strings.TrimSpace(antigravityOAuthTokenURL)
|
||||||
|
if tokenURL == "" {
|
||||||
|
tokenURL = "https://oauth2.googleapis.com/token"
|
||||||
|
}
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set("client_id", antigravityOAuthClientID)
|
||||||
|
form.Set("client_secret", antigravityOAuthClientSecret)
|
||||||
|
form.Set("grant_type", "refresh_token")
|
||||||
|
form.Set("refresh_token", refreshToken)
|
||||||
|
|
||||||
|
req, errReq := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||||
|
if errReq != nil {
|
||||||
|
return "", errReq
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: defaultAPICallTimeout,
|
||||||
|
Transport: h.apiCallTransport(auth),
|
||||||
|
}
|
||||||
|
resp, errDo := httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", errDo
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("response body close error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
return "", errRead
|
||||||
|
}
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return "", fmt.Errorf("antigravity oauth token refresh failed: status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}
|
||||||
|
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
|
||||||
|
return "", errUnmarshal
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||||
|
return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth.Metadata == nil {
|
||||||
|
auth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
auth.Metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
|
||||||
|
if strings.TrimSpace(tokenResp.RefreshToken) != "" {
|
||||||
|
auth.Metadata["refresh_token"] = strings.TrimSpace(tokenResp.RefreshToken)
|
||||||
|
}
|
||||||
|
if tokenResp.ExpiresIn > 0 {
|
||||||
|
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
|
||||||
|
auth.Metadata["timestamp"] = now.UnixMilli()
|
||||||
|
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
auth.Metadata["type"] = "antigravity"
|
||||||
|
|
||||||
|
if h != nil && h.authManager != nil {
|
||||||
|
auth.LastRefreshedAt = now
|
||||||
|
auth.UpdatedAt = now
|
||||||
|
_, _ = h.authManager.Update(ctx, auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(tokenResp.AccessToken), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityTokenNeedsRefresh(metadata map[string]any) bool {
|
||||||
|
// Refresh a bit early to avoid requests racing token expiry.
|
||||||
|
const skew = 30 * time.Second
|
||||||
|
|
||||||
|
if metadata == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if expStr, ok := metadata["expired"].(string); ok {
|
||||||
|
if ts, errParse := time.Parse(time.RFC3339, strings.TrimSpace(expStr)); errParse == nil {
|
||||||
|
return !ts.After(time.Now().Add(skew))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expiresIn := int64Value(metadata["expires_in"])
|
||||||
|
timestampMs := int64Value(metadata["timestamp"])
|
||||||
|
if expiresIn > 0 && timestampMs > 0 {
|
||||||
|
exp := time.UnixMilli(timestampMs).Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
return !exp.After(time.Now().Add(skew))
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func int64Value(raw any) int64 {
|
||||||
|
switch typed := raw.(type) {
|
||||||
|
case int:
|
||||||
|
return int64(typed)
|
||||||
|
case int32:
|
||||||
|
return int64(typed)
|
||||||
|
case int64:
|
||||||
|
return typed
|
||||||
|
case uint:
|
||||||
|
return int64(typed)
|
||||||
|
case uint32:
|
||||||
|
return int64(typed)
|
||||||
|
case uint64:
|
||||||
|
if typed > uint64(^uint64(0)>>1) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int64(typed)
|
||||||
|
case float32:
|
||||||
|
return int64(typed)
|
||||||
|
case float64:
|
||||||
|
return int64(typed)
|
||||||
|
case json.Number:
|
||||||
|
if i, errParse := typed.Int64(); errParse == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if s := strings.TrimSpace(typed); s != "" {
|
||||||
|
if i, errParse := json.Number(s).Int64(); errParse == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
|
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
173
internal/api/handlers/management/api_tools_test.go
Normal file
173
internal/api/handlers/management/api_tools_test.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memoryAuthStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
items map[string]*coreauth.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
||||||
|
_ = ctx
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||||
|
for _, a := range s.items {
|
||||||
|
out = append(out, a.Clone())
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||||
|
_ = ctx
|
||||||
|
if auth == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.items == nil {
|
||||||
|
s.items = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
s.items[auth.ID] = auth.Clone()
|
||||||
|
s.mu.Unlock()
|
||||||
|
return auth.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
||||||
|
_ = ctx
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.items, id)
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
||||||
|
var callCount int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callCount++
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Fatalf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
||||||
|
t.Fatalf("unexpected content-type: %s", ct)
|
||||||
|
}
|
||||||
|
bodyBytes, _ := io.ReadAll(r.Body)
|
||||||
|
_ = r.Body.Close()
|
||||||
|
values, err := url.ParseQuery(string(bodyBytes))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse form: %v", err)
|
||||||
|
}
|
||||||
|
if values.Get("grant_type") != "refresh_token" {
|
||||||
|
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
||||||
|
}
|
||||||
|
if values.Get("refresh_token") != "rt" {
|
||||||
|
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
||||||
|
}
|
||||||
|
if values.Get("client_id") != antigravityOAuthClientID {
|
||||||
|
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
||||||
|
}
|
||||||
|
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
||||||
|
t.Fatalf("unexpected client_secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"access_token": "new-token",
|
||||||
|
"refresh_token": "rt2",
|
||||||
|
"expires_in": int64(3600),
|
||||||
|
"token_type": "Bearer",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
originalURL := antigravityOAuthTokenURL
|
||||||
|
antigravityOAuthTokenURL = srv.URL
|
||||||
|
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||||
|
|
||||||
|
store := &memoryAuthStore{}
|
||||||
|
manager := coreauth.NewManager(store, nil, nil)
|
||||||
|
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "antigravity-test.json",
|
||||||
|
FileName: "antigravity-test.json",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "antigravity",
|
||||||
|
"access_token": "old-token",
|
||||||
|
"refresh_token": "rt",
|
||||||
|
"expires_in": int64(3600),
|
||||||
|
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
||||||
|
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("register auth: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &Handler{authManager: manager}
|
||||||
|
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||||
|
}
|
||||||
|
if token != "new-token" {
|
||||||
|
t.Fatalf("expected refreshed token, got %q", token)
|
||||||
|
}
|
||||||
|
if callCount != 1 {
|
||||||
|
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, ok := manager.GetByID(auth.ID)
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth in manager after update")
|
||||||
|
}
|
||||||
|
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
||||||
|
t.Fatalf("expected manager metadata updated, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
||||||
|
var callCount int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callCount++
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
originalURL := antigravityOAuthTokenURL
|
||||||
|
antigravityOAuthTokenURL = srv.URL
|
||||||
|
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||||
|
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "antigravity-valid.json",
|
||||||
|
FileName: "antigravity-valid.json",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "antigravity",
|
||||||
|
"access_token": "ok-token",
|
||||||
|
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := &Handler{}
|
||||||
|
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||||
|
}
|
||||||
|
if token != "ok-token" {
|
||||||
|
t.Fatalf("expected existing token, got %q", token)
|
||||||
|
}
|
||||||
|
if callCount != 0 {
|
||||||
|
t.Fatalf("expected no refresh calls, got %d", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,13 +3,14 @@ package management
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -19,6 +20,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||||
@@ -230,14 +232,6 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
|||||||
log.Infof("callback forwarder on port %d stopped", port)
|
log.Infof("callback forwarder on port %d stopped", port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeAntigravityFileName(email string) string {
|
|
||||||
if strings.TrimSpace(email) == "" {
|
|
||||||
return "antigravity.json"
|
|
||||||
}
|
|
||||||
replacer := strings.NewReplacer("@", "_", ".", "_")
|
|
||||||
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
||||||
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
||||||
return "", fmt.Errorf("server port is not configured")
|
return "", fmt.Errorf("server port is not configured")
|
||||||
@@ -460,6 +454,12 @@ func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
|
|||||||
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
|
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
|
||||||
result["plan_type"] = v
|
result["plan_type"] = v
|
||||||
}
|
}
|
||||||
|
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveStart; v != nil {
|
||||||
|
result["chatgpt_subscription_active_start"] = v
|
||||||
|
}
|
||||||
|
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveUntil; v != nil {
|
||||||
|
result["chatgpt_subscription_active_until"] = v
|
||||||
|
}
|
||||||
|
|
||||||
if len(result) == 0 {
|
if len(result) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -741,6 +741,72 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PatchAuthFileStatus toggles the disabled state of an auth file
|
||||||
|
func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||||
|
if h.authManager == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Disabled *bool `json:"disabled"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimSpace(req.Name)
|
||||||
|
if name == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Disabled == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "disabled is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
// Find auth by name or ID
|
||||||
|
var targetAuth *coreauth.Auth
|
||||||
|
if auth, ok := h.authManager.GetByID(name); ok {
|
||||||
|
targetAuth = auth
|
||||||
|
} else {
|
||||||
|
auths := h.authManager.List()
|
||||||
|
for _, auth := range auths {
|
||||||
|
if auth.FileName == name {
|
||||||
|
targetAuth = auth
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetAuth == nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update disabled state
|
||||||
|
targetAuth.Disabled = *req.Disabled
|
||||||
|
if *req.Disabled {
|
||||||
|
targetAuth.Status = coreauth.StatusDisabled
|
||||||
|
targetAuth.StatusMessage = "disabled via management API"
|
||||||
|
} else {
|
||||||
|
targetAuth.Status = coreauth.StatusActive
|
||||||
|
targetAuth.StatusMessage = ""
|
||||||
|
}
|
||||||
|
targetAuth.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||||
if h == nil || h.authManager == nil {
|
if h == nil || h.authManager == nil {
|
||||||
return
|
return
|
||||||
@@ -907,67 +973,14 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
rawCode := resultMap["code"]
|
rawCode := resultMap["code"]
|
||||||
code := strings.Split(rawCode, "#")[0]
|
code := strings.Split(rawCode, "#")[0]
|
||||||
|
|
||||||
// Exchange code for tokens (replicate logic using updated redirect_uri)
|
// Exchange code for tokens using internal auth service
|
||||||
// Extract client_id from the modified auth URL
|
bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes)
|
||||||
clientID := ""
|
if errExchange != nil {
|
||||||
if u2, errP := url.Parse(authURL); errP == nil {
|
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange)
|
||||||
clientID = u2.Query().Get("client_id")
|
|
||||||
}
|
|
||||||
// Build request
|
|
||||||
bodyMap := map[string]any{
|
|
||||||
"code": code,
|
|
||||||
"state": state,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"client_id": clientID,
|
|
||||||
"redirect_uri": "http://localhost:54545/callback",
|
|
||||||
"code_verifier": pkceCodes.CodeVerifier,
|
|
||||||
}
|
|
||||||
bodyJSON, _ := json.Marshal(bodyMap)
|
|
||||||
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("failed to close response body: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var tResp struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
Account struct {
|
|
||||||
EmailAddress string `json:"email_address"`
|
|
||||||
} `json:"account"`
|
|
||||||
}
|
|
||||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bundle := &claude.ClaudeAuthBundle{
|
|
||||||
TokenData: claude.ClaudeTokenData{
|
|
||||||
AccessToken: tResp.AccessToken,
|
|
||||||
RefreshToken: tResp.RefreshToken,
|
|
||||||
Email: tResp.Account.EmailAddress,
|
|
||||||
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create token storage
|
// Create token storage
|
||||||
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
||||||
@@ -1007,17 +1020,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
|
|
||||||
fmt.Println("Initializing Google authentication...")
|
fmt.Println("Initializing Google authentication...")
|
||||||
|
|
||||||
// OAuth2 configuration (mirrors internal/auth/gemini)
|
// OAuth2 configuration using exported constants from internal/auth/gemini
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com",
|
ClientID: geminiAuth.ClientID,
|
||||||
ClientSecret: "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",
|
ClientSecret: geminiAuth.ClientSecret,
|
||||||
RedirectURL: "http://localhost:8085/oauth2callback",
|
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort),
|
||||||
Scopes: []string{
|
Scopes: geminiAuth.Scopes,
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
Endpoint: google.Endpoint,
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
},
|
|
||||||
Endpoint: google.Endpoint,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build authorization URL and return it immediately
|
// Build authorization URL and return it immediately
|
||||||
@@ -1139,13 +1148,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||||
ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
ifToken["client_id"] = geminiAuth.ClientID
|
||||||
ifToken["client_secret"] = "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"
|
ifToken["client_secret"] = geminiAuth.ClientSecret
|
||||||
ifToken["scopes"] = []string{
|
ifToken["scopes"] = geminiAuth.Scopes
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
}
|
|
||||||
ifToken["universe_domain"] = "googleapis.com"
|
ifToken["universe_domain"] = "googleapis.com"
|
||||||
|
|
||||||
ts := geminiAuth.GeminiTokenStorage{
|
ts := geminiAuth.GeminiTokenStorage{
|
||||||
@@ -1332,74 +1337,34 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("Authorization code received, exchanging for tokens...")
|
log.Debug("Authorization code received, exchanging for tokens...")
|
||||||
// Extract client_id from authURL
|
// Exchange code for tokens using internal auth service
|
||||||
clientID := ""
|
bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes)
|
||||||
if u2, errP := url.Parse(authURL); errP == nil {
|
if errExchange != nil {
|
||||||
clientID = u2.Query().Get("client_id")
|
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange)
|
||||||
}
|
|
||||||
// Exchange code for tokens with redirect equal to mgmtRedirect
|
|
||||||
form := url.Values{
|
|
||||||
"grant_type": {"authorization_code"},
|
|
||||||
"client_id": {clientID},
|
|
||||||
"code": {code},
|
|
||||||
"redirect_uri": {"http://localhost:1455/auth/callback"},
|
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
|
||||||
}
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
|
||||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
// Extract additional info for filename generation
|
||||||
if resp.StatusCode != http.StatusOK {
|
claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken)
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
planType := ""
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
hashAccountID := ""
|
||||||
return
|
|
||||||
}
|
|
||||||
var tokenResp struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
IDToken string `json:"id_token"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
}
|
|
||||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
|
|
||||||
email := ""
|
|
||||||
accountID := ""
|
|
||||||
if claims != nil {
|
if claims != nil {
|
||||||
email = claims.GetUserEmail()
|
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||||
accountID = claims.GetAccountID()
|
if accountID := claims.GetAccountID(); accountID != "" {
|
||||||
}
|
digest := sha256.Sum256([]byte(accountID))
|
||||||
// Build bundle compatible with existing storage
|
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||||
bundle := &codex.CodexAuthBundle{
|
}
|
||||||
TokenData: codex.CodexTokenData{
|
|
||||||
IDToken: tokenResp.IDToken,
|
|
||||||
AccessToken: tokenResp.AccessToken,
|
|
||||||
RefreshToken: tokenResp.RefreshToken,
|
|
||||||
AccountID: accountID,
|
|
||||||
Email: email,
|
|
||||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create token storage and persist
|
// Create token storage and persist
|
||||||
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
|
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
|
||||||
|
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
|
ID: fileName,
|
||||||
Provider: "codex",
|
Provider: "codex",
|
||||||
FileName: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
|
FileName: fileName,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"email": tokenStorage.Email,
|
"email": tokenStorage.Email,
|
||||||
@@ -1425,23 +1390,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||||
const (
|
|
||||||
antigravityCallbackPort = 51121
|
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
|
||||||
antigravityClientSecret = "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"
|
|
||||||
)
|
|
||||||
var antigravityScopes = []string{
|
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
"https://www.googleapis.com/auth/cclog",
|
|
||||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
fmt.Println("Initializing Antigravity authentication...")
|
fmt.Println("Initializing Antigravity authentication...")
|
||||||
|
|
||||||
|
authSvc := antigravity.NewAntigravityAuth(h.cfg, nil)
|
||||||
|
|
||||||
state, errState := misc.GenerateRandomState()
|
state, errState := misc.GenerateRandomState()
|
||||||
if errState != nil {
|
if errState != nil {
|
||||||
log.Errorf("Failed to generate state parameter: %v", errState)
|
log.Errorf("Failed to generate state parameter: %v", errState)
|
||||||
@@ -1449,17 +1403,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
|
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort)
|
||||||
|
authURL := authSvc.BuildAuthURL(state, redirectURI)
|
||||||
params := url.Values{}
|
|
||||||
params.Set("access_type", "offline")
|
|
||||||
params.Set("client_id", antigravityClientID)
|
|
||||||
params.Set("prompt", "consent")
|
|
||||||
params.Set("redirect_uri", redirectURI)
|
|
||||||
params.Set("response_type", "code")
|
|
||||||
params.Set("scope", strings.Join(antigravityScopes, " "))
|
|
||||||
params.Set("state", state)
|
|
||||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
|
||||||
|
|
||||||
RegisterOAuthSession(state, "antigravity")
|
RegisterOAuthSession(state, "antigravity")
|
||||||
|
|
||||||
@@ -1473,7 +1418,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var errStart error
|
var errStart error
|
||||||
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil {
|
||||||
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
return
|
return
|
||||||
@@ -1482,7 +1427,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if isWebUI {
|
if isWebUI {
|
||||||
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
|
defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder)
|
||||||
}
|
}
|
||||||
|
|
||||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
||||||
@@ -1522,93 +1467,36 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI)
|
||||||
form := url.Values{}
|
if errToken != nil {
|
||||||
form.Set("code", authCode)
|
log.Errorf("Failed to exchange token: %v", errToken)
|
||||||
form.Set("client_id", antigravityClientID)
|
|
||||||
form.Set("client_secret", antigravityClientSecret)
|
|
||||||
form.Set("redirect_uri", redirectURI)
|
|
||||||
form.Set("grant_type", "authorization_code")
|
|
||||||
|
|
||||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
|
||||||
if errNewRequest != nil {
|
|
||||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
|
||||||
SetOAuthSessionError(state, "Failed to build token request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
log.Errorf("Failed to execute token request: %v", errDo)
|
|
||||||
SetOAuthSessionError(state, "Failed to exchange token")
|
SetOAuthSessionError(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity token exchange close error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
if accessToken == "" {
|
||||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
log.Error("antigravity: token exchange returned empty access token")
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
SetOAuthSessionError(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenResp struct {
|
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
|
||||||
AccessToken string `json:"access_token"`
|
if errInfo != nil {
|
||||||
RefreshToken string `json:"refresh_token"`
|
log.Errorf("Failed to fetch user info: %v", errInfo)
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
SetOAuthSessionError(state, "Failed to fetch user info")
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
}
|
|
||||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
|
||||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
email := ""
|
if email == "" {
|
||||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
log.Error("antigravity: user info returned empty email")
|
||||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
SetOAuthSessionError(state, "Failed to fetch user info")
|
||||||
if errInfoReq != nil {
|
return
|
||||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
|
||||||
SetOAuthSessionError(state, "Failed to build user info request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
|
||||||
|
|
||||||
infoResp, errInfo := httpClient.Do(infoReq)
|
|
||||||
if errInfo != nil {
|
|
||||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
|
||||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if errClose := infoResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity user info close error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
|
|
||||||
var infoPayload struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
}
|
|
||||||
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
|
|
||||||
email = strings.TrimSpace(infoPayload.Email)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
|
||||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
projectID := ""
|
projectID := ""
|
||||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
if accessToken != "" {
|
||||||
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
|
||||||
if errProject != nil {
|
if errProject != nil {
|
||||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||||
} else {
|
} else {
|
||||||
@@ -1633,7 +1521,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
metadata["project_id"] = projectID
|
metadata["project_id"] = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := sanitizeAntigravityFileName(email)
|
fileName := antigravity.CredentialFileName(email)
|
||||||
label := strings.TrimSpace(email)
|
label := strings.TrimSpace(email)
|
||||||
if label == "" {
|
if label == "" {
|
||||||
label = "antigravity"
|
label = "antigravity"
|
||||||
@@ -1697,7 +1585,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
// Create token storage
|
// Create token storage
|
||||||
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
|
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
|
||||||
|
|
||||||
tokenStorage.Email = fmt.Sprintf("qwen-%d", time.Now().UnixMilli())
|
tokenStorage.Email = fmt.Sprintf("%d", time.Now().UnixMilli())
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
|
ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
|
||||||
Provider: "qwen",
|
Provider: "qwen",
|
||||||
@@ -1802,7 +1690,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
tokenStorage := authSvc.CreateTokenStorage(tokenData)
|
tokenStorage := authSvc.CreateTokenStorage(tokenData)
|
||||||
identifier := strings.TrimSpace(tokenStorage.Email)
|
identifier := strings.TrimSpace(tokenStorage.Email)
|
||||||
if identifier == "" {
|
if identifier == "" {
|
||||||
identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
|
identifier = fmt.Sprintf("%d", time.Now().UnixMilli())
|
||||||
tokenStorage.Email = identifier
|
tokenStorage.Email = identifier
|
||||||
}
|
}
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
@@ -1887,15 +1775,17 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
|||||||
fileName := iflowauth.SanitizeIFlowFileName(email)
|
fileName := iflowauth.SanitizeIFlowFileName(email)
|
||||||
if fileName == "" {
|
if fileName == "" {
|
||||||
fileName = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
|
fileName = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
|
||||||
|
} else {
|
||||||
|
fileName = fmt.Sprintf("iflow-%s", fileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenStorage.Email = email
|
tokenStorage.Email = email
|
||||||
timestamp := time.Now().Unix()
|
timestamp := time.Now().Unix()
|
||||||
|
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
ID: fmt.Sprintf("%s-%d.json", fileName, timestamp),
|
||||||
Provider: "iflow",
|
Provider: "iflow",
|
||||||
FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
FileName: fmt.Sprintf("%s-%d.json", fileName, timestamp),
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"email": email,
|
"email": email,
|
||||||
@@ -2102,7 +1992,20 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
|||||||
finalProjectID := projectID
|
finalProjectID := projectID
|
||||||
if responseProjectID != "" {
|
if responseProjectID != "" {
|
||||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||||
|
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||||
|
strings.EqualFold(tierID, "FREE") ||
|
||||||
|
strings.EqualFold(tierID, "LEGACY")
|
||||||
|
|
||||||
|
if isFreeUser {
|
||||||
|
// For free users, use backend project ID for preview model access
|
||||||
|
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
|
||||||
|
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
|
||||||
|
finalProjectID = responseProjectID
|
||||||
|
} else {
|
||||||
|
// Pro users: keep requested project ID (original behavior)
|
||||||
|
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
finalProjectID = responseProjectID
|
finalProjectID = responseProjectID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -202,6 +202,46 @@ func (h *Handler) PutLoggingToFile(c *gin.Context) {
|
|||||||
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
|
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogsMaxTotalSizeMB
|
||||||
|
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
|
||||||
|
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
|
||||||
|
}
|
||||||
|
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value *int `json:"value"`
|
||||||
|
}
|
||||||
|
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := *body.Value
|
||||||
|
if value < 0 {
|
||||||
|
value = 0
|
||||||
|
}
|
||||||
|
h.cfg.LogsMaxTotalSizeMB = value
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorLogsMaxFiles
|
||||||
|
func (h *Handler) GetErrorLogsMaxFiles(c *gin.Context) {
|
||||||
|
c.JSON(200, gin.H{"error-logs-max-files": h.cfg.ErrorLogsMaxFiles})
|
||||||
|
}
|
||||||
|
func (h *Handler) PutErrorLogsMaxFiles(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value *int `json:"value"`
|
||||||
|
}
|
||||||
|
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := *body.Value
|
||||||
|
if value < 0 {
|
||||||
|
value = 10
|
||||||
|
}
|
||||||
|
h.cfg.ErrorLogsMaxFiles = value
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
// Request log
|
// Request log
|
||||||
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
||||||
func (h *Handler) PutRequestLog(c *gin.Context) {
|
func (h *Handler) PutRequestLog(c *gin.Context) {
|
||||||
@@ -232,6 +272,52 @@ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
|
|||||||
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
|
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForceModelPrefix
|
||||||
|
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
|
||||||
|
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
|
||||||
|
}
|
||||||
|
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
|
||||||
|
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeRoutingStrategy(strategy string) (string, bool) {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(strategy))
|
||||||
|
switch normalized {
|
||||||
|
case "", "round-robin", "roundrobin", "rr":
|
||||||
|
return "round-robin", true
|
||||||
|
case "fill-first", "fillfirst", "ff":
|
||||||
|
return "fill-first", true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoutingStrategy
|
||||||
|
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
|
||||||
|
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"strategy": strategy})
|
||||||
|
}
|
||||||
|
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value *string `json:"value"`
|
||||||
|
}
|
||||||
|
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
normalized, ok := normalizeRoutingStrategy(*body.Value)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.cfg.Routing.Strategy = normalized
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
// Proxy URL
|
// Proxy URL
|
||||||
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
||||||
func (h *Handler) PutProxyURL(c *gin.Context) {
|
func (h *Handler) PutProxyURL(c *gin.Context) {
|
||||||
|
|||||||
@@ -487,6 +487,137 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
|||||||
c.JSON(400, gin.H{"error": "missing name or index"})
|
c.JSON(400, gin.H{"error": "missing name or index"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// vertex-api-key: []VertexCompatKey
|
||||||
|
func (h *Handler) GetVertexCompatKeys(c *gin.Context) {
|
||||||
|
c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey})
|
||||||
|
}
|
||||||
|
func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
||||||
|
data, err := c.GetRawData()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var arr []config.VertexCompatKey
|
||||||
|
if err = json.Unmarshal(data, &arr); err != nil {
|
||||||
|
var obj struct {
|
||||||
|
Items []config.VertexCompatKey `json:"items"`
|
||||||
|
}
|
||||||
|
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
arr = obj.Items
|
||||||
|
}
|
||||||
|
for i := range arr {
|
||||||
|
normalizeVertexCompatKey(&arr[i])
|
||||||
|
}
|
||||||
|
h.cfg.VertexCompatAPIKey = arr
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
||||||
|
type vertexCompatPatch struct {
|
||||||
|
APIKey *string `json:"api-key"`
|
||||||
|
Prefix *string `json:"prefix"`
|
||||||
|
BaseURL *string `json:"base-url"`
|
||||||
|
ProxyURL *string `json:"proxy-url"`
|
||||||
|
Headers *map[string]string `json:"headers"`
|
||||||
|
Models *[]config.VertexCompatModel `json:"models"`
|
||||||
|
}
|
||||||
|
var body struct {
|
||||||
|
Index *int `json:"index"`
|
||||||
|
Match *string `json:"match"`
|
||||||
|
Value *vertexCompatPatch `json:"value"`
|
||||||
|
}
|
||||||
|
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
targetIndex := -1
|
||||||
|
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) {
|
||||||
|
targetIndex = *body.Index
|
||||||
|
}
|
||||||
|
if targetIndex == -1 && body.Match != nil {
|
||||||
|
match := strings.TrimSpace(*body.Match)
|
||||||
|
if match != "" {
|
||||||
|
for i := range h.cfg.VertexCompatAPIKey {
|
||||||
|
if h.cfg.VertexCompatAPIKey[i].APIKey == match {
|
||||||
|
targetIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if targetIndex == -1 {
|
||||||
|
c.JSON(404, gin.H{"error": "item not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := h.cfg.VertexCompatAPIKey[targetIndex]
|
||||||
|
if body.Value.APIKey != nil {
|
||||||
|
trimmed := strings.TrimSpace(*body.Value.APIKey)
|
||||||
|
if trimmed == "" {
|
||||||
|
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry.APIKey = trimmed
|
||||||
|
}
|
||||||
|
if body.Value.Prefix != nil {
|
||||||
|
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||||
|
}
|
||||||
|
if body.Value.BaseURL != nil {
|
||||||
|
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||||
|
if trimmed == "" {
|
||||||
|
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry.BaseURL = trimmed
|
||||||
|
}
|
||||||
|
if body.Value.ProxyURL != nil {
|
||||||
|
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||||
|
}
|
||||||
|
if body.Value.Headers != nil {
|
||||||
|
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||||
|
}
|
||||||
|
if body.Value.Models != nil {
|
||||||
|
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
|
||||||
|
}
|
||||||
|
normalizeVertexCompatKey(&entry)
|
||||||
|
h.cfg.VertexCompatAPIKey[targetIndex] = entry
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
||||||
|
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||||
|
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
||||||
|
for _, v := range h.cfg.VertexCompatAPIKey {
|
||||||
|
if v.APIKey != val {
|
||||||
|
out = append(out, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.cfg.VertexCompatAPIKey = out
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if idxStr := c.Query("index"); idxStr != "" {
|
||||||
|
var idx int
|
||||||
|
_, errScan := fmt.Sscanf(idxStr, "%d", &idx)
|
||||||
|
if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) {
|
||||||
|
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...)
|
||||||
|
h.cfg.SanitizeVertexCompatKeys()
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
||||||
|
}
|
||||||
|
|
||||||
// oauth-excluded-models: map[string][]string
|
// oauth-excluded-models: map[string][]string
|
||||||
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
|
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
|
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
|
||||||
@@ -572,6 +703,103 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
|
|||||||
h.persist(c)
|
h.persist(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oauth-model-alias: map[string][]OAuthModelAlias
|
||||||
|
func (h *Handler) GetOAuthModelAlias(c *gin.Context) {
|
||||||
|
c.JSON(200, gin.H{"oauth-model-alias": sanitizedOAuthModelAlias(h.cfg.OAuthModelAlias)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) PutOAuthModelAlias(c *gin.Context) {
|
||||||
|
data, err := c.GetRawData()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var entries map[string][]config.OAuthModelAlias
|
||||||
|
if err = json.Unmarshal(data, &entries); err != nil {
|
||||||
|
var wrapper struct {
|
||||||
|
Items map[string][]config.OAuthModelAlias `json:"items"`
|
||||||
|
}
|
||||||
|
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entries = wrapper.Items
|
||||||
|
}
|
||||||
|
h.cfg.OAuthModelAlias = sanitizedOAuthModelAlias(entries)
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) PatchOAuthModelAlias(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Provider *string `json:"provider"`
|
||||||
|
Channel *string `json:"channel"`
|
||||||
|
Aliases []config.OAuthModelAlias `json:"aliases"`
|
||||||
|
}
|
||||||
|
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channelRaw := ""
|
||||||
|
if body.Channel != nil {
|
||||||
|
channelRaw = *body.Channel
|
||||||
|
} else if body.Provider != nil {
|
||||||
|
channelRaw = *body.Provider
|
||||||
|
}
|
||||||
|
channel := strings.ToLower(strings.TrimSpace(channelRaw))
|
||||||
|
if channel == "" {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid channel"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases})
|
||||||
|
normalized := normalizedMap[channel]
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
if h.cfg.OAuthModelAlias == nil {
|
||||||
|
c.JSON(404, gin.H{"error": "channel not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
|
||||||
|
c.JSON(404, gin.H{"error": "channel not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(h.cfg.OAuthModelAlias, channel)
|
||||||
|
if len(h.cfg.OAuthModelAlias) == 0 {
|
||||||
|
h.cfg.OAuthModelAlias = nil
|
||||||
|
}
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.cfg.OAuthModelAlias == nil {
|
||||||
|
h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias)
|
||||||
|
}
|
||||||
|
h.cfg.OAuthModelAlias[channel] = normalized
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
|
||||||
|
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
|
||||||
|
if channel == "" {
|
||||||
|
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
|
||||||
|
}
|
||||||
|
if channel == "" {
|
||||||
|
c.JSON(400, gin.H{"error": "missing channel"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.cfg.OAuthModelAlias == nil {
|
||||||
|
c.JSON(404, gin.H{"error": "channel not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
|
||||||
|
c.JSON(404, gin.H{"error": "channel not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(h.cfg.OAuthModelAlias, channel)
|
||||||
|
if len(h.cfg.OAuthModelAlias) == 0 {
|
||||||
|
h.cfg.OAuthModelAlias = nil
|
||||||
|
}
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
// codex-api-key: []CodexKey
|
// codex-api-key: []CodexKey
|
||||||
func (h *Handler) GetCodexKeys(c *gin.Context) {
|
func (h *Handler) GetCodexKeys(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
|
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
|
||||||
@@ -789,6 +1017,53 @@ func normalizeCodexKey(entry *config.CodexKey) {
|
|||||||
entry.Models = normalized
|
entry.Models = normalized
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
|
||||||
|
if entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||||
|
entry.Prefix = strings.TrimSpace(entry.Prefix)
|
||||||
|
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||||
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
|
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||||
|
if len(entry.Models) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
normalized := make([]config.VertexCompatModel, 0, len(entry.Models))
|
||||||
|
for i := range entry.Models {
|
||||||
|
model := entry.Models[i]
|
||||||
|
model.Name = strings.TrimSpace(model.Name)
|
||||||
|
model.Alias = strings.TrimSpace(model.Alias)
|
||||||
|
if model.Name == "" || model.Alias == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalized = append(normalized, model)
|
||||||
|
}
|
||||||
|
entry.Models = normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizedOAuthModelAlias(entries map[string][]config.OAuthModelAlias) map[string][]config.OAuthModelAlias {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
copied := make(map[string][]config.OAuthModelAlias, len(entries))
|
||||||
|
for channel, aliases := range entries {
|
||||||
|
if len(aliases) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copied[channel] = append([]config.OAuthModelAlias(nil), aliases...)
|
||||||
|
}
|
||||||
|
if len(copied) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cfg := config.Config{OAuthModelAlias: copied}
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
if len(cfg.OAuthModelAlias) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cfg.OAuthModelAlias
|
||||||
|
}
|
||||||
|
|
||||||
// GetAmpCode returns the complete ampcode configuration.
|
// GetAmpCode returns the complete ampcode configuration.
|
||||||
func (h *Handler) GetAmpCode(c *gin.Context) {
|
func (h *Handler) GetAmpCode(c *gin.Context) {
|
||||||
if h == nil || h.cfg == nil {
|
if h == nil || h.cfg == nil {
|
||||||
|
|||||||
@@ -24,8 +24,15 @@ import (
|
|||||||
type attemptInfo struct {
|
type attemptInfo struct {
|
||||||
count int
|
count int
|
||||||
blockedUntil time.Time
|
blockedUntil time.Time
|
||||||
|
lastActivity time.Time // track last activity for cleanup
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// attemptCleanupInterval controls how often stale IP entries are purged
|
||||||
|
const attemptCleanupInterval = 1 * time.Hour
|
||||||
|
|
||||||
|
// attemptMaxIdleTime controls how long an IP can be idle before cleanup
|
||||||
|
const attemptMaxIdleTime = 2 * time.Hour
|
||||||
|
|
||||||
// Handler aggregates config reference, persistence path and helpers.
|
// Handler aggregates config reference, persistence path and helpers.
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -47,7 +54,7 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
|||||||
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
|
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
|
||||||
envSecret = strings.TrimSpace(envSecret)
|
envSecret = strings.TrimSpace(envSecret)
|
||||||
|
|
||||||
return &Handler{
|
h := &Handler{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
configFilePath: configFilePath,
|
configFilePath: configFilePath,
|
||||||
failedAttempts: make(map[string]*attemptInfo),
|
failedAttempts: make(map[string]*attemptInfo),
|
||||||
@@ -57,6 +64,38 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
|||||||
allowRemoteOverride: envSecret != "",
|
allowRemoteOverride: envSecret != "",
|
||||||
envSecret: envSecret,
|
envSecret: envSecret,
|
||||||
}
|
}
|
||||||
|
h.startAttemptCleanup()
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// startAttemptCleanup launches a background goroutine that periodically
|
||||||
|
// removes stale IP entries from failedAttempts to prevent memory leaks.
|
||||||
|
func (h *Handler) startAttemptCleanup() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(attemptCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
h.purgeStaleAttempts()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// purgeStaleAttempts removes IP entries that have been idle beyond attemptMaxIdleTime
|
||||||
|
// and whose ban (if any) has expired.
|
||||||
|
func (h *Handler) purgeStaleAttempts() {
|
||||||
|
now := time.Now()
|
||||||
|
h.attemptsMu.Lock()
|
||||||
|
defer h.attemptsMu.Unlock()
|
||||||
|
for ip, ai := range h.failedAttempts {
|
||||||
|
// Skip if still banned
|
||||||
|
if !ai.blockedUntil.IsZero() && now.Before(ai.blockedUntil) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Remove if idle too long
|
||||||
|
if now.Sub(ai.lastActivity) > attemptMaxIdleTime {
|
||||||
|
delete(h.failedAttempts, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new management handler instance.
|
// NewHandler creates a new management handler instance.
|
||||||
@@ -149,6 +188,7 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
|||||||
h.failedAttempts[clientIP] = aip
|
h.failedAttempts[clientIP] = aip
|
||||||
}
|
}
|
||||||
aip.count++
|
aip.count++
|
||||||
|
aip.lastActivity = time.Now()
|
||||||
if aip.count >= maxFailures {
|
if aip.count >= maxFailures {
|
||||||
aip.blockedUntil = time.Now().Add(banDuration)
|
aip.blockedUntil = time.Now().Add(banDuration)
|
||||||
aip.count = 0
|
aip.count = 0
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -360,16 +360,7 @@ func (h *Handler) logDirectory() string {
|
|||||||
if h.logDir != "" {
|
if h.logDir != "" {
|
||||||
return h.logDir
|
return h.logDir
|
||||||
}
|
}
|
||||||
if base := util.WritablePath(); base != "" {
|
return logging.ResolveLogDirectory(h.cfg)
|
||||||
return filepath.Join(base, "logs")
|
|
||||||
}
|
|
||||||
if h.configFilePath != "" {
|
|
||||||
dir := filepath.Dir(h.configFilePath)
|
|
||||||
if dir != "" && dir != "." {
|
|
||||||
return filepath.Join(dir, "logs")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "logs"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) collectLogFiles(dir string) ([]string, error) {
|
func (h *Handler) collectLogFiles(dir string) ([]string, error) {
|
||||||
|
|||||||
33
internal/api/handlers/management/model_definitions.go
Normal file
33
internal/api/handlers/management/model_definitions.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetStaticModelDefinitions returns static model metadata for a given channel.
|
||||||
|
// Channel is provided via path param (:channel) or query param (?channel=...).
|
||||||
|
func (h *Handler) GetStaticModelDefinitions(c *gin.Context) {
|
||||||
|
channel := strings.TrimSpace(c.Param("channel"))
|
||||||
|
if channel == "" {
|
||||||
|
channel = strings.TrimSpace(c.Query("channel"))
|
||||||
|
}
|
||||||
|
if channel == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "channel is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
models := registry.GetStaticModelDefinitionsByChannel(channel)
|
||||||
|
if models == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "unknown channel", "channel": channel})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"channel": strings.ToLower(strings.TrimSpace(channel)),
|
||||||
|
"models": models,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
@@ -103,6 +104,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
|||||||
Headers: headers,
|
Headers: headers,
|
||||||
Body: body,
|
Body: body,
|
||||||
RequestID: logging.GetGinRequestID(c),
|
RequestID: logging.GetGinRequestID(c),
|
||||||
|
Timestamp: time.Now(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
@@ -20,22 +21,24 @@ type RequestInfo struct {
|
|||||||
Headers map[string][]string // Headers contains the request headers.
|
Headers map[string][]string // Headers contains the request headers.
|
||||||
Body []byte // Body is the raw request body.
|
Body []byte // Body is the raw request body.
|
||||||
RequestID string // RequestID is the unique identifier for the request.
|
RequestID string // RequestID is the unique identifier for the request.
|
||||||
|
Timestamp time.Time // Timestamp is when the request was received.
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
||||||
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
|
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
|
||||||
type ResponseWriterWrapper struct {
|
type ResponseWriterWrapper struct {
|
||||||
gin.ResponseWriter
|
gin.ResponseWriter
|
||||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||||
statusCode int // statusCode stores the HTTP status code of the response.
|
statusCode int // statusCode stores the HTTP status code of the response.
|
||||||
headers map[string][]string // headers stores the response headers.
|
headers map[string][]string // headers stores the response headers.
|
||||||
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
||||||
|
firstChunkTimestamp time.Time // firstChunkTimestamp captures TTFB for streaming responses.
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
|
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
|
||||||
@@ -73,6 +76,10 @@ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
|||||||
|
|
||||||
// THEN: Handle logging based on response type
|
// THEN: Handle logging based on response type
|
||||||
if w.isStreaming && w.chunkChannel != nil {
|
if w.isStreaming && w.chunkChannel != nil {
|
||||||
|
// Capture TTFB on first chunk (synchronous, before async channel send)
|
||||||
|
if w.firstChunkTimestamp.IsZero() {
|
||||||
|
w.firstChunkTimestamp = time.Now()
|
||||||
|
}
|
||||||
// For streaming responses: Send to async logging channel (non-blocking)
|
// For streaming responses: Send to async logging channel (non-blocking)
|
||||||
select {
|
select {
|
||||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||||
@@ -117,6 +124,10 @@ func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
|
|||||||
|
|
||||||
// THEN: Capture for logging
|
// THEN: Capture for logging
|
||||||
if w.isStreaming && w.chunkChannel != nil {
|
if w.isStreaming && w.chunkChannel != nil {
|
||||||
|
// Capture TTFB on first chunk (synchronous, before async channel send)
|
||||||
|
if w.firstChunkTimestamp.IsZero() {
|
||||||
|
w.firstChunkTimestamp = time.Now()
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case w.chunkChannel <- []byte(data):
|
case w.chunkChannel <- []byte(data):
|
||||||
default:
|
default:
|
||||||
@@ -280,6 +291,8 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
w.streamDone = nil
|
w.streamDone = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w.streamWriter.SetFirstChunkTimestamp(w.firstChunkTimestamp)
|
||||||
|
|
||||||
// Write API Request and Response to the streaming log before closing
|
// Write API Request and Response to the streaming log before closing
|
||||||
apiRequest := w.extractAPIRequest(c)
|
apiRequest := w.extractAPIRequest(c)
|
||||||
if len(apiRequest) > 0 {
|
if len(apiRequest) > 0 {
|
||||||
@@ -297,7 +310,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||||
@@ -337,7 +350,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
||||||
|
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
||||||
|
if !isExist {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
if t, ok := ts.(time.Time); ok {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||||
if w.requestInfo == nil {
|
if w.requestInfo == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -348,7 +372,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
|||||||
}
|
}
|
||||||
|
|
||||||
if loggerWithOptions, ok := w.logger.(interface {
|
if loggerWithOptions, ok := w.logger.(interface {
|
||||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
|
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||||
}); ok {
|
}); ok {
|
||||||
return loggerWithOptions.LogRequestWithOptions(
|
return loggerWithOptions.LogRequestWithOptions(
|
||||||
w.requestInfo.URL,
|
w.requestInfo.URL,
|
||||||
@@ -363,6 +387,8 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
|||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
forceLog,
|
forceLog,
|
||||||
w.requestInfo.RequestID,
|
w.requestInfo.RequestID,
|
||||||
|
w.requestInfo.Timestamp,
|
||||||
|
apiResponseTimestamp,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -378,5 +404,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
|||||||
apiResponseBody,
|
apiResponseBody,
|
||||||
apiResponseErrors,
|
apiResponseErrors,
|
||||||
w.requestInfo.RequestID,
|
w.requestInfo.RequestID,
|
||||||
|
w.requestInfo.Timestamp,
|
||||||
|
apiResponseTimestamp,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -134,10 +135,11 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Normalize model (handles dynamic thinking suffixes)
|
// Normalize model (handles dynamic thinking suffixes)
|
||||||
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
|
suffixResult := thinking.ParseSuffix(modelName)
|
||||||
|
normalizedModel := suffixResult.ModelName
|
||||||
thinkingSuffix := ""
|
thinkingSuffix := ""
|
||||||
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
|
if suffixResult.HasSuffix {
|
||||||
thinkingSuffix = modelName[len(normalizedModel):]
|
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||||
}
|
}
|
||||||
|
|
||||||
resolveMappedModel := func() (string, []string) {
|
resolveMappedModel := func() (string, []string) {
|
||||||
@@ -157,13 +159,13 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||||
// already specifies its own thinking suffix.
|
// already specifies its own thinking suffix.
|
||||||
if thinkingSuffix != "" {
|
if thinkingSuffix != "" {
|
||||||
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
|
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
|
||||||
if mappedThinkingMetadata == nil {
|
if !mappedSuffixResult.HasSuffix {
|
||||||
mappedModel += thinkingSuffix
|
mappedModel += thinkingSuffix
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
|
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||||
if len(mappedProviders) == 0 {
|
if len(mappedProviders) == 0 {
|
||||||
return "", nil
|
return "", nil
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@@ -44,6 +45,11 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
|||||||
// MapModel checks if a mapping exists for the requested model and if the
|
// MapModel checks if a mapping exists for the requested model and if the
|
||||||
// target model has available local providers. Returns the mapped model name
|
// target model has available local providers. Returns the mapped model name
|
||||||
// or empty string if no valid mapping exists.
|
// or empty string if no valid mapping exists.
|
||||||
|
//
|
||||||
|
// If the requested model contains a thinking suffix (e.g., "g25p(8192)"),
|
||||||
|
// the suffix is preserved in the returned model name (e.g., "gemini-2.5-pro(8192)").
|
||||||
|
// However, if the mapping target already contains a suffix, the config suffix
|
||||||
|
// takes priority over the user's suffix.
|
||||||
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||||
if requestedModel == "" {
|
if requestedModel == "" {
|
||||||
return ""
|
return ""
|
||||||
@@ -52,16 +58,20 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
|||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
// Normalize the requested model for lookup
|
// Extract thinking suffix from requested model using ParseSuffix
|
||||||
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
|
requestResult := thinking.ParseSuffix(requestedModel)
|
||||||
|
baseModel := requestResult.ModelName
|
||||||
|
|
||||||
// Check for direct mapping
|
// Normalize the base model for lookup (case-insensitive)
|
||||||
targetModel, exists := m.mappings[normalizedRequest]
|
normalizedBase := strings.ToLower(strings.TrimSpace(baseModel))
|
||||||
|
|
||||||
|
// Check for direct mapping using base model name
|
||||||
|
targetModel, exists := m.mappings[normalizedBase]
|
||||||
if !exists {
|
if !exists {
|
||||||
// Try regex mappings in order
|
// Try regex mappings in order using base model only
|
||||||
base, _ := util.NormalizeThinkingModel(requestedModel)
|
// (suffix is handled separately via ParseSuffix)
|
||||||
for _, rm := range m.regexps {
|
for _, rm := range m.regexps {
|
||||||
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
|
if rm.re.MatchString(baseModel) {
|
||||||
targetModel = rm.to
|
targetModel = rm.to
|
||||||
exists = true
|
exists = true
|
||||||
break
|
break
|
||||||
@@ -72,14 +82,28 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify target model has available providers
|
// Check if target model already has a thinking suffix (config priority)
|
||||||
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
|
targetResult := thinking.ParseSuffix(targetModel)
|
||||||
providers := util.GetProviderName(normalizedTarget)
|
|
||||||
|
// Verify target model has available providers (use base model for lookup)
|
||||||
|
providers := util.GetProviderName(targetResult.ModelName)
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
|
||||||
|
if targetResult.HasSuffix {
|
||||||
|
// Config's "to" already contains a suffix - use it as-is (config priority)
|
||||||
|
return targetModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve user's thinking suffix on the mapped model
|
||||||
|
// (skip empty suffixes to avoid returning "model()")
|
||||||
|
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||||
|
return targetModel + "(" + requestResult.RawSuffix + ")"
|
||||||
|
}
|
||||||
|
|
||||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||||
return targetModel
|
return targetModel
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -217,10 +217,10 @@ func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
|
|||||||
|
|
||||||
mapper := NewModelMapper(mappings)
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
// Incoming model has reasoning suffix but should match base via regex
|
// Incoming model has reasoning suffix, regex matches base, suffix is preserved
|
||||||
result := mapper.MapModel("gpt-5(high)")
|
result := mapper.MapModel("gpt-5(high)")
|
||||||
if result != "gemini-2.5-pro" {
|
if result != "gemini-2.5-pro(high)" {
|
||||||
t.Errorf("Expected gemini-2.5-pro, got %s", result)
|
t.Errorf("Expected gemini-2.5-pro(high), got %s", result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,3 +281,95 @@ func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
|
|||||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_SuffixPreservation(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
|
||||||
|
// Register test models
|
||||||
|
reg.RegisterClient("test-client-suffix", "gemini", []*registry.ModelInfo{
|
||||||
|
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||||
|
})
|
||||||
|
reg.RegisterClient("test-client-suffix-2", "claude", []*registry.ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("test-client-suffix")
|
||||||
|
defer reg.UnregisterClient("test-client-suffix-2")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mappings []config.AmpModelMapping
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "numeric suffix preserved",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||||
|
input: "g25p(8192)",
|
||||||
|
want: "gemini-2.5-pro(8192)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "level suffix preserved",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||||
|
input: "g25p(high)",
|
||||||
|
want: "gemini-2.5-pro(high)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no suffix unchanged",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||||
|
input: "g25p",
|
||||||
|
want: "gemini-2.5-pro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "config suffix takes priority",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "alias", To: "gemini-2.5-pro(medium)"}},
|
||||||
|
input: "alias(high)",
|
||||||
|
want: "gemini-2.5-pro(medium)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regex with suffix preserved",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "^g25.*", To: "gemini-2.5-pro", Regex: true}},
|
||||||
|
input: "g25p(8192)",
|
||||||
|
want: "gemini-2.5-pro(8192)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto suffix preserved",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||||
|
input: "g25p(auto)",
|
||||||
|
want: "gemini-2.5-pro(auto)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "none suffix preserved",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||||
|
input: "g25p(none)",
|
||||||
|
want: "gemini-2.5-pro(none)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive base lookup with suffix",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "G25P", To: "gemini-2.5-pro"}},
|
||||||
|
input: "g25p(high)",
|
||||||
|
want: "gemini-2.5-pro(high)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty suffix filtered out",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||||
|
input: "g25p()",
|
||||||
|
want: "gemini-2.5-pro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete suffix treated as no suffix",
|
||||||
|
mappings: []config.AmpModelMapping{{From: "g25p(high", To: "gemini-2.5-pro"}},
|
||||||
|
input: "g25p(high",
|
||||||
|
want: "gemini-2.5-pro",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mapper := NewModelMapper(tt.mappings)
|
||||||
|
got := mapper.MapModel(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("MapModel(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -69,7 +69,30 @@ func (rw *ResponseRewriter) Flush() {
|
|||||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||||
|
|
||||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||||
|
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
|
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
||||||
|
// The Amp client struggles when both thinking and tool_use blocks are present
|
||||||
|
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||||
|
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||||
|
if filtered.Exists() {
|
||||||
|
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||||
|
filteredCount := filtered.Get("#").Int()
|
||||||
|
|
||||||
|
if originalCount > filteredCount {
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
||||||
|
// Log the result for verification
|
||||||
|
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if rw.originalModel == "" {
|
if rw.originalModel == "" {
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -26,6 +27,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||||
@@ -33,6 +35,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@@ -57,9 +60,9 @@ type ServerOption func(*serverOptionConfig)
|
|||||||
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
||||||
configDir := filepath.Dir(configPath)
|
configDir := filepath.Dir(configPath)
|
||||||
if base := util.WritablePath(); base != "" {
|
if base := util.WritablePath(); base != "" {
|
||||||
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir)
|
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir, cfg.ErrorLogsMaxFiles)
|
||||||
}
|
}
|
||||||
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir)
|
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir, cfg.ErrorLogsMaxFiles)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMiddleware appends additional Gin middleware during server construction.
|
// WithMiddleware appends additional Gin middleware during server construction.
|
||||||
@@ -253,15 +256,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
}
|
}
|
||||||
managementasset.SetCurrentConfig(cfg)
|
managementasset.SetCurrentConfig(cfg)
|
||||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
|
misc.SetCodexInstructionsEnabled(cfg.CodexInstructionsEnabled)
|
||||||
// Initialize management handler
|
// Initialize management handler
|
||||||
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
||||||
if optionState.localPassword != "" {
|
if optionState.localPassword != "" {
|
||||||
s.mgmt.SetLocalPassword(optionState.localPassword)
|
s.mgmt.SetLocalPassword(optionState.localPassword)
|
||||||
}
|
}
|
||||||
logDir := filepath.Join(s.currentPath, "logs")
|
logDir := logging.ResolveLogDirectory(cfg)
|
||||||
if base := util.WritablePath(); base != "" {
|
|
||||||
logDir = filepath.Join(base, "logs")
|
|
||||||
}
|
|
||||||
s.mgmt.SetLogDirectory(logDir)
|
s.mgmt.SetLogDirectory(logDir)
|
||||||
s.localPassword = optionState.localPassword
|
s.localPassword = optionState.localPassword
|
||||||
|
|
||||||
@@ -325,6 +326,7 @@ func (s *Server) setupRoutes() {
|
|||||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||||
|
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini compatible API routes
|
// Gemini compatible API routes
|
||||||
@@ -491,6 +493,14 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
|
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
|
||||||
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
|
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
|
||||||
|
|
||||||
|
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
|
||||||
|
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||||
|
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||||
|
|
||||||
|
mgmt.GET("/error-logs-max-files", s.mgmt.GetErrorLogsMaxFiles)
|
||||||
|
mgmt.PUT("/error-logs-max-files", s.mgmt.PutErrorLogsMaxFiles)
|
||||||
|
mgmt.PATCH("/error-logs-max-files", s.mgmt.PutErrorLogsMaxFiles)
|
||||||
|
|
||||||
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
||||||
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||||
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||||
@@ -563,6 +573,14 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||||
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||||
|
|
||||||
|
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
|
||||||
|
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
||||||
|
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
||||||
|
|
||||||
|
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
|
||||||
|
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
||||||
|
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
||||||
|
|
||||||
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
||||||
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
||||||
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
|
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
|
||||||
@@ -578,16 +596,28 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
|
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
|
||||||
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
|
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
|
||||||
|
|
||||||
|
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
|
||||||
|
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
|
||||||
|
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
|
||||||
|
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
|
||||||
|
|
||||||
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
|
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
|
||||||
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
|
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
|
||||||
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
|
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
|
||||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||||
|
|
||||||
|
mgmt.GET("/oauth-model-alias", s.mgmt.GetOAuthModelAlias)
|
||||||
|
mgmt.PUT("/oauth-model-alias", s.mgmt.PutOAuthModelAlias)
|
||||||
|
mgmt.PATCH("/oauth-model-alias", s.mgmt.PatchOAuthModelAlias)
|
||||||
|
mgmt.DELETE("/oauth-model-alias", s.mgmt.DeleteOAuthModelAlias)
|
||||||
|
|
||||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||||
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
||||||
|
mgmt.GET("/model-definitions/:channel", s.mgmt.GetStaticModelDefinitions)
|
||||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||||
|
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||||
|
|
||||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||||
@@ -881,6 +911,15 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.requestLogger != nil && (oldCfg == nil || oldCfg.ErrorLogsMaxFiles != cfg.ErrorLogsMaxFiles) {
|
||||||
|
if setter, ok := s.requestLogger.(interface{ SetErrorLogsMaxFiles(int) }); ok {
|
||||||
|
setter.SetErrorLogsMaxFiles(cfg.ErrorLogsMaxFiles)
|
||||||
|
}
|
||||||
|
if oldCfg != nil {
|
||||||
|
log.Debugf("error_logs_max_files updated from %d to %d", oldCfg.ErrorLogsMaxFiles, cfg.ErrorLogsMaxFiles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
|
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
|
||||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
if oldCfg != nil {
|
if oldCfg != nil {
|
||||||
@@ -889,6 +928,16 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
|
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if oldCfg == nil || oldCfg.CodexInstructionsEnabled != cfg.CodexInstructionsEnabled {
|
||||||
|
misc.SetCodexInstructionsEnabled(cfg.CodexInstructionsEnabled)
|
||||||
|
if oldCfg != nil {
|
||||||
|
log.Debugf("codex_instructions_enabled updated from %t to %t", oldCfg.CodexInstructionsEnabled, cfg.CodexInstructionsEnabled)
|
||||||
|
} else {
|
||||||
|
log.Debugf("codex_instructions_enabled toggled to %t", cfg.CodexInstructionsEnabled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||||
}
|
}
|
||||||
@@ -956,18 +1005,25 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify Amp module of config changes (for model mapping hot-reload)
|
// Notify Amp module only when Amp config has changed.
|
||||||
if s.ampModule != nil {
|
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
|
||||||
log.Debugf("triggering amp module config update")
|
if ampConfigChanged {
|
||||||
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
if s.ampModule != nil {
|
||||||
log.Errorf("failed to update Amp module config: %v", err)
|
log.Debugf("triggering amp module config update")
|
||||||
|
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
||||||
|
log.Errorf("failed to update Amp module config: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Warnf("amp module is nil, skipping config update")
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
log.Warnf("amp module is nil, skipping config update")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count client sources from configuration and auth directory
|
// Count client sources from configuration and auth store.
|
||||||
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
tokenStore := sdkAuth.GetTokenStore()
|
||||||
|
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
|
||||||
|
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||||
|
}
|
||||||
|
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
|
||||||
geminiAPIKeyCount := len(cfg.GeminiKey)
|
geminiAPIKeyCount := len(cfg.GeminiKey)
|
||||||
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
||||||
codexAPIKeyCount := len(cfg.CodexKey)
|
codexAPIKeyCount := len(cfg.CodexKey)
|
||||||
@@ -978,10 +1034,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
openAICompatCount += len(entry.APIKeyEntries)
|
openAICompatCount += len(entry.APIKeyEntries)
|
||||||
}
|
}
|
||||||
|
|
||||||
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||||
total,
|
total,
|
||||||
authFiles,
|
authEntries,
|
||||||
geminiAPIKeyCount,
|
geminiAPIKeyCount,
|
||||||
claudeAPIKeyCount,
|
claudeAPIKeyCount,
|
||||||
codexAPIKeyCount,
|
codexAPIKeyCount,
|
||||||
|
|||||||
344
internal/auth/antigravity/auth.go
Normal file
344
internal/auth/antigravity/auth.go
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenResponse represents OAuth token response from Google
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// userInfo represents Google user profile
|
||||||
|
type userInfo struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityAuth handles Antigravity OAuth authentication
|
||||||
|
type AntigravityAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAntigravityAuth creates a new Antigravity auth service.
|
||||||
|
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
|
||||||
|
if httpClient != nil {
|
||||||
|
return &AntigravityAuth{httpClient: httpClient}
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &config.Config{}
|
||||||
|
}
|
||||||
|
return &AntigravityAuth{
|
||||||
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildAuthURL generates the OAuth authorization URL.
|
||||||
|
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
|
||||||
|
if strings.TrimSpace(redirectURI) == "" {
|
||||||
|
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
|
||||||
|
}
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("access_type", "offline")
|
||||||
|
params.Set("client_id", ClientID)
|
||||||
|
params.Set("prompt", "consent")
|
||||||
|
params.Set("redirect_uri", redirectURI)
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Set("scope", strings.Join(Scopes, " "))
|
||||||
|
params.Set("state", state)
|
||||||
|
return AuthEndpoint + "?" + params.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
|
||||||
|
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("code", code)
|
||||||
|
data.Set("client_id", ClientID)
|
||||||
|
data.Set("client_secret", ClientSecret)
|
||||||
|
data.Set("redirect_uri", redirectURI)
|
||||||
|
data.Set("grant_type", "authorization_code")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity token exchange: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||||
|
if errRead != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if body == "" {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token TokenResponse
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUserInfo retrieves user email from Google
|
||||||
|
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||||
|
accessToken = strings.TrimSpace(accessToken)
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: missing access token")
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity userinfo: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if body == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
var info userInfo
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
email := strings.TrimSpace(info.Email)
|
||||||
|
if email == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: response missing email")
|
||||||
|
}
|
||||||
|
return email, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
|
||||||
|
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
|
||||||
|
loadReqBody := map[string]any{
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"ideType": "ANTIGRAVITY",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", APIUserAgent)
|
||||||
|
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||||
|
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("read response: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var loadResp map[string]any
|
||||||
|
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract projectID from response
|
||||||
|
projectID := ""
|
||||||
|
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
if projectID == "" {
|
||||||
|
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||||
|
if id, okID := projectMap["id"].(string); okID {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID == "" {
|
||||||
|
tierID := "legacy-tier"
|
||||||
|
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||||
|
for _, rawTier := range tiers {
|
||||||
|
tier, okTier := rawTier.(map[string]any)
|
||||||
|
if !okTier {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||||
|
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||||
|
tierID = strings.TrimSpace(id)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
|
||||||
|
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||||
|
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
|
||||||
|
requestBody := map[string]any{
|
||||||
|
"tierId": tierID,
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"ideType": "ANTIGRAVITY",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody, errMarshal := json.Marshal(requestBody)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAttempts := 5
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||||
|
|
||||||
|
reqCtx := ctx
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
if reqCtx == nil {
|
||||||
|
reqCtx = context.Background()
|
||||||
|
}
|
||||||
|
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||||
|
|
||||||
|
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
|
||||||
|
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||||
|
if errRequest != nil {
|
||||||
|
cancel()
|
||||||
|
return "", fmt.Errorf("create request: %w", errRequest)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", APIUserAgent)
|
||||||
|
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||||
|
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
cancel()
|
||||||
|
return "", fmt.Errorf("execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("read response: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
var data map[string]any
|
||||||
|
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if done, okDone := data["done"].(bool); okDone && done {
|
||||||
|
projectID := ""
|
||||||
|
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||||
|
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if id, okID := projectValue["id"].(string); okID {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
projectID = strings.TrimSpace(projectValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID != "" {
|
||||||
|
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("no project_id in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if len(responsePreview) > 500 {
|
||||||
|
responsePreview = responsePreview[:500]
|
||||||
|
}
|
||||||
|
|
||||||
|
responseErr := responsePreview
|
||||||
|
if len(responseErr) > 200 {
|
||||||
|
responseErr = responseErr[:200]
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
34
internal/auth/antigravity/constants.go
Normal file
34
internal/auth/antigravity/constants.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||||
|
package antigravity
|
||||||
|
|
||||||
|
// OAuth client credentials and configuration
|
||||||
|
const (
|
||||||
|
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
|
ClientSecret = "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"
|
||||||
|
CallbackPort = 51121
|
||||||
|
)
|
||||||
|
|
||||||
|
// Scopes defines the OAuth scopes required for Antigravity authentication
|
||||||
|
var Scopes = []string{
|
||||||
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
|
"https://www.googleapis.com/auth/cclog",
|
||||||
|
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2 endpoints for Google authentication
|
||||||
|
const (
|
||||||
|
TokenEndpoint = "https://oauth2.googleapis.com/token"
|
||||||
|
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Antigravity API configuration
|
||||||
|
const (
|
||||||
|
APIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
|
APIVersion = "v1internal"
|
||||||
|
APIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||||
|
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||||
|
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
||||||
|
)
|
||||||
16
internal/auth/antigravity/filename.go
Normal file
16
internal/auth/antigravity/filename.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used to persist Antigravity credentials.
|
||||||
|
// It uses the email as a suffix to disambiguate accounts.
|
||||||
|
func CredentialFileName(email string) string {
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
if email == "" {
|
||||||
|
return "antigravity.json"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("antigravity-%s.json", email)
|
||||||
|
}
|
||||||
@@ -14,15 +14,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for Claude/Anthropic
|
||||||
const (
|
const (
|
||||||
anthropicAuthURL = "https://claude.ai/oauth/authorize"
|
AuthURL = "https://claude.ai/oauth/authorize"
|
||||||
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||||
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||||
redirectURI = "http://localhost:54545/callback"
|
RedirectURI = "http://localhost:54545/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
||||||
@@ -50,7 +50,8 @@ type ClaudeAuth struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClaudeAuth creates a new Anthropic authentication service.
|
// NewClaudeAuth creates a new Anthropic authentication service.
|
||||||
// It initializes the HTTP client with proxy settings from the configuration.
|
// It initializes the HTTP client with a custom TLS transport that uses Firefox
|
||||||
|
// fingerprint to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - cfg: The application configuration containing proxy settings
|
// - cfg: The application configuration containing proxy settings
|
||||||
@@ -58,8 +59,10 @@ type ClaudeAuth struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *ClaudeAuth: A new Claude authentication service instance
|
// - *ClaudeAuth: A new Claude authentication service instance
|
||||||
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
||||||
|
// Use custom HTTP client with Firefox TLS fingerprint to bypass
|
||||||
|
// Cloudflare's bot detection on Anthropic domains
|
||||||
return &ClaudeAuth{
|
return &ClaudeAuth{
|
||||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
httpClient: NewAnthropicHttpClient(&cfg.SDKConfig),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,16 +85,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
|||||||
|
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"code": {"true"},
|
"code": {"true"},
|
||||||
"client_id": {anthropicClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"org:create_api_key user:profile user:inference"},
|
"scope": {"org:create_api_key user:profile user:inference"},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
"code_challenge_method": {"S256"},
|
"code_challenge_method": {"S256"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
|
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||||
return authURL, state, nil
|
return authURL, state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,8 +140,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
|||||||
"code": newCode,
|
"code": newCode,
|
||||||
"state": state,
|
"state": state,
|
||||||
"grant_type": "authorization_code",
|
"grant_type": "authorization_code",
|
||||||
"client_id": anthropicClientID,
|
"client_id": ClientID,
|
||||||
"redirect_uri": redirectURI,
|
"redirect_uri": RedirectURI,
|
||||||
"code_verifier": pkceCodes.CodeVerifier,
|
"code_verifier": pkceCodes.CodeVerifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,7 +157,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
|||||||
|
|
||||||
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -221,7 +224,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
|||||||
}
|
}
|
||||||
|
|
||||||
reqBody := map[string]interface{}{
|
reqBody := map[string]interface{}{
|
||||||
"client_id": anthropicClientID,
|
"client_id": ClientID,
|
||||||
"grant_type": "refresh_token",
|
"grant_type": "refresh_token",
|
||||||
"refresh_token": refreshToken,
|
"refresh_token": refreshToken,
|
||||||
}
|
}
|
||||||
@@ -231,7 +234,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
|||||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
165
internal/auth/claude/utls_transport.go
Normal file
165
internal/auth/claude/utls_transport.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
// Package claude provides authentication functionality for Anthropic's Claude API.
|
||||||
|
// This file implements a custom HTTP transport using utls to bypass TLS fingerprinting.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
tls "github.com/refraction-networking/utls"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// utlsRoundTripper implements http.RoundTripper using utls with Firefox fingerprint
|
||||||
|
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||||
|
type utlsRoundTripper struct {
|
||||||
|
// mu protects the connections map and pending map
|
||||||
|
mu sync.Mutex
|
||||||
|
// connections caches HTTP/2 client connections per host
|
||||||
|
connections map[string]*http2.ClientConn
|
||||||
|
// pending tracks hosts that are currently being connected to (prevents race condition)
|
||||||
|
pending map[string]*sync.Cond
|
||||||
|
// dialer is used to create network connections, supporting proxies
|
||||||
|
dialer proxy.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
||||||
|
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||||
|
var dialer proxy.Dialer = proxy.Direct
|
||||||
|
if cfg != nil && cfg.ProxyURL != "" {
|
||||||
|
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
|
||||||
|
} else {
|
||||||
|
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
|
||||||
|
} else {
|
||||||
|
dialer = pDialer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &utlsRoundTripper{
|
||||||
|
connections: make(map[string]*http2.ClientConn),
|
||||||
|
pending: make(map[string]*sync.Cond),
|
||||||
|
dialer: dialer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOrCreateConnection gets an existing connection or creates a new one.
|
||||||
|
// It uses a per-host locking mechanism to prevent multiple goroutines from
|
||||||
|
// creating connections to the same host simultaneously.
|
||||||
|
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||||
|
t.mu.Lock()
|
||||||
|
|
||||||
|
// Check if connection exists and is usable
|
||||||
|
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return h2Conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if another goroutine is already creating a connection
|
||||||
|
if cond, ok := t.pending[host]; ok {
|
||||||
|
// Wait for the other goroutine to finish
|
||||||
|
cond.Wait()
|
||||||
|
// Check if connection is now available
|
||||||
|
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return h2Conn, nil
|
||||||
|
}
|
||||||
|
// Connection still not available, we'll create one
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark this host as pending
|
||||||
|
cond := sync.NewCond(&t.mu)
|
||||||
|
t.pending[host] = cond
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
// Create connection outside the lock
|
||||||
|
h2Conn, err := t.createConnection(host, addr)
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
// Remove pending marker and wake up waiting goroutines
|
||||||
|
delete(t.pending, host)
|
||||||
|
cond.Broadcast()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the new connection
|
||||||
|
t.connections[host] = h2Conn
|
||||||
|
return h2Conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createConnection creates a new HTTP/2 connection with Firefox TLS fingerprint
|
||||||
|
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||||
|
conn, err := t.dialer.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{ServerName: host}
|
||||||
|
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloFirefox_Auto)
|
||||||
|
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tr := &http2.Transport{}
|
||||||
|
h2Conn, err := tr.NewClientConn(tlsConn)
|
||||||
|
if err != nil {
|
||||||
|
tlsConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return h2Conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip implements http.RoundTripper
|
||||||
|
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
host := req.URL.Host
|
||||||
|
addr := host
|
||||||
|
if !strings.Contains(addr, ":") {
|
||||||
|
addr += ":443"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get hostname without port for TLS ServerName
|
||||||
|
hostname := req.URL.Hostname()
|
||||||
|
|
||||||
|
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h2Conn.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
// Connection failed, remove it from cache
|
||||||
|
t.mu.Lock()
|
||||||
|
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
|
||||||
|
delete(t.connections, hostname)
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
|
||||||
|
// for Anthropic domains by using utls with Firefox fingerprint.
|
||||||
|
// It accepts optional SDK configuration for proxy settings.
|
||||||
|
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: newUtlsRoundTripper(cfg),
|
||||||
|
}
|
||||||
|
}
|
||||||
46
internal/auth/codex/filename.go
Normal file
46
internal/auth/codex/filename.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
|
||||||
|
// When planType is available (e.g. "plus", "team"), it is appended after the email
|
||||||
|
// as a suffix to disambiguate subscriptions.
|
||||||
|
func CredentialFileName(email, planType, hashAccountID string, includeProviderPrefix bool) string {
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
plan := normalizePlanTypeForFilename(planType)
|
||||||
|
|
||||||
|
prefix := ""
|
||||||
|
if includeProviderPrefix {
|
||||||
|
prefix = "codex"
|
||||||
|
}
|
||||||
|
|
||||||
|
if plan == "" {
|
||||||
|
return fmt.Sprintf("%s-%s.json", prefix, email)
|
||||||
|
} else if plan == "team" {
|
||||||
|
return fmt.Sprintf("%s-%s-%s-%s.json", prefix, hashAccountID, email, plan)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s-%s-%s.json", prefix, email, plan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePlanTypeForFilename(planType string) string {
|
||||||
|
planType = strings.TrimSpace(planType)
|
||||||
|
if planType == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.FieldsFunc(planType, func(r rune) bool {
|
||||||
|
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||||
|
})
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, part := range parts {
|
||||||
|
parts[i] = strings.ToLower(strings.TrimSpace(part))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "-")
|
||||||
|
}
|
||||||
@@ -19,11 +19,12 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for OpenAI Codex
|
||||||
const (
|
const (
|
||||||
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
|
AuthURL = "https://auth.openai.com/oauth/authorize"
|
||||||
openaiTokenURL = "https://auth.openai.com/oauth/token"
|
TokenURL = "https://auth.openai.com/oauth/token"
|
||||||
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
redirectURI = "http://localhost:1455/auth/callback"
|
RedirectURI = "http://localhost:1455/auth/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
||||||
@@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"openid email profile offline_access"},
|
"scope": {"openid email profile offline_access"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
@@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
"codex_cli_simplified_flow": {"true"},
|
"codex_cli_simplified_flow": {"true"},
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
|
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||||
return authURL, nil
|
return authURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,13 +78,13 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce
|
|||||||
// Prepare token exchange request
|
// Prepare token exchange request
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"code": {code},
|
"code": {code},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -163,13 +164,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"grant_type": {"refresh_token"},
|
"grant_type": {"refresh_token"},
|
||||||
"refresh_token": {refreshToken},
|
"refresh_token": {refreshToken},
|
||||||
"scope": {"openid profile email"},
|
"scope": {"openid profile email"},
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,18 +28,19 @@ import (
|
|||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for Gemini
|
||||||
const (
|
const (
|
||||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||||
geminiOauthClientSecret = "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"
|
ClientSecret = "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"
|
||||||
|
DefaultCallbackPort = 8085
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// OAuth scopes for Gemini authentication
|
||||||
geminiOauthScopes = []string{
|
var Scopes = []string{
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
||||||
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
||||||
@@ -49,8 +50,9 @@ type GeminiAuth struct {
|
|||||||
|
|
||||||
// WebLoginOptions customizes the interactive OAuth flow.
|
// WebLoginOptions customizes the interactive OAuth flow.
|
||||||
type WebLoginOptions struct {
|
type WebLoginOptions struct {
|
||||||
NoBrowser bool
|
NoBrowser bool
|
||||||
Prompt func(string) (string, error)
|
CallbackPort int
|
||||||
|
Prompt func(string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||||
@@ -72,6 +74,12 @@ func NewGeminiAuth() *GeminiAuth {
|
|||||||
// - *http.Client: An HTTP client configured with authentication
|
// - *http.Client: An HTTP client configured with authentication
|
||||||
// - error: An error if the client configuration fails, nil otherwise
|
// - error: An error if the client configuration fails, nil otherwise
|
||||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||||
|
callbackPort := DefaultCallbackPort
|
||||||
|
if opts != nil && opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||||
|
|
||||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -104,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
|
|
||||||
// Configure the OAuth2 client.
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: geminiOauthClientID,
|
ClientID: ClientID,
|
||||||
ClientSecret: geminiOauthClientSecret,
|
ClientSecret: ClientSecret,
|
||||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
RedirectURL: callbackURL, // This will be used by the local server.
|
||||||
Scopes: geminiOauthScopes,
|
Scopes: Scopes,
|
||||||
Endpoint: google.Endpoint,
|
Endpoint: google.Endpoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||||
ifToken["client_id"] = geminiOauthClientID
|
ifToken["client_id"] = ClientID
|
||||||
ifToken["client_secret"] = geminiOauthClientSecret
|
ifToken["client_secret"] = ClientSecret
|
||||||
ifToken["scopes"] = geminiOauthScopes
|
ifToken["scopes"] = Scopes
|
||||||
ifToken["universe_domain"] = "googleapis.com"
|
ifToken["universe_domain"] = "googleapis.com"
|
||||||
|
|
||||||
ts := GeminiTokenStorage{
|
ts := GeminiTokenStorage{
|
||||||
@@ -218,14 +226,20 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
|||||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||||
// - error: An error if the token acquisition fails, nil otherwise
|
// - error: An error if the token acquisition fails, nil otherwise
|
||||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||||
|
callbackPort := DefaultCallbackPort
|
||||||
|
if opts != nil && opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||||
|
|
||||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||||
codeChan := make(chan string, 1)
|
codeChan := make(chan string, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
// Create a new HTTP server with its own multiplexer.
|
// Create a new HTTP server with its own multiplexer.
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
server := &http.Server{Addr: ":8085", Handler: mux}
|
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
|
||||||
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
config.RedirectURL = callbackURL
|
||||||
|
|
||||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.URL.Query().Get("error"); err != "" {
|
if err := r.URL.Query().Get("error"); err != "" {
|
||||||
@@ -277,13 +291,13 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
// Check if browser is available
|
// Check if browser is available
|
||||||
if !browser.IsAvailable() {
|
if !browser.IsAvailable() {
|
||||||
log.Warn("No browser available on this system")
|
log.Warn("No browser available on this system")
|
||||||
util.PrintSSHTunnelInstructions(8085)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||||
} else {
|
} else {
|
||||||
if err := browser.OpenURL(authURL); err != nil {
|
if err := browser.OpenURL(authURL); err != nil {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||||
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||||
util.PrintSSHTunnelInstructions(8085)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||||
|
|
||||||
// Log platform info for debugging
|
// Log platform info for debugging
|
||||||
@@ -294,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
util.PrintSSHTunnelInstructions(8085)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
189
internal/cache/signature_cache.go
vendored
189
internal/cache/signature_cache.go
vendored
@@ -3,7 +3,7 @@ package cache
|
|||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"sort"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -16,23 +16,26 @@ type SignatureEntry struct {
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// SignatureCacheTTL is how long signatures are valid
|
// SignatureCacheTTL is how long signatures are valid
|
||||||
SignatureCacheTTL = 1 * time.Hour
|
SignatureCacheTTL = 3 * time.Hour
|
||||||
|
|
||||||
// MaxEntriesPerSession limits memory usage per session
|
|
||||||
MaxEntriesPerSession = 100
|
|
||||||
|
|
||||||
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
|
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
|
||||||
SignatureTextHashLen = 16
|
SignatureTextHashLen = 16
|
||||||
|
|
||||||
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
||||||
MinValidSignatureLen = 50
|
MinValidSignatureLen = 50
|
||||||
|
|
||||||
|
// CacheCleanupInterval controls how often stale entries are purged
|
||||||
|
CacheCleanupInterval = 10 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
|
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
|
||||||
var signatureCache sync.Map
|
var signatureCache sync.Map
|
||||||
|
|
||||||
// sessionCache is the inner map type
|
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
|
||||||
type sessionCache struct {
|
var cacheCleanupOnce sync.Once
|
||||||
|
|
||||||
|
// groupCache is the inner map type
|
||||||
|
type groupCache struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
entries map[string]SignatureEntry
|
entries map[string]SignatureEntry
|
||||||
}
|
}
|
||||||
@@ -43,122 +46,150 @@ func hashText(text string) string {
|
|||||||
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrCreateSession gets or creates a session cache
|
// getOrCreateGroupCache gets or creates a cache bucket for a model group
|
||||||
func getOrCreateSession(sessionID string) *sessionCache {
|
func getOrCreateGroupCache(groupKey string) *groupCache {
|
||||||
if val, ok := signatureCache.Load(sessionID); ok {
|
// Start background cleanup on first access
|
||||||
return val.(*sessionCache)
|
cacheCleanupOnce.Do(startCacheCleanup)
|
||||||
|
|
||||||
|
if val, ok := signatureCache.Load(groupKey); ok {
|
||||||
|
return val.(*groupCache)
|
||||||
}
|
}
|
||||||
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
|
sc := &groupCache{entries: make(map[string]SignatureEntry)}
|
||||||
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
|
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
|
||||||
return actual.(*sessionCache)
|
return actual.(*groupCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CacheSignature stores a thinking signature for a given session and text.
|
// startCacheCleanup launches a background goroutine that periodically
|
||||||
|
// removes caches where all entries have expired.
|
||||||
|
func startCacheCleanup() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(CacheCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
purgeExpiredCaches()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
|
||||||
|
func purgeExpiredCaches() {
|
||||||
|
now := time.Now()
|
||||||
|
signatureCache.Range(func(key, value any) bool {
|
||||||
|
sc := value.(*groupCache)
|
||||||
|
sc.mu.Lock()
|
||||||
|
// Remove expired entries
|
||||||
|
for k, entry := range sc.entries {
|
||||||
|
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||||
|
delete(sc.entries, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isEmpty := len(sc.entries) == 0
|
||||||
|
sc.mu.Unlock()
|
||||||
|
// Remove cache bucket if empty
|
||||||
|
if isEmpty {
|
||||||
|
signatureCache.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheSignature stores a thinking signature for a given model group and text.
|
||||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||||
func CacheSignature(sessionID, text, signature string) {
|
func CacheSignature(modelName, text, signature string) {
|
||||||
if sessionID == "" || text == "" || signature == "" {
|
if text == "" || signature == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(signature) < MinValidSignatureLen {
|
if len(signature) < MinValidSignatureLen {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sc := getOrCreateSession(sessionID)
|
groupKey := GetModelGroup(modelName)
|
||||||
textHash := hashText(text)
|
textHash := hashText(text)
|
||||||
|
sc := getOrCreateGroupCache(groupKey)
|
||||||
sc.mu.Lock()
|
sc.mu.Lock()
|
||||||
defer sc.mu.Unlock()
|
defer sc.mu.Unlock()
|
||||||
|
|
||||||
// Evict expired entries if at capacity
|
|
||||||
if len(sc.entries) >= MaxEntriesPerSession {
|
|
||||||
now := time.Now()
|
|
||||||
for key, entry := range sc.entries {
|
|
||||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
|
||||||
delete(sc.entries, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If still at capacity, remove oldest entries
|
|
||||||
if len(sc.entries) >= MaxEntriesPerSession {
|
|
||||||
// Find and remove oldest quarter
|
|
||||||
oldest := make([]struct {
|
|
||||||
key string
|
|
||||||
ts time.Time
|
|
||||||
}, 0, len(sc.entries))
|
|
||||||
for key, entry := range sc.entries {
|
|
||||||
oldest = append(oldest, struct {
|
|
||||||
key string
|
|
||||||
ts time.Time
|
|
||||||
}{key, entry.Timestamp})
|
|
||||||
}
|
|
||||||
// Sort by timestamp (oldest first) using sort.Slice
|
|
||||||
sort.Slice(oldest, func(i, j int) bool {
|
|
||||||
return oldest[i].ts.Before(oldest[j].ts)
|
|
||||||
})
|
|
||||||
|
|
||||||
toRemove := len(oldest) / 4
|
|
||||||
if toRemove < 1 {
|
|
||||||
toRemove = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < toRemove; i++ {
|
|
||||||
delete(sc.entries, oldest[i].key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.entries[textHash] = SignatureEntry{
|
sc.entries[textHash] = SignatureEntry{
|
||||||
Signature: signature,
|
Signature: signature,
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCachedSignature retrieves a cached signature for a given session and text.
|
// GetCachedSignature retrieves a cached signature for a given model group and text.
|
||||||
// Returns empty string if not found or expired.
|
// Returns empty string if not found or expired.
|
||||||
func GetCachedSignature(sessionID, text string) string {
|
func GetCachedSignature(modelName, text string) string {
|
||||||
if sessionID == "" || text == "" {
|
groupKey := GetModelGroup(modelName)
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
val, ok := signatureCache.Load(sessionID)
|
if text == "" {
|
||||||
if !ok {
|
if groupKey == "gemini" {
|
||||||
|
return "skip_thought_signature_validator"
|
||||||
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
sc := val.(*sessionCache)
|
val, ok := signatureCache.Load(groupKey)
|
||||||
|
if !ok {
|
||||||
|
if groupKey == "gemini" {
|
||||||
|
return "skip_thought_signature_validator"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sc := val.(*groupCache)
|
||||||
|
|
||||||
textHash := hashText(text)
|
textHash := hashText(text)
|
||||||
|
|
||||||
sc.mu.RLock()
|
now := time.Now()
|
||||||
entry, exists := sc.entries[textHash]
|
|
||||||
sc.mu.RUnlock()
|
|
||||||
|
|
||||||
|
sc.mu.Lock()
|
||||||
|
entry, exists := sc.entries[textHash]
|
||||||
if !exists {
|
if !exists {
|
||||||
|
sc.mu.Unlock()
|
||||||
|
if groupKey == "gemini" {
|
||||||
|
return "skip_thought_signature_validator"
|
||||||
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||||
// Check if expired
|
|
||||||
if time.Since(entry.Timestamp) > SignatureCacheTTL {
|
|
||||||
sc.mu.Lock()
|
|
||||||
delete(sc.entries, textHash)
|
delete(sc.entries, textHash)
|
||||||
sc.mu.Unlock()
|
sc.mu.Unlock()
|
||||||
|
if groupKey == "gemini" {
|
||||||
|
return "skip_thought_signature_validator"
|
||||||
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Refresh TTL on access (sliding expiration).
|
||||||
|
entry.Timestamp = now
|
||||||
|
sc.entries[textHash] = entry
|
||||||
|
sc.mu.Unlock()
|
||||||
|
|
||||||
return entry.Signature
|
return entry.Signature
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearSignatureCache clears signature cache for a specific session or all sessions.
|
// ClearSignatureCache clears signature cache for a specific model group or all groups.
|
||||||
func ClearSignatureCache(sessionID string) {
|
func ClearSignatureCache(modelName string) {
|
||||||
if sessionID != "" {
|
if modelName == "" {
|
||||||
signatureCache.Delete(sessionID)
|
|
||||||
} else {
|
|
||||||
signatureCache.Range(func(key, _ any) bool {
|
signatureCache.Range(func(key, _ any) bool {
|
||||||
signatureCache.Delete(key)
|
signatureCache.Delete(key)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
groupKey := GetModelGroup(modelName)
|
||||||
|
signatureCache.Delete(groupKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
||||||
func HasValidSignature(signature string) bool {
|
func HasValidSignature(modelName, signature string) bool {
|
||||||
return signature != "" && len(signature) >= MinValidSignatureLen
|
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModelGroup(modelName string) string {
|
||||||
|
if strings.Contains(modelName, "gpt") {
|
||||||
|
return "gpt"
|
||||||
|
} else if strings.Contains(modelName, "claude") {
|
||||||
|
return "claude"
|
||||||
|
} else if strings.Contains(modelName, "gemini") {
|
||||||
|
return "gemini"
|
||||||
|
}
|
||||||
|
return modelName
|
||||||
}
|
}
|
||||||
|
|||||||
110
internal/cache/signature_cache_test.go
vendored
110
internal/cache/signature_cache_test.go
vendored
@@ -5,38 +5,40 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const testModelName = "claude-sonnet-4-5"
|
||||||
|
|
||||||
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "test-session-1"
|
|
||||||
text := "This is some thinking text content"
|
text := "This is some thinking text content"
|
||||||
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
|
||||||
// Store signature
|
// Store signature
|
||||||
CacheSignature(sessionID, text, signature)
|
CacheSignature(testModelName, text, signature)
|
||||||
|
|
||||||
// Retrieve signature
|
// Retrieve signature
|
||||||
retrieved := GetCachedSignature(sessionID, text)
|
retrieved := GetCachedSignature(testModelName, text)
|
||||||
if retrieved != signature {
|
if retrieved != signature {
|
||||||
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCacheSignature_DifferentSessions(t *testing.T) {
|
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
text := "Same text in different sessions"
|
text := "Same text across models"
|
||||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||||
|
|
||||||
CacheSignature("session-a", text, sig1)
|
geminiModel := "gemini-3-pro-preview"
|
||||||
CacheSignature("session-b", text, sig2)
|
CacheSignature(testModelName, text, sig1)
|
||||||
|
CacheSignature(geminiModel, text, sig2)
|
||||||
|
|
||||||
if GetCachedSignature("session-a", text) != sig1 {
|
if GetCachedSignature(testModelName, text) != sig1 {
|
||||||
t.Error("Session-a signature mismatch")
|
t.Error("Claude signature mismatch")
|
||||||
}
|
}
|
||||||
if GetCachedSignature("session-b", text) != sig2 {
|
if GetCachedSignature(geminiModel, text) != sig2 {
|
||||||
t.Error("Session-b signature mismatch")
|
t.Error("Gemini signature mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,13 +46,13 @@ func TestCacheSignature_NotFound(t *testing.T) {
|
|||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
// Non-existent session
|
// Non-existent session
|
||||||
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
|
if got := GetCachedSignature(testModelName, "some text"); got != "" {
|
||||||
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Existing session but different text
|
// Existing session but different text
|
||||||
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
CacheSignature(testModelName, "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||||
if got := GetCachedSignature("session-x", "text-b"); got != "" {
|
if got := GetCachedSignature(testModelName, "text-b"); got != "" {
|
||||||
t.Errorf("Expected empty string for different text, got '%s'", got)
|
t.Errorf("Expected empty string for different text, got '%s'", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -59,12 +61,11 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
|||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
// All empty/invalid inputs should be no-ops
|
// All empty/invalid inputs should be no-ops
|
||||||
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
|
CacheSignature(testModelName, "", "sig12345678901234567890123456789012345678901234567890")
|
||||||
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
|
CacheSignature(testModelName, "text", "")
|
||||||
CacheSignature("session", "text", "")
|
CacheSignature(testModelName, "text", "short") // Too short
|
||||||
CacheSignature("session", "text", "short") // Too short
|
|
||||||
|
|
||||||
if got := GetCachedSignature("session", "text"); got != "" {
|
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||||
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -72,31 +73,27 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
|||||||
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "test-short-sig"
|
|
||||||
text := "Some text"
|
text := "Some text"
|
||||||
shortSig := "abc123" // Less than 50 chars
|
shortSig := "abc123" // Less than 50 chars
|
||||||
|
|
||||||
CacheSignature(sessionID, text, shortSig)
|
CacheSignature(testModelName, text, shortSig)
|
||||||
|
|
||||||
if got := GetCachedSignature(sessionID, text); got != "" {
|
if got := GetCachedSignature(testModelName, text); got != "" {
|
||||||
t.Errorf("Short signature should be rejected, got '%s'", got)
|
t.Errorf("Short signature should be rejected, got '%s'", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClearSignatureCache_SpecificSession(t *testing.T) {
|
func TestClearSignatureCache_ModelGroup(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||||
CacheSignature("session-1", "text", sig)
|
CacheSignature(testModelName, "text", sig)
|
||||||
CacheSignature("session-2", "text", sig)
|
CacheSignature(testModelName, "text-2", sig)
|
||||||
|
|
||||||
ClearSignatureCache("session-1")
|
ClearSignatureCache("session-1")
|
||||||
|
|
||||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
if got := GetCachedSignature(testModelName, "text"); got != sig {
|
||||||
t.Error("session-1 should be cleared")
|
t.Error("signature should remain when clearing unknown session")
|
||||||
}
|
|
||||||
if got := GetCachedSignature("session-2", "text"); got != sig {
|
|
||||||
t.Error("session-2 should still exist")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,35 +101,37 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
|
|||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||||
CacheSignature("session-1", "text", sig)
|
CacheSignature(testModelName, "text", sig)
|
||||||
CacheSignature("session-2", "text", sig)
|
CacheSignature(testModelName, "text-2", sig)
|
||||||
|
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||||
t.Error("session-1 should be cleared")
|
t.Error("text should be cleared")
|
||||||
}
|
}
|
||||||
if got := GetCachedSignature("session-2", "text"); got != "" {
|
if got := GetCachedSignature(testModelName, "text-2"); got != "" {
|
||||||
t.Error("session-2 should be cleared")
|
t.Error("text-2 should be cleared")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHasValidSignature(t *testing.T) {
|
func TestHasValidSignature(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
modelName string
|
||||||
signature string
|
signature string
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
|
{"valid long signature", testModelName, "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||||
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
|
{"exactly 50 chars", testModelName, "12345678901234567890123456789012345678901234567890", true},
|
||||||
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
|
{"49 chars - invalid", testModelName, "1234567890123456789012345678901234567890123456789", false},
|
||||||
{"empty string", "", false},
|
{"empty string", testModelName, "", false},
|
||||||
{"short signature", "abc", false},
|
{"short signature", testModelName, "abc", false},
|
||||||
|
{"gemini sentinel", "gemini-3-pro-preview", "skip_thought_signature_validator", true},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := HasValidSignature(tt.signature)
|
result := HasValidSignature(tt.modelName, tt.signature)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
||||||
}
|
}
|
||||||
@@ -143,21 +142,19 @@ func TestHasValidSignature(t *testing.T) {
|
|||||||
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "hash-test-session"
|
|
||||||
|
|
||||||
// Different texts should produce different hashes
|
// Different texts should produce different hashes
|
||||||
text1 := "First thinking text"
|
text1 := "First thinking text"
|
||||||
text2 := "Second thinking text"
|
text2 := "Second thinking text"
|
||||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||||
|
|
||||||
CacheSignature(sessionID, text1, sig1)
|
CacheSignature(testModelName, text1, sig1)
|
||||||
CacheSignature(sessionID, text2, sig2)
|
CacheSignature(testModelName, text2, sig2)
|
||||||
|
|
||||||
if GetCachedSignature(sessionID, text1) != sig1 {
|
if GetCachedSignature(testModelName, text1) != sig1 {
|
||||||
t.Error("text1 signature mismatch")
|
t.Error("text1 signature mismatch")
|
||||||
}
|
}
|
||||||
if GetCachedSignature(sessionID, text2) != sig2 {
|
if GetCachedSignature(testModelName, text2) != sig2 {
|
||||||
t.Error("text2 signature mismatch")
|
t.Error("text2 signature mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,13 +162,12 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
|||||||
func TestCacheSignature_UnicodeText(t *testing.T) {
|
func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "unicode-session"
|
|
||||||
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
||||||
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
||||||
|
|
||||||
CacheSignature(sessionID, text, sig)
|
CacheSignature(testModelName, text, sig)
|
||||||
|
|
||||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||||
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -179,15 +175,14 @@ func TestCacheSignature_UnicodeText(t *testing.T) {
|
|||||||
func TestCacheSignature_Overwrite(t *testing.T) {
|
func TestCacheSignature_Overwrite(t *testing.T) {
|
||||||
ClearSignatureCache("")
|
ClearSignatureCache("")
|
||||||
|
|
||||||
sessionID := "overwrite-session"
|
|
||||||
text := "Same text"
|
text := "Same text"
|
||||||
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
||||||
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
||||||
|
|
||||||
CacheSignature(sessionID, text, sig1)
|
CacheSignature(testModelName, text, sig1)
|
||||||
CacheSignature(sessionID, text, sig2) // Overwrite
|
CacheSignature(testModelName, text, sig2) // Overwrite
|
||||||
|
|
||||||
if got := GetCachedSignature(sessionID, text); got != sig2 {
|
if got := GetCachedSignature(testModelName, text); got != sig2 {
|
||||||
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,14 +194,13 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
|||||||
|
|
||||||
// This test verifies the expiration check exists
|
// This test verifies the expiration check exists
|
||||||
// In a real scenario, we'd mock time.Now()
|
// In a real scenario, we'd mock time.Now()
|
||||||
sessionID := "expiration-test"
|
|
||||||
text := "text"
|
text := "text"
|
||||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||||
|
|
||||||
CacheSignature(sessionID, text, sig)
|
CacheSignature(testModelName, text, sig)
|
||||||
|
|
||||||
// Fresh entry should be retrievable
|
// Fresh entry should be retrievable
|
||||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||||
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,10 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
manager := newAuthManager()
|
manager := newAuthManager()
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
|
|
||||||
manager := newAuthManager()
|
manager := newAuthManager()
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||||
|
|||||||
@@ -67,10 +67,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginOpts := &sdkAuth.LoginOptions{
|
loginOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
ProjectID: trimmedProjectID,
|
ProjectID: trimmedProjectID,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: callbackPrompt,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: callbackPrompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
authenticator := sdkAuth.NewGeminiAuthenticator()
|
authenticator := sdkAuth.NewGeminiAuthenticator()
|
||||||
@@ -88,8 +89,9 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
|
|
||||||
geminiAuth := gemini.NewGeminiAuth()
|
geminiAuth := gemini.NewGeminiAuth()
|
||||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Prompt: callbackPrompt,
|
CallbackPort: options.CallbackPort,
|
||||||
|
Prompt: callbackPrompt,
|
||||||
})
|
})
|
||||||
if errClient != nil {
|
if errClient != nil {
|
||||||
log.Errorf("Gemini authentication failed: %v", errClient)
|
log.Errorf("Gemini authentication failed: %v", errClient)
|
||||||
@@ -116,6 +118,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
activatedProjects := make([]string, 0, len(projectSelections))
|
activatedProjects := make([]string, 0, len(projectSelections))
|
||||||
|
seenProjects := make(map[string]bool)
|
||||||
for _, candidateID := range projectSelections {
|
for _, candidateID := range projectSelections {
|
||||||
log.Infof("Activating project %s", candidateID)
|
log.Infof("Activating project %s", candidateID)
|
||||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||||
@@ -132,6 +135,13 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
if finalID == "" {
|
if finalID == "" {
|
||||||
finalID = candidateID
|
finalID = candidateID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip duplicates
|
||||||
|
if seenProjects[finalID] {
|
||||||
|
log.Infof("Project %s already activated, skipping", finalID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenProjects[finalID] = true
|
||||||
activatedProjects = append(activatedProjects, finalID)
|
activatedProjects = append(activatedProjects, finalID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,7 +269,39 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
|||||||
finalProjectID := projectID
|
finalProjectID := projectID
|
||||||
if responseProjectID != "" {
|
if responseProjectID != "" {
|
||||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||||
|
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||||
|
strings.EqualFold(tierID, "FREE") ||
|
||||||
|
strings.EqualFold(tierID, "LEGACY")
|
||||||
|
|
||||||
|
if isFreeUser {
|
||||||
|
// Interactive prompt for free users
|
||||||
|
fmt.Printf("\nGoogle returned a different project ID:\n")
|
||||||
|
fmt.Printf(" Requested (frontend): %s\n", projectID)
|
||||||
|
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
|
||||||
|
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
|
||||||
|
fmt.Printf(" This is normal for free tier users.\n\n")
|
||||||
|
fmt.Printf("Which project ID would you like to use?\n")
|
||||||
|
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
|
||||||
|
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
|
||||||
|
fmt.Printf("Enter choice [1]: ")
|
||||||
|
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
choice, _ := reader.ReadString('\n')
|
||||||
|
choice = strings.TrimSpace(choice)
|
||||||
|
|
||||||
|
if choice == "2" {
|
||||||
|
log.Infof("Using frontend project ID: %s", projectID)
|
||||||
|
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
|
||||||
|
finalProjectID = projectID
|
||||||
|
} else {
|
||||||
|
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
|
||||||
|
finalProjectID = responseProjectID
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Pro users: keep requested project ID (original behavior)
|
||||||
|
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
finalProjectID = responseProjectID
|
finalProjectID = responseProjectID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ type LoginOptions struct {
|
|||||||
// NoBrowser indicates whether to skip opening the browser automatically.
|
// NoBrowser indicates whether to skip opening the browser automatically.
|
||||||
NoBrowser bool
|
NoBrowser bool
|
||||||
|
|
||||||
|
// CallbackPort overrides the local OAuth callback port when set (>0).
|
||||||
|
CallbackPort int
|
||||||
|
|
||||||
// Prompt allows the caller to provide interactive input when needed.
|
// Prompt allows the caller to provide interactive input when needed.
|
||||||
Prompt func(prompt string) (string, error)
|
Prompt func(prompt string) (string, error)
|
||||||
}
|
}
|
||||||
@@ -43,9 +46,10 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
manager := newAuthManager()
|
manager := newAuthManager()
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||||
|
|||||||
@@ -36,9 +36,10 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
@@ -49,6 +51,10 @@ type Config struct {
|
|||||||
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
||||||
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
||||||
|
|
||||||
|
// ErrorLogsMaxFiles limits the number of error log files retained when request logging is disabled.
|
||||||
|
// When exceeded, the oldest error log files are deleted. Default is 10. Set to 0 to disable cleanup.
|
||||||
|
ErrorLogsMaxFiles int `yaml:"error-logs-max-files" json:"error-logs-max-files"`
|
||||||
|
|
||||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||||
|
|
||||||
@@ -69,6 +75,11 @@ type Config struct {
|
|||||||
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
||||||
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
||||||
|
|
||||||
|
// CodexInstructionsEnabled controls whether official Codex instructions are injected.
|
||||||
|
// When false (default), CodexInstructionsForModel returns immediately without modification.
|
||||||
|
// When true, the original instruction injection logic is used.
|
||||||
|
CodexInstructionsEnabled bool `yaml:"codex-instructions-enabled" json:"codex-instructions-enabled"`
|
||||||
|
|
||||||
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
||||||
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
||||||
|
|
||||||
@@ -91,13 +102,13 @@ type Config struct {
|
|||||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||||
|
|
||||||
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
|
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
||||||
// These mappings affect both model listing and model routing for supported channels:
|
// These aliases affect both model listing and model routing for supported channels:
|
||||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||||
//
|
//
|
||||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||||
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
|
OAuthModelAlias map[string][]OAuthModelAlias `yaml:"oauth-model-alias,omitempty" json:"oauth-model-alias,omitempty"`
|
||||||
|
|
||||||
// Payload defines default and override rules for provider payload parameters.
|
// Payload defines default and override rules for provider payload parameters.
|
||||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||||
@@ -145,11 +156,14 @@ type RoutingConfig struct {
|
|||||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelNameMapping defines a model ID rename mapping for a specific channel.
|
// OAuthModelAlias defines a model ID alias for a specific channel.
|
||||||
// It maps the original model name (Name) to the client-visible alias (Alias).
|
// It maps the upstream model name (Name) to the client-visible alias (Alias).
|
||||||
type ModelNameMapping struct {
|
// When Fork is true, the alias is added as an additional model in listings while
|
||||||
|
// keeping the original model ID available.
|
||||||
|
type OAuthModelAlias struct {
|
||||||
Name string `yaml:"name" json:"name"`
|
Name string `yaml:"name" json:"name"`
|
||||||
Alias string `yaml:"alias" json:"alias"`
|
Alias string `yaml:"alias" json:"alias"`
|
||||||
|
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||||
@@ -213,8 +227,22 @@ type AmpUpstreamAPIKeyEntry struct {
|
|||||||
type PayloadConfig struct {
|
type PayloadConfig struct {
|
||||||
// Default defines rules that only set parameters when they are missing in the payload.
|
// Default defines rules that only set parameters when they are missing in the payload.
|
||||||
Default []PayloadRule `yaml:"default" json:"default"`
|
Default []PayloadRule `yaml:"default" json:"default"`
|
||||||
|
// DefaultRaw defines rules that set raw JSON values only when they are missing.
|
||||||
|
DefaultRaw []PayloadRule `yaml:"default-raw" json:"default-raw"`
|
||||||
// Override defines rules that always set parameters, overwriting any existing values.
|
// Override defines rules that always set parameters, overwriting any existing values.
|
||||||
Override []PayloadRule `yaml:"override" json:"override"`
|
Override []PayloadRule `yaml:"override" json:"override"`
|
||||||
|
// OverrideRaw defines rules that always set raw JSON values, overwriting any existing values.
|
||||||
|
OverrideRaw []PayloadRule `yaml:"override-raw" json:"override-raw"`
|
||||||
|
// Filter defines rules that remove parameters from the payload by JSON path.
|
||||||
|
Filter []PayloadFilterRule `yaml:"filter" json:"filter"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayloadFilterRule describes a rule to remove specific JSON paths from matching model payloads.
|
||||||
|
type PayloadFilterRule struct {
|
||||||
|
// Models lists model entries with name pattern and protocol constraint.
|
||||||
|
Models []PayloadModelRule `yaml:"models" json:"models"`
|
||||||
|
// Params lists JSON paths (gjson/sjson syntax) to remove from the payload.
|
||||||
|
Params []string `yaml:"params" json:"params"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PayloadRule describes a single rule targeting a list of models with parameter updates.
|
// PayloadRule describes a single rule targeting a list of models with parameter updates.
|
||||||
@@ -222,6 +250,7 @@ type PayloadRule struct {
|
|||||||
// Models lists model entries with name pattern and protocol constraint.
|
// Models lists model entries with name pattern and protocol constraint.
|
||||||
Models []PayloadModelRule `yaml:"models" json:"models"`
|
Models []PayloadModelRule `yaml:"models" json:"models"`
|
||||||
// Params maps JSON paths (gjson/sjson syntax) to values written into the payload.
|
// Params maps JSON paths (gjson/sjson syntax) to values written into the payload.
|
||||||
|
// For *-raw rules, values are treated as raw JSON fragments (strings are used as-is).
|
||||||
Params map[string]any `yaml:"params" json:"params"`
|
Params map[string]any `yaml:"params" json:"params"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -233,12 +262,35 @@ type PayloadModelRule struct {
|
|||||||
Protocol string `yaml:"protocol" json:"protocol"`
|
Protocol string `yaml:"protocol" json:"protocol"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloakConfig configures request cloaking for non-Claude-Code clients.
|
||||||
|
// Cloaking disguises API requests to appear as originating from the official Claude Code CLI.
|
||||||
|
type CloakConfig struct {
|
||||||
|
// Mode controls cloaking behavior: "auto" (default), "always", or "never".
|
||||||
|
// - "auto": cloak only when client is not Claude Code (based on User-Agent)
|
||||||
|
// - "always": always apply cloaking regardless of client
|
||||||
|
// - "never": never apply cloaking
|
||||||
|
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||||
|
|
||||||
|
// StrictMode controls how system prompts are handled when cloaking.
|
||||||
|
// - false (default): prepend Claude Code prompt to user system messages
|
||||||
|
// - true: strip all user system messages, keep only Claude Code prompt
|
||||||
|
StrictMode bool `yaml:"strict-mode,omitempty" json:"strict-mode,omitempty"`
|
||||||
|
|
||||||
|
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||||
|
// This can help bypass certain content filters.
|
||||||
|
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ClaudeKey represents the configuration for a Claude API key,
|
// ClaudeKey represents the configuration for a Claude API key,
|
||||||
// including the API key itself and an optional base URL for the API endpoint.
|
// including the API key itself and an optional base URL for the API endpoint.
|
||||||
type ClaudeKey struct {
|
type ClaudeKey struct {
|
||||||
// APIKey is the authentication key for accessing Claude API services.
|
// APIKey is the authentication key for accessing Claude API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -257,8 +309,14 @@ type ClaudeKey struct {
|
|||||||
|
|
||||||
// ExcludedModels lists model IDs that should be excluded for this provider.
|
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||||
|
|
||||||
|
// Cloak configures request cloaking for non-Claude-Code clients.
|
||||||
|
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
||||||
|
func (k ClaudeKey) GetBaseURL() string { return k.BaseURL }
|
||||||
|
|
||||||
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
|
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
|
||||||
type ClaudeModel struct {
|
type ClaudeModel struct {
|
||||||
// Name is the upstream model identifier used when issuing requests.
|
// Name is the upstream model identifier used when issuing requests.
|
||||||
@@ -277,6 +335,10 @@ type CodexKey struct {
|
|||||||
// APIKey is the authentication key for accessing Codex API services.
|
// APIKey is the authentication key for accessing Codex API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -297,6 +359,9 @@ type CodexKey struct {
|
|||||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k CodexKey) GetAPIKey() string { return k.APIKey }
|
||||||
|
func (k CodexKey) GetBaseURL() string { return k.BaseURL }
|
||||||
|
|
||||||
// CodexModel describes a mapping between an alias and the actual upstream model name.
|
// CodexModel describes a mapping between an alias and the actual upstream model name.
|
||||||
type CodexModel struct {
|
type CodexModel struct {
|
||||||
// Name is the upstream model identifier used when issuing requests.
|
// Name is the upstream model identifier used when issuing requests.
|
||||||
@@ -315,6 +380,10 @@ type GeminiKey struct {
|
|||||||
// APIKey is the authentication key for accessing Gemini API services.
|
// APIKey is the authentication key for accessing Gemini API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -334,6 +403,9 @@ type GeminiKey struct {
|
|||||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k GeminiKey) GetAPIKey() string { return k.APIKey }
|
||||||
|
func (k GeminiKey) GetBaseURL() string { return k.BaseURL }
|
||||||
|
|
||||||
// GeminiModel describes a mapping between an alias and the actual upstream model name.
|
// GeminiModel describes a mapping between an alias and the actual upstream model name.
|
||||||
type GeminiModel struct {
|
type GeminiModel struct {
|
||||||
// Name is the upstream model identifier used when issuing requests.
|
// Name is the upstream model identifier used when issuing requests.
|
||||||
@@ -352,6 +424,10 @@ type OpenAICompatibility struct {
|
|||||||
// Name is the identifier for this OpenAI compatibility configuration.
|
// Name is the identifier for this OpenAI compatibility configuration.
|
||||||
Name string `yaml:"name" json:"name"`
|
Name string `yaml:"name" json:"name"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple providers or credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -387,6 +463,9 @@ type OpenAICompatibilityModel struct {
|
|||||||
Alias string `yaml:"alias" json:"alias"`
|
Alias string `yaml:"alias" json:"alias"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
||||||
|
func (m OpenAICompatibilityModel) GetAlias() string { return m.Alias }
|
||||||
|
|
||||||
// LoadConfig reads a YAML configuration file from the given path,
|
// LoadConfig reads a YAML configuration file from the given path,
|
||||||
// unmarshals it into a Config struct, applies environment variable overrides,
|
// unmarshals it into a Config struct, applies environment variable overrides,
|
||||||
// and returns it.
|
// and returns it.
|
||||||
@@ -405,6 +484,15 @@ func LoadConfig(configFile string) (*Config, error) {
|
|||||||
// If optional is true and the file is missing, it returns an empty Config.
|
// If optional is true and the file is missing, it returns an empty Config.
|
||||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||||
|
// Perform oauth-model-alias migration before loading config.
|
||||||
|
// This migrates oauth-model-mappings to oauth-model-alias if needed.
|
||||||
|
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||||
|
// Log warning but don't fail - config loading should still work
|
||||||
|
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||||
|
} else if migrated {
|
||||||
|
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||||
|
}
|
||||||
|
|
||||||
// Read the entire configuration file into memory.
|
// Read the entire configuration file into memory.
|
||||||
data, err := os.ReadFile(configFile)
|
data, err := os.ReadFile(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -428,6 +516,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||||
cfg.LoggingToFile = false
|
cfg.LoggingToFile = false
|
||||||
cfg.LogsMaxTotalSizeMB = 0
|
cfg.LogsMaxTotalSizeMB = 0
|
||||||
|
cfg.ErrorLogsMaxFiles = 10
|
||||||
cfg.UsageStatisticsEnabled = false
|
cfg.UsageStatisticsEnabled = false
|
||||||
cfg.DisableCooling = false
|
cfg.DisableCooling = false
|
||||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||||
@@ -476,6 +565,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.LogsMaxTotalSizeMB = 0
|
cfg.LogsMaxTotalSizeMB = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.ErrorLogsMaxFiles < 0 {
|
||||||
|
cfg.ErrorLogsMaxFiles = 10
|
||||||
|
}
|
||||||
|
|
||||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||||
syncInlineAccessProvider(&cfg)
|
syncInlineAccessProvider(&cfg)
|
||||||
|
|
||||||
@@ -497,8 +590,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Normalize OAuth provider model exclusion map.
|
// Normalize OAuth provider model exclusion map.
|
||||||
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
||||||
|
|
||||||
// Normalize global OAuth model name mappings.
|
// Normalize global OAuth model name aliases.
|
||||||
cfg.SanitizeOAuthModelMappings()
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
// Validate raw payload rules and drop invalid entries.
|
||||||
|
cfg.SanitizePayloadRules()
|
||||||
|
|
||||||
if cfg.legacyMigrationPending {
|
if cfg.legacyMigrationPending {
|
||||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||||
@@ -516,48 +612,97 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
|
// SanitizePayloadRules validates raw JSON payload rule params and drops invalid rules.
|
||||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
func (cfg *Config) SanitizePayloadRules() {
|
||||||
// and ensures (From, To) pairs are unique within each channel.
|
if cfg == nil {
|
||||||
func (cfg *Config) SanitizeOAuthModelMappings() {
|
|
||||||
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
|
cfg.Payload.DefaultRaw = sanitizePayloadRawRules(cfg.Payload.DefaultRaw, "default-raw")
|
||||||
for rawChannel, mappings := range cfg.OAuthModelMappings {
|
cfg.Payload.OverrideRaw = sanitizePayloadRawRules(cfg.Payload.OverrideRaw, "override-raw")
|
||||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
}
|
||||||
if channel == "" || len(mappings) == 0 {
|
|
||||||
|
func sanitizePayloadRawRules(rules []PayloadRule, section string) []PayloadRule {
|
||||||
|
if len(rules) == 0 {
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
out := make([]PayloadRule, 0, len(rules))
|
||||||
|
for i := range rules {
|
||||||
|
rule := rules[i]
|
||||||
|
if len(rule.Params) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seenName := make(map[string]struct{}, len(mappings))
|
invalid := false
|
||||||
seenAlias := make(map[string]struct{}, len(mappings))
|
for path, value := range rule.Params {
|
||||||
clean := make([]ModelNameMapping, 0, len(mappings))
|
raw, ok := payloadRawString(value)
|
||||||
for _, mapping := range mappings {
|
if !ok {
|
||||||
name := strings.TrimSpace(mapping.Name)
|
continue
|
||||||
alias := strings.TrimSpace(mapping.Alias)
|
}
|
||||||
|
trimmed := bytes.TrimSpace(raw)
|
||||||
|
if len(trimmed) == 0 || !json.Valid(trimmed) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"section": section,
|
||||||
|
"rule_index": i + 1,
|
||||||
|
"param": path,
|
||||||
|
}).Warn("payload rule dropped: invalid raw JSON")
|
||||||
|
invalid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if invalid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, rule)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func payloadRawString(value any) ([]byte, bool) {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case string:
|
||||||
|
return []byte(typed), true
|
||||||
|
case []byte:
|
||||||
|
return typed, true
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||||
|
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||||
|
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||||
|
func (cfg *Config) SanitizeOAuthModelAlias() {
|
||||||
|
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
|
||||||
|
for rawChannel, aliases := range cfg.OAuthModelAlias {
|
||||||
|
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||||
|
if channel == "" || len(aliases) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenAlias := make(map[string]struct{}, len(aliases))
|
||||||
|
clean := make([]OAuthModelAlias, 0, len(aliases))
|
||||||
|
for _, entry := range aliases {
|
||||||
|
name := strings.TrimSpace(entry.Name)
|
||||||
|
alias := strings.TrimSpace(entry.Alias)
|
||||||
if name == "" || alias == "" {
|
if name == "" || alias == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.EqualFold(name, alias) {
|
if strings.EqualFold(name, alias) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
nameKey := strings.ToLower(name)
|
|
||||||
aliasKey := strings.ToLower(alias)
|
aliasKey := strings.ToLower(alias)
|
||||||
if _, ok := seenName[nameKey]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ok := seenAlias[aliasKey]; ok {
|
if _, ok := seenAlias[aliasKey]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seenName[nameKey] = struct{}{}
|
|
||||||
seenAlias[aliasKey] = struct{}{}
|
seenAlias[aliasKey] = struct{}{}
|
||||||
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
|
clean = append(clean, OAuthModelAlias{Name: name, Alias: alias, Fork: entry.Fork})
|
||||||
}
|
}
|
||||||
if len(clean) > 0 {
|
if len(clean) > 0 {
|
||||||
out[channel] = clean
|
out[channel] = clean
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cfg.OAuthModelMappings = out
|
cfg.OAuthModelAlias = out
|
||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
|
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
|
||||||
@@ -797,6 +942,7 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
|||||||
removeLegacyGenerativeLanguageKeys(original.Content[0])
|
removeLegacyGenerativeLanguageKeys(original.Content[0])
|
||||||
|
|
||||||
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-excluded-models")
|
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-excluded-models")
|
||||||
|
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-model-alias")
|
||||||
|
|
||||||
// Merge generated into original in-place, preserving comments/order of existing nodes.
|
// Merge generated into original in-place, preserving comments/order of existing nodes.
|
||||||
mergeMappingPreserve(original.Content[0], generated.Content[0])
|
mergeMappingPreserve(original.Content[0], generated.Content[0])
|
||||||
@@ -1287,6 +1433,16 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
|
|||||||
}
|
}
|
||||||
srcIdx := findMapKeyIndex(srcRoot, key)
|
srcIdx := findMapKeyIndex(srcRoot, key)
|
||||||
if srcIdx < 0 {
|
if srcIdx < 0 {
|
||||||
|
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
|
||||||
|
//
|
||||||
|
// Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the
|
||||||
|
// oauth-model-alias key is missing, migration will add the default antigravity aliases.
|
||||||
|
// When users delete the last channel from oauth-model-alias via the management API,
|
||||||
|
// we want that deletion to persist across hot reloads and restarts.
|
||||||
|
if key == "oauth-model-alias" {
|
||||||
|
dstRoot.Content[dstIdx+1] = &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||||
|
return
|
||||||
|
}
|
||||||
removeMapKey(dstRoot, key)
|
removeMapKey(dstRoot, key)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
275
internal/config/oauth_model_alias_migration.go
Normal file
275
internal/config/oauth_model_alias_migration.go
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// antigravityModelConversionTable maps old built-in aliases to actual model names
|
||||||
|
// for the antigravity channel during migration.
|
||||||
|
var antigravityModelConversionTable = map[string]string{
|
||||||
|
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
|
||||||
|
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||||
|
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||||
|
"gemini-3-flash-preview": "gemini-3-flash",
|
||||||
|
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||||
|
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||||
|
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||||
|
// for the antigravity channel when neither field exists.
|
||||||
|
func defaultAntigravityAliases() []OAuthModelAlias {
|
||||||
|
return []OAuthModelAlias{
|
||||||
|
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
|
||||||
|
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
|
||||||
|
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
|
||||||
|
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
|
||||||
|
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
||||||
|
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
||||||
|
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
|
||||||
|
// to oauth-model-alias at startup. Returns true if migration was performed.
|
||||||
|
//
|
||||||
|
// Migration flow:
|
||||||
|
// 1. Check if oauth-model-alias exists -> skip migration
|
||||||
|
// 2. Check if oauth-model-mappings exists -> convert and migrate
|
||||||
|
// - For antigravity channel, convert old built-in aliases to actual model names
|
||||||
|
//
|
||||||
|
// 3. Neither exists -> add default antigravity config
|
||||||
|
func MigrateOAuthModelAlias(configFile string) (bool, error) {
|
||||||
|
data, err := os.ReadFile(configFile)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse YAML into node tree to preserve structure
|
||||||
|
var root yaml.Node
|
||||||
|
if err := yaml.Unmarshal(data, &root); err != nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
rootMap := root.Content[0]
|
||||||
|
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if oauth-model-alias already exists
|
||||||
|
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if oauth-model-mappings exists
|
||||||
|
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
|
||||||
|
if oldIdx >= 0 {
|
||||||
|
// Migrate from old field
|
||||||
|
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Neither field exists - add default antigravity config
|
||||||
|
return addDefaultAntigravityConfig(configFile, &root, rootMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
|
||||||
|
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
|
||||||
|
if oldIdx+1 >= len(rootMap.Content) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
oldValue := rootMap.Content[oldIdx+1]
|
||||||
|
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the old aliases
|
||||||
|
oldAliases := parseOldAliasNode(oldValue)
|
||||||
|
if len(oldAliases) == 0 {
|
||||||
|
// Remove the old field and write
|
||||||
|
removeMapKeyByIndex(rootMap, oldIdx)
|
||||||
|
return writeYAMLNode(configFile, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert model names for antigravity channel
|
||||||
|
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
|
||||||
|
for channel, entries := range oldAliases {
|
||||||
|
converted := make([]OAuthModelAlias, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
newEntry := OAuthModelAlias{
|
||||||
|
Name: entry.Name,
|
||||||
|
Alias: entry.Alias,
|
||||||
|
Fork: entry.Fork,
|
||||||
|
}
|
||||||
|
// Convert model names for antigravity channel
|
||||||
|
if strings.EqualFold(channel, "antigravity") {
|
||||||
|
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
|
||||||
|
newEntry.Name = actual
|
||||||
|
}
|
||||||
|
}
|
||||||
|
converted = append(converted, newEntry)
|
||||||
|
}
|
||||||
|
newAliases[channel] = converted
|
||||||
|
}
|
||||||
|
|
||||||
|
// For antigravity channel, supplement missing default aliases
|
||||||
|
if antigravityEntries, exists := newAliases["antigravity"]; exists {
|
||||||
|
// Build a set of already configured model names (upstream names)
|
||||||
|
configuredModels := make(map[string]bool, len(antigravityEntries))
|
||||||
|
for _, entry := range antigravityEntries {
|
||||||
|
configuredModels[entry.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add missing default aliases
|
||||||
|
for _, defaultAlias := range defaultAntigravityAliases() {
|
||||||
|
if !configuredModels[defaultAlias.Name] {
|
||||||
|
antigravityEntries = append(antigravityEntries, defaultAlias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newAliases["antigravity"] = antigravityEntries
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build new node
|
||||||
|
newNode := buildOAuthModelAliasNode(newAliases)
|
||||||
|
|
||||||
|
// Replace old key with new key and value
|
||||||
|
rootMap.Content[oldIdx].Value = "oauth-model-alias"
|
||||||
|
rootMap.Content[oldIdx+1] = newNode
|
||||||
|
|
||||||
|
return writeYAMLNode(configFile, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addDefaultAntigravityConfig adds the default antigravity configuration
|
||||||
|
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
|
||||||
|
defaults := map[string][]OAuthModelAlias{
|
||||||
|
"antigravity": defaultAntigravityAliases(),
|
||||||
|
}
|
||||||
|
newNode := buildOAuthModelAliasNode(defaults)
|
||||||
|
|
||||||
|
// Add new key-value pair
|
||||||
|
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
|
||||||
|
rootMap.Content = append(rootMap.Content, keyNode, newNode)
|
||||||
|
|
||||||
|
return writeYAMLNode(configFile, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOldAliasNode parses the old oauth-model-mappings node structure
|
||||||
|
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
|
||||||
|
if node == nil || node.Kind != yaml.MappingNode {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make(map[string][]OAuthModelAlias)
|
||||||
|
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||||
|
channelNode := node.Content[i]
|
||||||
|
entriesNode := node.Content[i+1]
|
||||||
|
if channelNode == nil || entriesNode == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
|
||||||
|
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
|
||||||
|
for _, entryNode := range entriesNode.Content {
|
||||||
|
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry := parseAliasEntry(entryNode)
|
||||||
|
if entry.Name != "" && entry.Alias != "" {
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(entries) > 0 {
|
||||||
|
result[channel] = entries
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAliasEntry parses a single alias entry node
|
||||||
|
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
|
||||||
|
var entry OAuthModelAlias
|
||||||
|
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||||
|
keyNode := node.Content[i]
|
||||||
|
valNode := node.Content[i+1]
|
||||||
|
if keyNode == nil || valNode == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
|
||||||
|
case "name":
|
||||||
|
entry.Name = strings.TrimSpace(valNode.Value)
|
||||||
|
case "alias":
|
||||||
|
entry.Alias = strings.TrimSpace(valNode.Value)
|
||||||
|
case "fork":
|
||||||
|
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
|
||||||
|
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
|
||||||
|
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||||
|
for channel, entries := range aliases {
|
||||||
|
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
|
||||||
|
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
|
||||||
|
for _, entry := range entries {
|
||||||
|
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||||
|
entryNode.Content = append(entryNode.Content,
|
||||||
|
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
|
||||||
|
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
|
||||||
|
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
|
||||||
|
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
|
||||||
|
)
|
||||||
|
if entry.Fork {
|
||||||
|
entryNode.Content = append(entryNode.Content,
|
||||||
|
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
|
||||||
|
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
entriesNode.Content = append(entriesNode.Content, entryNode)
|
||||||
|
}
|
||||||
|
node.Content = append(node.Content, channelNode, entriesNode)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
|
||||||
|
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
|
||||||
|
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeYAMLNode writes the YAML node tree back to file
|
||||||
|
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
|
||||||
|
f, err := os.Create(configFile)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
enc := yaml.NewEncoder(f)
|
||||||
|
enc.SetIndent(2)
|
||||||
|
if err := enc.Encode(root); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if err := enc.Close(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
242
internal/config/oauth_model_alias_migration_test.go
Normal file
242
internal/config/oauth_model_alias_migration_test.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
configFile := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
content := `oauth-model-alias:
|
||||||
|
gemini-cli:
|
||||||
|
- name: "gemini-2.5-pro"
|
||||||
|
alias: "g2.5p"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if migrated {
|
||||||
|
t.Fatal("expected no migration when oauth-model-alias already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file unchanged
|
||||||
|
data, _ := os.ReadFile(configFile)
|
||||||
|
if !strings.Contains(string(data), "oauth-model-alias:") {
|
||||||
|
t.Fatal("file should still contain oauth-model-alias")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
configFile := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
content := `oauth-model-mappings:
|
||||||
|
gemini-cli:
|
||||||
|
- name: "gemini-2.5-pro"
|
||||||
|
alias: "g2.5p"
|
||||||
|
fork: true
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !migrated {
|
||||||
|
t.Fatal("expected migration to occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify new field exists and old field removed
|
||||||
|
data, _ := os.ReadFile(configFile)
|
||||||
|
if strings.Contains(string(data), "oauth-model-mappings:") {
|
||||||
|
t.Fatal("old field should be removed")
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(data), "oauth-model-alias:") {
|
||||||
|
t.Fatal("new field should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and verify structure
|
||||||
|
var root yaml.Node
|
||||||
|
if err := yaml.Unmarshal(data, &root); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
configFile := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
// Use old model names that should be converted
|
||||||
|
content := `oauth-model-mappings:
|
||||||
|
antigravity:
|
||||||
|
- name: "gemini-2.5-computer-use-preview-10-2025"
|
||||||
|
alias: "computer-use"
|
||||||
|
- name: "gemini-3-pro-preview"
|
||||||
|
alias: "g3p"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !migrated {
|
||||||
|
t.Fatal("expected migration to occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify model names were converted
|
||||||
|
data, _ := os.ReadFile(configFile)
|
||||||
|
content = string(data)
|
||||||
|
if !strings.Contains(content, "rev19-uic3-1p") {
|
||||||
|
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "gemini-3-pro-high") {
|
||||||
|
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify missing default aliases were supplemented
|
||||||
|
if !strings.Contains(content, "gemini-3-pro-image") {
|
||||||
|
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "gemini-3-flash") {
|
||||||
|
t.Fatal("expected missing default alias gemini-3-flash to be added")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "claude-sonnet-4-5") {
|
||||||
|
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
|
||||||
|
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
||||||
|
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
configFile := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
content := `debug: true
|
||||||
|
port: 8080
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !migrated {
|
||||||
|
t.Fatal("expected migration to add default config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify default antigravity config was added
|
||||||
|
data, _ := os.ReadFile(configFile)
|
||||||
|
content = string(data)
|
||||||
|
if !strings.Contains(content, "oauth-model-alias:") {
|
||||||
|
t.Fatal("expected oauth-model-alias to be added")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "antigravity:") {
|
||||||
|
t.Fatal("expected antigravity channel to be added")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "rev19-uic3-1p") {
|
||||||
|
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
configFile := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
content := `debug: true
|
||||||
|
port: 8080
|
||||||
|
oauth-model-mappings:
|
||||||
|
gemini-cli:
|
||||||
|
- name: "test"
|
||||||
|
alias: "t"
|
||||||
|
api-keys:
|
||||||
|
- "key1"
|
||||||
|
- "key2"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !migrated {
|
||||||
|
t.Fatal("expected migration to occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify other config preserved
|
||||||
|
data, _ := os.ReadFile(configFile)
|
||||||
|
content = string(data)
|
||||||
|
if !strings.Contains(content, "debug: true") {
|
||||||
|
t.Fatal("expected debug field to be preserved")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "port: 8080") {
|
||||||
|
t.Fatal("expected port field to be preserved")
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "api-keys:") {
|
||||||
|
t.Fatal("expected api-keys field to be preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error for nonexistent file: %v", err)
|
||||||
|
}
|
||||||
|
if migrated {
|
||||||
|
t.Fatal("expected no migration for nonexistent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
configFile := filepath.Join(dir, "config.yaml")
|
||||||
|
|
||||||
|
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if migrated {
|
||||||
|
t.Fatal("expected no migration for empty file")
|
||||||
|
}
|
||||||
|
}
|
||||||
56
internal/config/oauth_model_alias_test.go
Normal file
56
internal/config/oauth_model_alias_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSanitizeOAuthModelAlias_PreservesForkFlag(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||||
|
" CoDeX ": {
|
||||||
|
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
|
||||||
|
{Name: "gpt-6", Alias: "g6"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
aliases := cfg.OAuthModelAlias["codex"]
|
||||||
|
if len(aliases) != 2 {
|
||||||
|
t.Fatalf("expected 2 sanitized aliases, got %d", len(aliases))
|
||||||
|
}
|
||||||
|
if aliases[0].Name != "gpt-5" || aliases[0].Alias != "g5" || !aliases[0].Fork {
|
||||||
|
t.Fatalf("expected first alias to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", aliases[0].Name, aliases[0].Alias, aliases[0].Fork)
|
||||||
|
}
|
||||||
|
if aliases[1].Name != "gpt-6" || aliases[1].Alias != "g6" || aliases[1].Fork {
|
||||||
|
t.Fatalf("expected second alias to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", aliases[1].Name, aliases[1].Alias, aliases[1].Fork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||||
|
"antigravity": {
|
||||||
|
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||||
|
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
|
||||||
|
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
aliases := cfg.OAuthModelAlias["antigravity"]
|
||||||
|
expected := []OAuthModelAlias{
|
||||||
|
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||||
|
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
|
||||||
|
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
|
||||||
|
}
|
||||||
|
if len(aliases) != len(expected) {
|
||||||
|
t.Fatalf("expected %d sanitized aliases, got %d", len(expected), len(aliases))
|
||||||
|
}
|
||||||
|
for i, exp := range expected {
|
||||||
|
if aliases[i].Name != exp.Name || aliases[i].Alias != exp.Alias || aliases[i].Fork != exp.Fork {
|
||||||
|
t.Fatalf("expected alias %d to be name=%q alias=%q fork=%v, got name=%q alias=%q fork=%v", i, exp.Name, exp.Alias, exp.Fork, aliases[i].Name, aliases[i].Alias, aliases[i].Fork)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,6 +25,10 @@ type SDKConfig struct {
|
|||||||
|
|
||||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||||
|
|
||||||
|
// NonStreamKeepAliveInterval controls how often blank lines are emitted for non-streaming responses.
|
||||||
|
// <= 0 disables keep-alives. Value is in seconds.
|
||||||
|
NonStreamKeepAliveInterval int `yaml:"nonstream-keepalive-interval,omitempty" json:"nonstream-keepalive-interval,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamingConfig holds server streaming behavior configuration.
|
// StreamingConfig holds server streaming behavior configuration.
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ type VertexCompatKey struct {
|
|||||||
// Maps to the x-goog-api-key header.
|
// Maps to the x-goog-api-key header.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -32,6 +36,9 @@ type VertexCompatKey struct {
|
|||||||
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
|
||||||
|
func (k VertexCompatKey) GetBaseURL() string { return k.BaseURL }
|
||||||
|
|
||||||
// VertexCompatModel represents a model configuration for Vertex compatibility,
|
// VertexCompatModel represents a model configuration for Vertex compatibility,
|
||||||
// including the actual model name and its alias for API routing.
|
// including the actual model name and its alias for API routing.
|
||||||
type VertexCompatModel struct {
|
type VertexCompatModel struct {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package logging
|
package logging
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
@@ -112,6 +113,11 @@ func isAIAPIPath(path string) bool {
|
|||||||
// - gin.HandlerFunc: A middleware handler for panic recovery
|
// - gin.HandlerFunc: A middleware handler for panic recovery
|
||||||
func GinLogrusRecovery() gin.HandlerFunc {
|
func GinLogrusRecovery() gin.HandlerFunc {
|
||||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||||
|
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||||
|
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
}
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"panic": recovered,
|
"panic": recovered,
|
||||||
"stack": string(debug.Stack()),
|
"stack": string(debug.Stack()),
|
||||||
|
|||||||
60
internal/logging/gin_logger_test.go
Normal file
60
internal/logging/gin_logger_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
engine := gin.New()
|
||||||
|
engine.Use(GinLogrusRecovery())
|
||||||
|
engine.GET("/abort", func(c *gin.Context) {
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
recovered := recover()
|
||||||
|
if recovered == nil {
|
||||||
|
t.Fatalf("expected panic, got nil")
|
||||||
|
}
|
||||||
|
err, ok := recovered.(error)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected error panic, got %T", recovered)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, http.ErrAbortHandler) {
|
||||||
|
t.Fatalf("expected ErrAbortHandler, got %v", err)
|
||||||
|
}
|
||||||
|
if err != http.ErrAbortHandler {
|
||||||
|
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
engine.ServeHTTP(recorder, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
engine := gin.New()
|
||||||
|
engine.Use(GinLogrusRecovery())
|
||||||
|
engine.GET("/panic", func(c *gin.Context) {
|
||||||
|
panic("boom")
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
engine.ServeHTTP(recorder, req)
|
||||||
|
if recorder.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,6 +29,9 @@ var (
|
|||||||
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
|
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
|
||||||
type LogFormatter struct{}
|
type LogFormatter struct{}
|
||||||
|
|
||||||
|
// logFieldOrder defines the display order for common log fields.
|
||||||
|
var logFieldOrder = []string{"provider", "model", "mode", "budget", "level", "original_mode", "original_value", "min", "max", "clamped_to", "error"}
|
||||||
|
|
||||||
// Format renders a single log entry with custom formatting.
|
// Format renders a single log entry with custom formatting.
|
||||||
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
||||||
var buffer *bytes.Buffer
|
var buffer *bytes.Buffer
|
||||||
@@ -52,11 +55,25 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
levelStr := fmt.Sprintf("%-5s", level)
|
levelStr := fmt.Sprintf("%-5s", level)
|
||||||
|
|
||||||
|
// Build fields string (only print fields in logFieldOrder)
|
||||||
|
var fieldsStr string
|
||||||
|
if len(entry.Data) > 0 {
|
||||||
|
var fields []string
|
||||||
|
for _, k := range logFieldOrder {
|
||||||
|
if v, ok := entry.Data[k]; ok {
|
||||||
|
fields = append(fields, fmt.Sprintf("%s=%v", k, v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(fields) > 0 {
|
||||||
|
fieldsStr = " " + strings.Join(fields, " ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var formatted string
|
var formatted string
|
||||||
if entry.Caller != nil {
|
if entry.Caller != nil {
|
||||||
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
|
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s%s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message, fieldsStr)
|
||||||
} else {
|
} else {
|
||||||
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
|
formatted = fmt.Sprintf("[%s] [%s] [%s] %s%s\n", timestamp, reqID, levelStr, message, fieldsStr)
|
||||||
}
|
}
|
||||||
buffer.WriteString(formatted)
|
buffer.WriteString(formatted)
|
||||||
|
|
||||||
@@ -104,6 +121,24 @@ func isDirWritable(dir string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveLogDirectory determines the directory used for application logs.
|
||||||
|
func ResolveLogDirectory(cfg *config.Config) string {
|
||||||
|
logDir := "logs"
|
||||||
|
if base := util.WritablePath(); base != "" {
|
||||||
|
return filepath.Join(base, "logs")
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return logDir
|
||||||
|
}
|
||||||
|
if !isDirWritable(logDir) {
|
||||||
|
authDir := strings.TrimSpace(cfg.AuthDir)
|
||||||
|
if authDir != "" {
|
||||||
|
logDir = filepath.Join(authDir, "logs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logDir
|
||||||
|
}
|
||||||
|
|
||||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||||
// until the total size is within the limit.
|
// until the total size is within the limit.
|
||||||
@@ -113,12 +148,7 @@ func ConfigureLogOutput(cfg *config.Config) error {
|
|||||||
writerMu.Lock()
|
writerMu.Lock()
|
||||||
defer writerMu.Unlock()
|
defer writerMu.Unlock()
|
||||||
|
|
||||||
logDir := "logs"
|
logDir := ResolveLogDirectory(cfg)
|
||||||
if base := util.WritablePath(); base != "" {
|
|
||||||
logDir = filepath.Join(base, "logs")
|
|
||||||
} else if !isDirWritable(logDir) {
|
|
||||||
logDir = filepath.Join(cfg.AuthDir, "logs")
|
|
||||||
}
|
|
||||||
|
|
||||||
protectedPath := ""
|
protectedPath := ""
|
||||||
if cfg.LoggingToFile {
|
if cfg.LoggingToFile {
|
||||||
|
|||||||
@@ -44,10 +44,12 @@ type RequestLogger interface {
|
|||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
// - requestID: Optional request ID for log file naming
|
// - requestID: Optional request ID for log file naming
|
||||||
|
// - requestTimestamp: When the request was received
|
||||||
|
// - apiResponseTimestamp: When the API response was received
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
|
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||||
|
|
||||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||||
//
|
//
|
||||||
@@ -109,6 +111,12 @@ type StreamingLogWriter interface {
|
|||||||
// - error: An error if writing fails, nil otherwise
|
// - error: An error if writing fails, nil otherwise
|
||||||
WriteAPIResponse(apiResponse []byte) error
|
WriteAPIResponse(apiResponse []byte) error
|
||||||
|
|
||||||
|
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - timestamp: The time when first response chunk was received
|
||||||
|
SetFirstChunkTimestamp(timestamp time.Time)
|
||||||
|
|
||||||
// Close finalizes the log file and cleans up resources.
|
// Close finalizes the log file and cleans up resources.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
@@ -124,6 +132,9 @@ type FileRequestLogger struct {
|
|||||||
|
|
||||||
// logsDir is the directory where log files are stored.
|
// logsDir is the directory where log files are stored.
|
||||||
logsDir string
|
logsDir string
|
||||||
|
|
||||||
|
// errorLogsMaxFiles limits the number of error log files retained.
|
||||||
|
errorLogsMaxFiles int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFileRequestLogger creates a new file-based request logger.
|
// NewFileRequestLogger creates a new file-based request logger.
|
||||||
@@ -133,10 +144,11 @@ type FileRequestLogger struct {
|
|||||||
// - logsDir: The directory where log files should be stored (can be relative)
|
// - logsDir: The directory where log files should be stored (can be relative)
|
||||||
// - configDir: The directory of the configuration file; when logsDir is
|
// - configDir: The directory of the configuration file; when logsDir is
|
||||||
// relative, it will be resolved relative to this directory
|
// relative, it will be resolved relative to this directory
|
||||||
|
// - errorLogsMaxFiles: Maximum number of error log files to retain (0 = no cleanup)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *FileRequestLogger: A new file-based request logger instance
|
// - *FileRequestLogger: A new file-based request logger instance
|
||||||
func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger {
|
func NewFileRequestLogger(enabled bool, logsDir string, configDir string, errorLogsMaxFiles int) *FileRequestLogger {
|
||||||
// Resolve logsDir relative to the configuration file directory when it's not absolute.
|
// Resolve logsDir relative to the configuration file directory when it's not absolute.
|
||||||
if !filepath.IsAbs(logsDir) {
|
if !filepath.IsAbs(logsDir) {
|
||||||
// If configDir is provided, resolve logsDir relative to it.
|
// If configDir is provided, resolve logsDir relative to it.
|
||||||
@@ -145,8 +157,9 @@ func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileR
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &FileRequestLogger{
|
return &FileRequestLogger{
|
||||||
enabled: enabled,
|
enabled: enabled,
|
||||||
logsDir: logsDir,
|
logsDir: logsDir,
|
||||||
|
errorLogsMaxFiles: errorLogsMaxFiles,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,6 +180,11 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
|||||||
l.enabled = enabled
|
l.enabled = enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetErrorLogsMaxFiles updates the maximum number of error log files to retain.
|
||||||
|
func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
|
||||||
|
l.errorLogsMaxFiles = maxFiles
|
||||||
|
}
|
||||||
|
|
||||||
// LogRequest logs a complete non-streaming request/response cycle to a file.
|
// LogRequest logs a complete non-streaming request/response cycle to a file.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
@@ -180,20 +198,22 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
|||||||
// - apiRequest: The API request data
|
// - apiRequest: The API request data
|
||||||
// - apiResponse: The API response data
|
// - apiResponse: The API response data
|
||||||
// - requestID: Optional request ID for log file naming
|
// - requestID: Optional request ID for log file naming
|
||||||
|
// - requestTimestamp: When the request was received
|
||||||
|
// - apiResponseTimestamp: When the API response was received
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if logging fails, nil otherwise
|
// - error: An error if logging fails, nil otherwise
|
||||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
|
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
|
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||||
if !l.enabled && !force {
|
if !l.enabled && !force {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -247,6 +267,8 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
|||||||
responseHeaders,
|
responseHeaders,
|
||||||
responseToWrite,
|
responseToWrite,
|
||||||
decompressErr,
|
decompressErr,
|
||||||
|
requestTimestamp,
|
||||||
|
apiResponseTimestamp,
|
||||||
)
|
)
|
||||||
if errClose := logFile.Close(); errClose != nil {
|
if errClose := logFile.Close(); errClose != nil {
|
||||||
log.WithError(errClose).Warn("failed to close request log file")
|
log.WithError(errClose).Warn("failed to close request log file")
|
||||||
@@ -421,8 +443,12 @@ func (l *FileRequestLogger) sanitizeForFilename(path string) string {
|
|||||||
return sanitized
|
return sanitized
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupOldErrorLogs keeps only the newest 10 forced error log files.
|
// cleanupOldErrorLogs keeps only the newest errorLogsMaxFiles forced error log files.
|
||||||
func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
||||||
|
if l.errorLogsMaxFiles <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
entries, errRead := os.ReadDir(l.logsDir)
|
entries, errRead := os.ReadDir(l.logsDir)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
return errRead
|
return errRead
|
||||||
@@ -450,7 +476,7 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
|||||||
files = append(files, logFile{name: name, modTime: info.ModTime()})
|
files = append(files, logFile{name: name, modTime: info.ModTime()})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(files) <= 10 {
|
if len(files) <= l.errorLogsMaxFiles {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,7 +484,7 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
|||||||
return files[i].modTime.After(files[j].modTime)
|
return files[i].modTime.After(files[j].modTime)
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, file := range files[10:] {
|
for _, file := range files[l.errorLogsMaxFiles:] {
|
||||||
if errRemove := os.Remove(filepath.Join(l.logsDir, file.name)); errRemove != nil {
|
if errRemove := os.Remove(filepath.Join(l.logsDir, file.name)); errRemove != nil {
|
||||||
log.WithError(errRemove).Warnf("failed to remove old error log: %s", file.name)
|
log.WithError(errRemove).Warnf("failed to remove old error log: %s", file.name)
|
||||||
}
|
}
|
||||||
@@ -499,17 +525,22 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
|||||||
responseHeaders map[string][]string,
|
responseHeaders map[string][]string,
|
||||||
response []byte,
|
response []byte,
|
||||||
decompressErr error,
|
decompressErr error,
|
||||||
|
requestTimestamp time.Time,
|
||||||
|
apiResponseTimestamp time.Time,
|
||||||
) error {
|
) error {
|
||||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, time.Now()); errWrite != nil {
|
if requestTimestamp.IsZero() {
|
||||||
|
requestTimestamp = time.Now()
|
||||||
|
}
|
||||||
|
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest); errWrite != nil {
|
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
|
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse); errWrite != nil {
|
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||||
@@ -583,7 +614,7 @@ func writeRequestInfoWithBody(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte) error {
|
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -601,6 +632,11 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
|||||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
if !timestamp.IsZero() {
|
||||||
|
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
@@ -974,6 +1010,9 @@ type FileStreamingLogWriter struct {
|
|||||||
|
|
||||||
// apiResponse stores the upstream API response data.
|
// apiResponse stores the upstream API response data.
|
||||||
apiResponse []byte
|
apiResponse []byte
|
||||||
|
|
||||||
|
// apiResponseTimestamp captures when the API response was received.
|
||||||
|
apiResponseTimestamp time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
|
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
|
||||||
@@ -1053,6 +1092,12 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||||
|
if !timestamp.IsZero() {
|
||||||
|
w.apiResponseTimestamp = timestamp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Close finalizes the log file and cleans up resources.
|
// Close finalizes the log file and cleans up resources.
|
||||||
// It writes all buffered data to the file in the correct order:
|
// It writes all buffered data to the file in the correct order:
|
||||||
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||||
@@ -1140,10 +1185,10 @@ func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
|||||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest); errWrite != nil {
|
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse); errWrite != nil {
|
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse, w.apiResponseTimestamp); errWrite != nil {
|
||||||
return errWrite
|
return errWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1220,6 +1265,8 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
||||||
|
|
||||||
// Close is a no-op implementation that does nothing and always returns nil.
|
// Close is a no-op implementation that does nothing and always returns nil.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
|
|||||||
@@ -7,12 +7,93 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// codexInstructionsEnabled controls whether CodexInstructionsForModel returns official instructions.
|
||||||
|
// When false (default), CodexInstructionsForModel returns (true, "") immediately.
|
||||||
|
// Set via SetCodexInstructionsEnabled from config.
|
||||||
|
var codexInstructionsEnabled atomic.Bool
|
||||||
|
|
||||||
|
// SetCodexInstructionsEnabled sets whether codex instructions processing is enabled.
|
||||||
|
func SetCodexInstructionsEnabled(enabled bool) {
|
||||||
|
codexInstructionsEnabled.Store(enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexInstructionsEnabled returns whether codex instructions processing is enabled.
|
||||||
|
func GetCodexInstructionsEnabled() bool {
|
||||||
|
return codexInstructionsEnabled.Load()
|
||||||
|
}
|
||||||
|
|
||||||
//go:embed codex_instructions
|
//go:embed codex_instructions
|
||||||
var codexInstructionsDir embed.FS
|
var codexInstructionsDir embed.FS
|
||||||
|
|
||||||
func CodexInstructionsForModel(modelName, systemInstructions string) (bool, string) {
|
//go:embed opencode_codex_instructions.txt
|
||||||
|
var opencodeCodexInstructions string
|
||||||
|
|
||||||
|
const (
|
||||||
|
codexUserAgentKey = "__cpa_user_agent"
|
||||||
|
userAgentOpenAISDK = "opencode/"
|
||||||
|
)
|
||||||
|
|
||||||
|
func InjectCodexUserAgent(raw []byte, userAgent string) []byte {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(userAgent)
|
||||||
|
if trimmed == "" {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
updated, err := sjson.SetBytes(raw, codexUserAgentKey, trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractCodexUserAgent(raw []byte) string {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(gjson.GetBytes(raw, codexUserAgentKey).String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripCodexUserAgent(raw []byte) []byte {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(raw, codexUserAgentKey).Exists() {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
updated, err := sjson.DeleteBytes(raw, codexUserAgentKey)
|
||||||
|
if err != nil {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
func codexInstructionsForOpenCode(systemInstructions string) (bool, string) {
|
||||||
|
if opencodeCodexInstructions == "" {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(systemInstructions, opencodeCodexInstructions) {
|
||||||
|
return true, ""
|
||||||
|
}
|
||||||
|
return false, opencodeCodexInstructions
|
||||||
|
}
|
||||||
|
|
||||||
|
func useOpenCodeInstructions(userAgent string) bool {
|
||||||
|
return strings.Contains(strings.ToLower(userAgent), userAgentOpenAISDK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsOpenCodeUserAgent(userAgent string) bool {
|
||||||
|
return useOpenCodeInstructions(userAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func codexInstructionsForCodex(modelName, systemInstructions string) (bool, string) {
|
||||||
entries, _ := codexInstructionsDir.ReadDir("codex_instructions")
|
entries, _ := codexInstructionsDir.ReadDir("codex_instructions")
|
||||||
|
|
||||||
lastPrompt := ""
|
lastPrompt := ""
|
||||||
@@ -57,3 +138,13 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
|||||||
return false, lastPrompt
|
return false, lastPrompt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CodexInstructionsForModel(modelName, systemInstructions, userAgent string) (bool, string) {
|
||||||
|
if !GetCodexInstructionsEnabled() {
|
||||||
|
return true, ""
|
||||||
|
}
|
||||||
|
if IsOpenCodeUserAgent(userAgent) {
|
||||||
|
return codexInstructionsForOpenCode(systemInstructions)
|
||||||
|
}
|
||||||
|
return codexInstructionsForCodex(modelName, systemInstructions)
|
||||||
|
}
|
||||||
|
|||||||
79
internal/misc/opencode_codex_instructions.txt
Normal file
79
internal/misc/opencode_codex_instructions.txt
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
You are OpenCode, the best coding agent on the planet.
|
||||||
|
|
||||||
|
You are an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||||
|
|
||||||
|
## Editing constraints
|
||||||
|
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||||
|
- Only add comments if they are necessary to make a non-obvious block easier to understand.
|
||||||
|
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||||
|
|
||||||
|
## Tool usage
|
||||||
|
- Prefer specialized tools over shell for file operations:
|
||||||
|
- Use Read to view files, Edit to modify files, and Write only when needed.
|
||||||
|
- Use Glob to find files by name and Grep to search file contents.
|
||||||
|
- Use Bash for terminal operations (git, bun, builds, tests, running scripts).
|
||||||
|
- Run tool calls in parallel when neither call needs the other’s output; otherwise run sequentially.
|
||||||
|
|
||||||
|
## Git and workspace hygiene
|
||||||
|
- You may be in a dirty git worktree.
|
||||||
|
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||||
|
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||||
|
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||||
|
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||||
|
- Do not amend commits unless explicitly requested.
|
||||||
|
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||||
|
|
||||||
|
## Frontend tasks
|
||||||
|
When doing frontend design tasks, avoid collapsing into bland, generic layouts.
|
||||||
|
Aim for interfaces that feel intentional and deliberate.
|
||||||
|
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||||
|
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||||
|
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||||
|
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||||
|
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||||
|
- Ensure the page loads properly on both desktop and mobile.
|
||||||
|
|
||||||
|
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||||
|
|
||||||
|
## Presenting your work and final message
|
||||||
|
|
||||||
|
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||||
|
|
||||||
|
- Default: be very concise; friendly coding teammate tone.
|
||||||
|
- Default: do the work without asking questions. Treat short tasks as sufficient direction; infer missing details by reading the codebase and following existing conventions.
|
||||||
|
- Questions: only ask when you are truly blocked after checking relevant context AND you cannot safely pick a reasonable default. This usually means one of:
|
||||||
|
* The request is ambiguous in a way that materially changes the result and you cannot disambiguate by reading the repo.
|
||||||
|
* The action is destructive/irreversible, touches production, or changes billing/security posture.
|
||||||
|
* You need a secret/credential/value that cannot be inferred (API key, account id, etc.).
|
||||||
|
- If you must ask: do all non-blocked work first, then ask exactly one targeted question, include your recommended default, and state what would change based on the answer.
|
||||||
|
- Never ask permission questions like "Should I proceed?" or "Do you want me to run tests?"; proceed with the most reasonable option and mention what you did.
|
||||||
|
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||||
|
- Skip heavy formatting for simple confirmations.
|
||||||
|
- Don't dump large files you've written; reference paths only.
|
||||||
|
- No "save/copy this file" - User is on the same machine.
|
||||||
|
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||||
|
- For code changes:
|
||||||
|
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
|
||||||
|
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||||
|
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||||
|
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||||
|
|
||||||
|
## Final answer structure and style guidelines
|
||||||
|
|
||||||
|
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||||
|
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||||
|
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||||
|
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||||
|
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||||
|
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||||
|
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
|
||||||
|
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||||
|
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||||
|
- File References: When referencing files in your response follow the below rules:
|
||||||
|
* Use inline code to make file paths clickable.
|
||||||
|
* Each reference should have a stand alone path. Even if it's the same file.
|
||||||
|
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||||
|
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||||
|
* Do not use URIs like file://, vscode://, or https://.
|
||||||
|
* Do not provide range of lines
|
||||||
|
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||||
@@ -1,784 +1,69 @@
|
|||||||
// Package registry provides model definitions for various AI service providers.
|
// Package registry provides model definitions and lookup helpers for various AI providers.
|
||||||
// This file contains static model definitions that can be used by clients
|
// Static model metadata is stored in model_definitions_static_data.go.
|
||||||
// when registering their supported models.
|
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
// GetClaudeModels returns the standard Claude model definitions
|
import (
|
||||||
func GetClaudeModels() []*ModelInfo {
|
"sort"
|
||||||
return []*ModelInfo{
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
{
|
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
||||||
ID: "claude-haiku-4-5-20251001",
|
// It returns nil when the channel is unknown.
|
||||||
Object: "model",
|
//
|
||||||
Created: 1759276800, // 2025-10-01
|
// Supported channels:
|
||||||
OwnedBy: "anthropic",
|
// - claude
|
||||||
Type: "claude",
|
// - gemini
|
||||||
DisplayName: "Claude 4.5 Haiku",
|
// - vertex
|
||||||
ContextLength: 200000,
|
// - gemini-cli
|
||||||
MaxCompletionTokens: 64000,
|
// - aistudio
|
||||||
// Thinking: not supported for Haiku models
|
// - codex
|
||||||
},
|
// - qwen
|
||||||
{
|
// - iflow
|
||||||
ID: "claude-sonnet-4-5-20250929",
|
// - antigravity (returns static overrides only)
|
||||||
Object: "model",
|
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||||
Created: 1759104000, // 2025-09-29
|
key := strings.ToLower(strings.TrimSpace(channel))
|
||||||
OwnedBy: "anthropic",
|
switch key {
|
||||||
Type: "claude",
|
case "claude":
|
||||||
DisplayName: "Claude 4.5 Sonnet",
|
return GetClaudeModels()
|
||||||
ContextLength: 200000,
|
case "gemini":
|
||||||
MaxCompletionTokens: 64000,
|
return GetGeminiModels()
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
case "vertex":
|
||||||
},
|
return GetGeminiVertexModels()
|
||||||
{
|
case "gemini-cli":
|
||||||
ID: "claude-opus-4-5-20251101",
|
return GetGeminiCLIModels()
|
||||||
Object: "model",
|
case "aistudio":
|
||||||
Created: 1761955200, // 2025-11-01
|
return GetAIStudioModels()
|
||||||
OwnedBy: "anthropic",
|
case "codex":
|
||||||
Type: "claude",
|
return GetOpenAIModels()
|
||||||
DisplayName: "Claude 4.5 Opus",
|
case "qwen":
|
||||||
Description: "Premium model combining maximum intelligence with practical performance",
|
return GetQwenModels()
|
||||||
ContextLength: 200000,
|
case "iflow":
|
||||||
MaxCompletionTokens: 64000,
|
return GetIFlowModels()
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
case "antigravity":
|
||||||
},
|
cfg := GetAntigravityModelConfig()
|
||||||
{
|
if len(cfg) == 0 {
|
||||||
ID: "claude-opus-4-1-20250805",
|
return nil
|
||||||
Object: "model",
|
}
|
||||||
Created: 1722945600, // 2025-08-05
|
models := make([]*ModelInfo, 0, len(cfg))
|
||||||
OwnedBy: "anthropic",
|
for modelID, entry := range cfg {
|
||||||
Type: "claude",
|
if modelID == "" || entry == nil {
|
||||||
DisplayName: "Claude 4.1 Opus",
|
continue
|
||||||
ContextLength: 200000,
|
}
|
||||||
MaxCompletionTokens: 32000,
|
models = append(models, &ModelInfo{
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
ID: modelID,
|
||||||
},
|
Object: "model",
|
||||||
{
|
OwnedBy: "antigravity",
|
||||||
ID: "claude-opus-4-20250514",
|
Type: "antigravity",
|
||||||
Object: "model",
|
Thinking: entry.Thinking,
|
||||||
Created: 1715644800, // 2025-05-14
|
MaxCompletionTokens: entry.MaxCompletionTokens,
|
||||||
OwnedBy: "anthropic",
|
})
|
||||||
Type: "claude",
|
}
|
||||||
DisplayName: "Claude 4 Opus",
|
sort.Slice(models, func(i, j int) bool {
|
||||||
ContextLength: 200000,
|
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
|
||||||
MaxCompletionTokens: 32000,
|
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "claude-sonnet-4-20250514",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1715644800, // 2025-05-14
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Type: "claude",
|
|
||||||
DisplayName: "Claude 4 Sonnet",
|
|
||||||
ContextLength: 200000,
|
|
||||||
MaxCompletionTokens: 64000,
|
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "claude-3-7-sonnet-20250219",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1708300800, // 2025-02-19
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Type: "claude",
|
|
||||||
DisplayName: "Claude 3.7 Sonnet",
|
|
||||||
ContextLength: 128000,
|
|
||||||
MaxCompletionTokens: 8192,
|
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "claude-3-5-haiku-20241022",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1729555200, // 2024-10-22
|
|
||||||
OwnedBy: "anthropic",
|
|
||||||
Type: "claude",
|
|
||||||
DisplayName: "Claude 3.5 Haiku",
|
|
||||||
ContextLength: 128000,
|
|
||||||
MaxCompletionTokens: 8192,
|
|
||||||
// Thinking: not supported for Haiku models
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGeminiModels returns the standard Gemini model definitions
|
|
||||||
func GetGeminiModels() []*ModelInfo {
|
|
||||||
return []*ModelInfo{
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-pro",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-pro",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Pro",
|
|
||||||
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash",
|
|
||||||
Version: "001",
|
|
||||||
DisplayName: "Gemini 2.5 Flash",
|
|
||||||
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash-lite",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1753142400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash-lite",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Flash Lite",
|
|
||||||
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-pro-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1737158400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-pro-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Pro Preview",
|
|
||||||
Description: "Gemini 3 Pro Preview",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-flash-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1765929600,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-flash-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Flash Preview",
|
|
||||||
Description: "Gemini 3 Flash Preview",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-pro-image-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1737158400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-pro-image-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Pro Image Preview",
|
|
||||||
Description: "Gemini 3 Pro Image Preview",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetGeminiVertexModels() []*ModelInfo {
|
|
||||||
return []*ModelInfo{
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-pro",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-pro",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Pro",
|
|
||||||
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash",
|
|
||||||
Version: "001",
|
|
||||||
DisplayName: "Gemini 2.5 Flash",
|
|
||||||
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash-lite",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1753142400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash-lite",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Flash Lite",
|
|
||||||
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-pro-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1737158400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-pro-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Pro Preview",
|
|
||||||
Description: "Gemini 3 Pro Preview",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-flash-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1765929600,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-flash-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Flash Preview",
|
|
||||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-pro-image-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1737158400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-pro-image-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Pro Image Preview",
|
|
||||||
Description: "Gemini 3 Pro Image Preview",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetGeminiCLIModels returns the standard Gemini model definitions
|
|
||||||
func GetGeminiCLIModels() []*ModelInfo {
|
|
||||||
return []*ModelInfo{
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-pro",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-pro",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Pro",
|
|
||||||
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash",
|
|
||||||
Version: "001",
|
|
||||||
DisplayName: "Gemini 2.5 Flash",
|
|
||||||
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash-lite",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1753142400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash-lite",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Flash Lite",
|
|
||||||
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-pro-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1737158400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-pro-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Pro Preview",
|
|
||||||
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-flash-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1765929600,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-flash-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Flash Preview",
|
|
||||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAIStudioModels returns the Gemini model definitions for AI Studio integrations
|
|
||||||
func GetAIStudioModels() []*ModelInfo {
|
|
||||||
return []*ModelInfo{
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-pro",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-pro",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Pro",
|
|
||||||
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash",
|
|
||||||
Version: "001",
|
|
||||||
DisplayName: "Gemini 2.5 Flash",
|
|
||||||
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash-lite",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1753142400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash-lite",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Flash Lite",
|
|
||||||
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-pro-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1737158400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-pro-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Pro Preview",
|
|
||||||
Description: "Gemini 3 Pro Preview",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-3-flash-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1765929600,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-3-flash-preview",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Gemini 3 Flash Preview",
|
|
||||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-pro-latest",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-pro-latest",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini Pro Latest",
|
|
||||||
Description: "Latest release of Gemini Pro",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-flash-latest",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1750118400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-flash-latest",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini Flash Latest",
|
|
||||||
Description: "Latest release of Gemini Flash",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-flash-lite-latest",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1753142400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-flash-lite-latest",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini Flash-Lite Latest",
|
|
||||||
Description: "Latest release of Gemini Flash-Lite",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 65536,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
Thinking: &ThinkingSupport{Min: 512, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash-image-preview",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1756166400,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash-image-preview",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Flash Image Preview",
|
|
||||||
Description: "State-of-the-art image generation and editing model.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 8192,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
// image models don't support thinkingConfig; leave Thinking nil
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gemini-2.5-flash-image",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1759363200,
|
|
||||||
OwnedBy: "google",
|
|
||||||
Type: "gemini",
|
|
||||||
Name: "models/gemini-2.5-flash-image",
|
|
||||||
Version: "2.5",
|
|
||||||
DisplayName: "Gemini 2.5 Flash Image",
|
|
||||||
Description: "State-of-the-art image generation and editing model.",
|
|
||||||
InputTokenLimit: 1048576,
|
|
||||||
OutputTokenLimit: 8192,
|
|
||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
|
||||||
// image models don't support thinkingConfig; leave Thinking nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOpenAIModels returns the standard OpenAI model definitions
|
|
||||||
func GetOpenAIModels() []*ModelInfo {
|
|
||||||
return []*ModelInfo{
|
|
||||||
{
|
|
||||||
ID: "gpt-5",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1754524800,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5-2025-08-07",
|
|
||||||
DisplayName: "GPT 5",
|
|
||||||
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"minimal", "low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gpt-5-codex",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1757894400,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5-2025-09-15",
|
|
||||||
DisplayName: "GPT 5 Codex",
|
|
||||||
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gpt-5-codex-mini",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1762473600,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5-2025-11-07",
|
|
||||||
DisplayName: "GPT 5 Codex Mini",
|
|
||||||
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gpt-5.1",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1762905600,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5.1-2025-11-12",
|
|
||||||
DisplayName: "GPT 5",
|
|
||||||
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gpt-5.1-codex",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1762905600,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5.1-2025-11-12",
|
|
||||||
DisplayName: "GPT 5.1 Codex",
|
|
||||||
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gpt-5.1-codex-mini",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1762905600,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5.1-2025-11-12",
|
|
||||||
DisplayName: "GPT 5.1 Codex Mini",
|
|
||||||
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gpt-5.1-codex-max",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1763424000,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5.1-max",
|
|
||||||
DisplayName: "GPT 5.1 Codex Max",
|
|
||||||
Description: "Stable version of GPT 5.1 Codex Max",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gpt-5.2",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1765440000,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5.2",
|
|
||||||
DisplayName: "GPT 5.2",
|
|
||||||
Description: "Stable version of GPT 5.2",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "gpt-5.2-codex",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1765440000,
|
|
||||||
OwnedBy: "openai",
|
|
||||||
Type: "openai",
|
|
||||||
Version: "gpt-5.2",
|
|
||||||
DisplayName: "GPT 5.2 Codex",
|
|
||||||
Description: "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
|
||||||
ContextLength: 400000,
|
|
||||||
MaxCompletionTokens: 128000,
|
|
||||||
SupportedParameters: []string{"tools"},
|
|
||||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetQwenModels returns the standard Qwen model definitions
|
|
||||||
func GetQwenModels() []*ModelInfo {
|
|
||||||
return []*ModelInfo{
|
|
||||||
{
|
|
||||||
ID: "qwen3-coder-plus",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1753228800,
|
|
||||||
OwnedBy: "qwen",
|
|
||||||
Type: "qwen",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Qwen3 Coder Plus",
|
|
||||||
Description: "Advanced code generation and understanding model",
|
|
||||||
ContextLength: 32768,
|
|
||||||
MaxCompletionTokens: 8192,
|
|
||||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "qwen3-coder-flash",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1753228800,
|
|
||||||
OwnedBy: "qwen",
|
|
||||||
Type: "qwen",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Qwen3 Coder Flash",
|
|
||||||
Description: "Fast code generation model",
|
|
||||||
ContextLength: 8192,
|
|
||||||
MaxCompletionTokens: 2048,
|
|
||||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "vision-model",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1758672000,
|
|
||||||
OwnedBy: "qwen",
|
|
||||||
Type: "qwen",
|
|
||||||
Version: "3.0",
|
|
||||||
DisplayName: "Qwen3 Vision Model",
|
|
||||||
Description: "Vision model model",
|
|
||||||
ContextLength: 32768,
|
|
||||||
MaxCompletionTokens: 2048,
|
|
||||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
|
|
||||||
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
|
|
||||||
// Uses level-based configuration so standard normalization flows apply before conversion.
|
|
||||||
var iFlowThinkingSupport = &ThinkingSupport{
|
|
||||||
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
|
||||||
func GetIFlowModels() []*ModelInfo {
|
|
||||||
entries := []struct {
|
|
||||||
ID string
|
|
||||||
DisplayName string
|
|
||||||
Description string
|
|
||||||
Created int64
|
|
||||||
Thinking *ThinkingSupport
|
|
||||||
}{
|
|
||||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
|
||||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
|
||||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
|
||||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
|
||||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
|
||||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
|
||||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
|
||||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
|
||||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
|
||||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
|
||||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
|
||||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200},
|
|
||||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
|
||||||
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
|
||||||
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
|
||||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
|
||||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
|
||||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
|
||||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
|
||||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
|
||||||
}
|
|
||||||
models := make([]*ModelInfo, 0, len(entries))
|
|
||||||
for _, entry := range entries {
|
|
||||||
models = append(models, &ModelInfo{
|
|
||||||
ID: entry.ID,
|
|
||||||
Object: "model",
|
|
||||||
Created: entry.Created,
|
|
||||||
OwnedBy: "iflow",
|
|
||||||
Type: "iflow",
|
|
||||||
DisplayName: entry.DisplayName,
|
|
||||||
Description: entry.Description,
|
|
||||||
Thinking: entry.Thinking,
|
|
||||||
})
|
})
|
||||||
}
|
return models
|
||||||
return models
|
default:
|
||||||
}
|
return nil
|
||||||
|
|
||||||
// AntigravityModelConfig captures static antigravity model overrides, including
|
|
||||||
// Thinking budget limits and provider max completion tokens.
|
|
||||||
type AntigravityModelConfig struct {
|
|
||||||
Thinking *ThinkingSupport
|
|
||||||
MaxCompletionTokens int
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
|
||||||
// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup.
|
|
||||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
|
||||||
return map[string]*AntigravityModelConfig{
|
|
||||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
|
||||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
|
||||||
"gemini-2.5-computer-use-preview-10-2025": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
|
||||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
|
|
||||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
|
|
||||||
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
|
|
||||||
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
|
||||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -788,6 +73,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
allModels := [][]*ModelInfo{
|
allModels := [][]*ModelInfo{
|
||||||
GetClaudeModels(),
|
GetClaudeModels(),
|
||||||
GetGeminiModels(),
|
GetGeminiModels(),
|
||||||
@@ -805,5 +91,15 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check Antigravity static config
|
||||||
|
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
||||||
|
return &ModelInfo{
|
||||||
|
ID: modelID,
|
||||||
|
Thinking: cfg.Thinking,
|
||||||
|
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
846
internal/registry/model_definitions_static_data.go
Normal file
846
internal/registry/model_definitions_static_data.go
Normal file
@@ -0,0 +1,846 @@
|
|||||||
|
// Package registry provides model definitions for various AI service providers.
|
||||||
|
// This file stores the static model metadata catalog.
|
||||||
|
package registry
|
||||||
|
|
||||||
|
// GetClaudeModels returns the standard Claude model definitions
|
||||||
|
func GetClaudeModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
|
||||||
|
{
|
||||||
|
ID: "claude-haiku-4-5-20251001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1759276800, // 2025-10-01
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4.5 Haiku",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
// Thinking: not supported for Haiku models
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4-5-20250929",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1759104000, // 2025-09-29
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4.5 Sonnet",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4-5-20251101",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1761955200, // 2025-11-01
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4.5 Opus",
|
||||||
|
Description: "Premium model combining maximum intelligence with practical performance",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4-1-20250805",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1722945600, // 2025-08-05
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4.1 Opus",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4-20250514",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1715644800, // 2025-05-14
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4 Opus",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4-20250514",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1715644800, // 2025-05-14
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4 Sonnet",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-3-7-sonnet-20250219",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1708300800, // 2025-02-19
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 3.7 Sonnet",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 8192,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-3-5-haiku-20241022",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1729555200, // 2024-10-22
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 3.5 Haiku",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 8192,
|
||||||
|
// Thinking: not supported for Haiku models
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiModels returns the standard Gemini model definitions
|
||||||
|
func GetGeminiModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-pro",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-pro",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Pro",
|
||||||
|
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash",
|
||||||
|
Version: "001",
|
||||||
|
DisplayName: "Gemini 2.5 Flash",
|
||||||
|
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash-lite",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1753142400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash-lite",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Flash Lite",
|
||||||
|
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1737158400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-pro-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Pro Preview",
|
||||||
|
Description: "Gemini 3 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-flash-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765929600,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-flash-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Flash Preview",
|
||||||
|
Description: "Gemini 3 Flash Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-pro-image-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1737158400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-pro-image-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Pro Image Preview",
|
||||||
|
Description: "Gemini 3 Pro Image Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetGeminiVertexModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-pro",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-pro",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Pro",
|
||||||
|
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash",
|
||||||
|
Version: "001",
|
||||||
|
DisplayName: "Gemini 2.5 Flash",
|
||||||
|
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash-lite",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1753142400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash-lite",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Flash Lite",
|
||||||
|
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1737158400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-pro-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Pro Preview",
|
||||||
|
Description: "Gemini 3 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-flash-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765929600,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-flash-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Flash Preview",
|
||||||
|
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-pro-image-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1737158400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-pro-image-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Pro Image Preview",
|
||||||
|
Description: "Gemini 3 Pro Image Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
|
// Imagen image generation models - use :predict action
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Generate",
|
||||||
|
Description: "Imagen 4.0 image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-ultra-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-ultra-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Ultra Generate",
|
||||||
|
Description: "Imagen 4.0 Ultra high-quality image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-3.0-generate-002",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1740000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-3.0-generate-002",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Imagen 3.0 Generate",
|
||||||
|
Description: "Imagen 3.0 image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-3.0-fast-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1740000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-3.0-fast-generate-001",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Imagen 3.0 Fast Generate",
|
||||||
|
Description: "Imagen 3.0 fast image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-fast-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-fast-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Fast Generate",
|
||||||
|
Description: "Imagen 4.0 fast image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiCLIModels returns the standard Gemini model definitions
|
||||||
|
func GetGeminiCLIModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-pro",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-pro",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Pro",
|
||||||
|
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash",
|
||||||
|
Version: "001",
|
||||||
|
DisplayName: "Gemini 2.5 Flash",
|
||||||
|
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash-lite",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1753142400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash-lite",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Flash Lite",
|
||||||
|
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1737158400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-pro-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Pro Preview",
|
||||||
|
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-flash-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765929600,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-flash-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Flash Preview",
|
||||||
|
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAIStudioModels returns the Gemini model definitions for AI Studio integrations
|
||||||
|
func GetAIStudioModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-pro",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-pro",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Pro",
|
||||||
|
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash",
|
||||||
|
Version: "001",
|
||||||
|
DisplayName: "Gemini 2.5 Flash",
|
||||||
|
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash-lite",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1753142400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash-lite",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Flash Lite",
|
||||||
|
Description: "Our smallest and most cost effective model, built for at scale usage.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1737158400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-pro-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Pro Preview",
|
||||||
|
Description: "Gemini 3 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-flash-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765929600,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3-flash-preview",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Gemini 3 Flash Preview",
|
||||||
|
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-pro-latest",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-pro-latest",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini Pro Latest",
|
||||||
|
Description: "Latest release of Gemini Pro",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-flash-latest",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750118400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-flash-latest",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini Flash Latest",
|
||||||
|
Description: "Latest release of Gemini Flash",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-flash-lite-latest",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1753142400,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-flash-lite-latest",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini Flash-Lite Latest",
|
||||||
|
Description: "Latest release of Gemini Flash-Lite",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 512, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
// {
|
||||||
|
// ID: "gemini-2.5-flash-image-preview",
|
||||||
|
// Object: "model",
|
||||||
|
// Created: 1756166400,
|
||||||
|
// OwnedBy: "google",
|
||||||
|
// Type: "gemini",
|
||||||
|
// Name: "models/gemini-2.5-flash-image-preview",
|
||||||
|
// Version: "2.5",
|
||||||
|
// DisplayName: "Gemini 2.5 Flash Image Preview",
|
||||||
|
// Description: "State-of-the-art image generation and editing model.",
|
||||||
|
// InputTokenLimit: 1048576,
|
||||||
|
// OutputTokenLimit: 8192,
|
||||||
|
// SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
// // image models don't support thinkingConfig; leave Thinking nil
|
||||||
|
// },
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash-image",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1759363200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash-image",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Flash Image",
|
||||||
|
Description: "State-of-the-art image generation and editing model.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 8192,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
// image models don't support thinkingConfig; leave Thinking nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIModels returns the standard OpenAI model definitions
|
||||||
|
func GetOpenAIModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "gpt-5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1754524800,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5-2025-08-07",
|
||||||
|
DisplayName: "GPT 5",
|
||||||
|
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"minimal", "low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5-codex",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1757894400,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5-2025-09-15",
|
||||||
|
DisplayName: "GPT 5 Codex",
|
||||||
|
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5-codex-mini",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1762473600,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5-2025-11-07",
|
||||||
|
DisplayName: "GPT 5 Codex Mini",
|
||||||
|
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1762905600,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5.1-2025-11-12",
|
||||||
|
DisplayName: "GPT 5",
|
||||||
|
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.1-codex",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1762905600,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5.1-2025-11-12",
|
||||||
|
DisplayName: "GPT 5.1 Codex",
|
||||||
|
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.1-codex-mini",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1762905600,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5.1-2025-11-12",
|
||||||
|
DisplayName: "GPT 5.1 Codex Mini",
|
||||||
|
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.1-codex-max",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1763424000,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5.1-max",
|
||||||
|
DisplayName: "GPT 5.1 Codex Max",
|
||||||
|
Description: "Stable version of GPT 5.1 Codex Max",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.2",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765440000,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5.2",
|
||||||
|
DisplayName: "GPT 5.2",
|
||||||
|
Description: "Stable version of GPT 5.2",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.2-codex",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1765440000,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5.2",
|
||||||
|
DisplayName: "GPT 5.2 Codex",
|
||||||
|
Description: "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQwenModels returns the standard Qwen model definitions
|
||||||
|
func GetQwenModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "qwen3-coder-plus",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1753228800,
|
||||||
|
OwnedBy: "qwen",
|
||||||
|
Type: "qwen",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Qwen3 Coder Plus",
|
||||||
|
Description: "Advanced code generation and understanding model",
|
||||||
|
ContextLength: 32768,
|
||||||
|
MaxCompletionTokens: 8192,
|
||||||
|
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "qwen3-coder-flash",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1753228800,
|
||||||
|
OwnedBy: "qwen",
|
||||||
|
Type: "qwen",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Qwen3 Coder Flash",
|
||||||
|
Description: "Fast code generation model",
|
||||||
|
ContextLength: 8192,
|
||||||
|
MaxCompletionTokens: 2048,
|
||||||
|
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "vision-model",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1758672000,
|
||||||
|
OwnedBy: "qwen",
|
||||||
|
Type: "qwen",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Qwen3 Vision Model",
|
||||||
|
Description: "Vision model model",
|
||||||
|
ContextLength: 32768,
|
||||||
|
MaxCompletionTokens: 2048,
|
||||||
|
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
|
||||||
|
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
|
||||||
|
// Uses level-based configuration so standard normalization flows apply before conversion.
|
||||||
|
var iFlowThinkingSupport = &ThinkingSupport{
|
||||||
|
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||||
|
func GetIFlowModels() []*ModelInfo {
|
||||||
|
entries := []struct {
|
||||||
|
ID string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
Created int64
|
||||||
|
Thinking *ThinkingSupport
|
||||||
|
}{
|
||||||
|
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
||||||
|
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||||
|
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||||
|
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||||
|
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||||
|
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||||
|
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||||
|
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||||
|
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||||
|
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||||
|
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
||||||
|
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
||||||
|
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||||
|
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||||
|
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||||
|
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||||
|
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||||
|
}
|
||||||
|
models := make([]*ModelInfo, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
models = append(models, &ModelInfo{
|
||||||
|
ID: entry.ID,
|
||||||
|
Object: "model",
|
||||||
|
Created: entry.Created,
|
||||||
|
OwnedBy: "iflow",
|
||||||
|
Type: "iflow",
|
||||||
|
DisplayName: entry.DisplayName,
|
||||||
|
Description: entry.Description,
|
||||||
|
Thinking: entry.Thinking,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityModelConfig captures static antigravity model overrides, including
|
||||||
|
// Thinking budget limits and provider max completion tokens.
|
||||||
|
type AntigravityModelConfig struct {
|
||||||
|
Thinking *ThinkingSupport
|
||||||
|
MaxCompletionTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||||
|
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||||
|
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||||
|
return map[string]*AntigravityModelConfig{
|
||||||
|
// "rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||||
|
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
|
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
|
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
|
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
|
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||||
|
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||||
|
"gpt-oss-120b-medium": {},
|
||||||
|
"tab_flash_lite_preview": {},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -50,6 +51,11 @@ type ModelInfo struct {
|
|||||||
// Thinking holds provider-specific reasoning/thinking budget capabilities.
|
// Thinking holds provider-specific reasoning/thinking budget capabilities.
|
||||||
// This is optional and currently used for Gemini thinking budget normalization.
|
// This is optional and currently used for Gemini thinking budget normalization.
|
||||||
Thinking *ThinkingSupport `json:"thinking,omitempty"`
|
Thinking *ThinkingSupport `json:"thinking,omitempty"`
|
||||||
|
|
||||||
|
// UserDefined indicates this model was defined through config file's models[]
|
||||||
|
// array (e.g., openai-compatibility.*.models[], *-api-key.models[]).
|
||||||
|
// UserDefined models have thinking configuration passed through without validation.
|
||||||
|
UserDefined bool `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
||||||
@@ -72,6 +78,8 @@ type ThinkingSupport struct {
|
|||||||
type ModelRegistration struct {
|
type ModelRegistration struct {
|
||||||
// Info contains the model metadata
|
// Info contains the model metadata
|
||||||
Info *ModelInfo
|
Info *ModelInfo
|
||||||
|
// InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities.
|
||||||
|
InfoByProvider map[string]*ModelInfo
|
||||||
// Count is the number of active clients that can provide this model
|
// Count is the number of active clients that can provide this model
|
||||||
Count int
|
Count int
|
||||||
// LastUpdated tracks when this registration was last modified
|
// LastUpdated tracks when this registration was last modified
|
||||||
@@ -84,6 +92,13 @@ type ModelRegistration struct {
|
|||||||
SuspendedClients map[string]string
|
SuspendedClients map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelRegistryHook provides optional callbacks for external integrations to track model list changes.
|
||||||
|
// Hook implementations must be non-blocking and resilient; calls are executed asynchronously and panics are recovered.
|
||||||
|
type ModelRegistryHook interface {
|
||||||
|
OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo)
|
||||||
|
OnModelsUnregistered(ctx context.Context, provider, clientID string)
|
||||||
|
}
|
||||||
|
|
||||||
// ModelRegistry manages the global registry of available models
|
// ModelRegistry manages the global registry of available models
|
||||||
type ModelRegistry struct {
|
type ModelRegistry struct {
|
||||||
// models maps model ID to registration information
|
// models maps model ID to registration information
|
||||||
@@ -97,6 +112,8 @@ type ModelRegistry struct {
|
|||||||
clientProviders map[string]string
|
clientProviders map[string]string
|
||||||
// mutex ensures thread-safe access to the registry
|
// mutex ensures thread-safe access to the registry
|
||||||
mutex *sync.RWMutex
|
mutex *sync.RWMutex
|
||||||
|
// hook is an optional callback sink for model registration changes
|
||||||
|
hook ModelRegistryHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global model registry instance
|
// Global model registry instance
|
||||||
@@ -117,6 +134,71 @@ func GetGlobalRegistry() *ModelRegistry {
|
|||||||
return globalRegistry
|
return globalRegistry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||||
|
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p := ""
|
||||||
|
if len(provider) > 0 {
|
||||||
|
p = strings.ToLower(strings.TrimSpace(provider[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
return LookupStaticModelInfo(modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHook sets an optional hook for observing model registration changes.
|
||||||
|
func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.hook = hook
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultModelRegistryHookTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
|
||||||
|
hook := r.hook
|
||||||
|
if hook == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelsCopy := cloneModelInfosUnique(models)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if recovered := recover(); recovered != nil {
|
||||||
|
log.Errorf("model registry hook OnModelsRegistered panic: %v", recovered)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
|
||||||
|
defer cancel()
|
||||||
|
hook.OnModelsRegistered(ctx, provider, clientID, modelsCopy)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
|
||||||
|
hook := r.hook
|
||||||
|
if hook == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if recovered := recover(); recovered != nil {
|
||||||
|
log.Errorf("model registry hook OnModelsUnregistered panic: %v", recovered)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
|
||||||
|
defer cancel()
|
||||||
|
hook.OnModelsUnregistered(ctx, provider, clientID)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterClient registers a client and its supported models
|
// RegisterClient registers a client and its supported models
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - clientID: Unique identifier for the client
|
// - clientID: Unique identifier for the client
|
||||||
@@ -177,6 +259,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
} else {
|
} else {
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
|
r.triggerModelsRegistered(provider, clientID, models)
|
||||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
return
|
return
|
||||||
@@ -219,6 +302,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
if count, okProv := reg.Providers[oldProvider]; okProv {
|
if count, okProv := reg.Providers[oldProvider]; okProv {
|
||||||
if count <= toRemove {
|
if count <= toRemove {
|
||||||
delete(reg.Providers, oldProvider)
|
delete(reg.Providers, oldProvider)
|
||||||
|
if reg.InfoByProvider != nil {
|
||||||
|
delete(reg.InfoByProvider, oldProvider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
reg.Providers[oldProvider] = count - toRemove
|
reg.Providers[oldProvider] = count - toRemove
|
||||||
}
|
}
|
||||||
@@ -268,6 +354,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
model := newModels[id]
|
model := newModels[id]
|
||||||
if reg, ok := r.models[id]; ok {
|
if reg, ok := r.models[id]; ok {
|
||||||
reg.Info = cloneModelInfo(model)
|
reg.Info = cloneModelInfo(model)
|
||||||
|
if provider != "" {
|
||||||
|
if reg.InfoByProvider == nil {
|
||||||
|
reg.InfoByProvider = make(map[string]*ModelInfo)
|
||||||
|
}
|
||||||
|
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
|
}
|
||||||
reg.LastUpdated = now
|
reg.LastUpdated = now
|
||||||
if reg.QuotaExceededClients != nil {
|
if reg.QuotaExceededClients != nil {
|
||||||
delete(reg.QuotaExceededClients, clientID)
|
delete(reg.QuotaExceededClients, clientID)
|
||||||
@@ -310,6 +402,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.triggerModelsRegistered(provider, clientID, models)
|
||||||
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||||
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||||
return
|
return
|
||||||
@@ -330,11 +423,15 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
if existing.SuspendedClients == nil {
|
if existing.SuspendedClients == nil {
|
||||||
existing.SuspendedClients = make(map[string]string)
|
existing.SuspendedClients = make(map[string]string)
|
||||||
}
|
}
|
||||||
|
if existing.InfoByProvider == nil {
|
||||||
|
existing.InfoByProvider = make(map[string]*ModelInfo)
|
||||||
|
}
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
if existing.Providers == nil {
|
if existing.Providers == nil {
|
||||||
existing.Providers = make(map[string]int)
|
existing.Providers = make(map[string]int)
|
||||||
}
|
}
|
||||||
existing.Providers[provider]++
|
existing.Providers[provider]++
|
||||||
|
existing.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
}
|
}
|
||||||
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
||||||
return
|
return
|
||||||
@@ -342,6 +439,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
|
|
||||||
registration := &ModelRegistration{
|
registration := &ModelRegistration{
|
||||||
Info: cloneModelInfo(model),
|
Info: cloneModelInfo(model),
|
||||||
|
InfoByProvider: make(map[string]*ModelInfo),
|
||||||
Count: 1,
|
Count: 1,
|
||||||
LastUpdated: now,
|
LastUpdated: now,
|
||||||
QuotaExceededClients: make(map[string]*time.Time),
|
QuotaExceededClients: make(map[string]*time.Time),
|
||||||
@@ -349,6 +447,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
}
|
}
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
registration.Providers = map[string]int{provider: 1}
|
registration.Providers = map[string]int{provider: 1}
|
||||||
|
registration.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
}
|
}
|
||||||
r.models[modelID] = registration
|
r.models[modelID] = registration
|
||||||
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
||||||
@@ -374,6 +473,9 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
|
|||||||
if count, ok := registration.Providers[provider]; ok {
|
if count, ok := registration.Providers[provider]; ok {
|
||||||
if count <= 1 {
|
if count <= 1 {
|
||||||
delete(registration.Providers, provider)
|
delete(registration.Providers, provider)
|
||||||
|
if registration.InfoByProvider != nil {
|
||||||
|
delete(registration.InfoByProvider, provider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
registration.Providers[provider] = count - 1
|
registration.Providers[provider] = count - 1
|
||||||
}
|
}
|
||||||
@@ -400,6 +502,25 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
|||||||
return ©Model
|
return ©Model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneModelInfosUnique(models []*ModelInfo) []*ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := make([]*ModelInfo, 0, len(models))
|
||||||
|
seen := make(map[string]struct{}, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil || model.ID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[model.ID]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[model.ID] = struct{}{}
|
||||||
|
cloned = append(cloned, cloneModelInfo(model))
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
// UnregisterClient removes a client and decrements counts for its models
|
// UnregisterClient removes a client and decrements counts for its models
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - clientID: Unique identifier for the client to remove
|
// - clientID: Unique identifier for the client to remove
|
||||||
@@ -436,6 +557,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
|||||||
if count, ok := registration.Providers[provider]; ok {
|
if count, ok := registration.Providers[provider]; ok {
|
||||||
if count <= 1 {
|
if count <= 1 {
|
||||||
delete(registration.Providers, provider)
|
delete(registration.Providers, provider)
|
||||||
|
if registration.InfoByProvider != nil {
|
||||||
|
delete(registration.InfoByProvider, provider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
registration.Providers[provider] = count - 1
|
registration.Providers[provider] = count - 1
|
||||||
}
|
}
|
||||||
@@ -460,6 +584,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
|||||||
log.Debugf("Unregistered client %s", clientID)
|
log.Debugf("Unregistered client %s", clientID)
|
||||||
// Separator line after completing client unregistration (after the summary line)
|
// Separator line after completing client unregistration (after the summary line)
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
|
r.triggerModelsUnregistered(provider, clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
|
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
|
||||||
@@ -841,12 +966,22 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelInfo returns the registered ModelInfo for the given model ID, if present.
|
// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available.
|
||||||
// Returns nil if the model is unknown to the registry.
|
func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
||||||
func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo {
|
|
||||||
r.mutex.RLock()
|
r.mutex.RLock()
|
||||||
defer r.mutex.RUnlock()
|
defer r.mutex.RUnlock()
|
||||||
if reg, ok := r.models[modelID]; ok && reg != nil {
|
if reg, ok := r.models[modelID]; ok && reg != nil {
|
||||||
|
// Try provider specific definition first
|
||||||
|
if provider != "" && reg.InfoByProvider != nil {
|
||||||
|
if reg.Providers != nil {
|
||||||
|
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||||
|
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback to global info (last registered)
|
||||||
return reg.Info
|
return reg.Info
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -898,10 +1033,10 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
"owned_by": model.OwnedBy,
|
"owned_by": model.OwnedBy,
|
||||||
}
|
}
|
||||||
if model.Created > 0 {
|
if model.Created > 0 {
|
||||||
result["created"] = model.Created
|
result["created_at"] = model.Created
|
||||||
}
|
}
|
||||||
if model.Type != "" {
|
if model.Type != "" {
|
||||||
result["type"] = model.Type
|
result["type"] = "model"
|
||||||
}
|
}
|
||||||
if model.DisplayName != "" {
|
if model.DisplayName != "" {
|
||||||
result["display_name"] = model.DisplayName
|
result["display_name"] = model.DisplayName
|
||||||
|
|||||||
204
internal/registry/model_registry_hook_test.go
Normal file
204
internal/registry/model_registry_hook_test.go
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestModelRegistry() *ModelRegistry {
|
||||||
|
return &ModelRegistry{
|
||||||
|
models: make(map[string]*ModelRegistration),
|
||||||
|
clientModels: make(map[string][]string),
|
||||||
|
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||||
|
clientProviders: make(map[string]string),
|
||||||
|
mutex: &sync.RWMutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type registeredCall struct {
|
||||||
|
provider string
|
||||||
|
clientID string
|
||||||
|
models []*ModelInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type unregisteredCall struct {
|
||||||
|
provider string
|
||||||
|
clientID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type capturingHook struct {
|
||||||
|
registeredCh chan registeredCall
|
||||||
|
unregisteredCh chan unregisteredCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
||||||
|
h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
|
||||||
|
h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
hook := &capturingHook{
|
||||||
|
registeredCh: make(chan registeredCall, 1),
|
||||||
|
unregisteredCh: make(chan unregisteredCall, 1),
|
||||||
|
}
|
||||||
|
r.SetHook(hook)
|
||||||
|
|
||||||
|
inputModels := []*ModelInfo{
|
||||||
|
{ID: "m1", DisplayName: "Model One"},
|
||||||
|
{ID: "m2", DisplayName: "Model Two"},
|
||||||
|
}
|
||||||
|
r.RegisterClient("client-1", "OpenAI", inputModels)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case call := <-hook.registeredCh:
|
||||||
|
if call.provider != "openai" {
|
||||||
|
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
|
||||||
|
}
|
||||||
|
if call.clientID != "client-1" {
|
||||||
|
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
|
||||||
|
}
|
||||||
|
if len(call.models) != 2 {
|
||||||
|
t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2)
|
||||||
|
}
|
||||||
|
if call.models[0] == nil || call.models[0].ID != "m1" {
|
||||||
|
t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1")
|
||||||
|
}
|
||||||
|
if call.models[1] == nil || call.models[1].ID != "m2" {
|
||||||
|
t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2")
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
hook := &capturingHook{
|
||||||
|
registeredCh: make(chan registeredCall, 1),
|
||||||
|
unregisteredCh: make(chan unregisteredCall, 1),
|
||||||
|
}
|
||||||
|
r.SetHook(hook)
|
||||||
|
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
||||||
|
select {
|
||||||
|
case <-hook.registeredCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.UnregisterClient("client-1")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case call := <-hook.unregisteredCh:
|
||||||
|
if call.provider != "openai" {
|
||||||
|
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
|
||||||
|
}
|
||||||
|
if call.clientID != "client-1" {
|
||||||
|
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type blockingHook struct {
|
||||||
|
started chan struct{}
|
||||||
|
unblock chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
||||||
|
select {
|
||||||
|
case <-h.started:
|
||||||
|
default:
|
||||||
|
close(h.started)
|
||||||
|
}
|
||||||
|
<-h.unblock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {}
|
||||||
|
|
||||||
|
func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
hook := &blockingHook{
|
||||||
|
started: make(chan struct{}),
|
||||||
|
unblock: make(chan struct{}),
|
||||||
|
}
|
||||||
|
r.SetHook(hook)
|
||||||
|
defer close(hook.unblock)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-hook.started:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for hook to start")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("RegisterClient appears to be blocked by hook")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !r.ClientSupportsModel("client-1", "m1") {
|
||||||
|
t.Fatal("model registration failed; expected client to support model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type panicHook struct {
|
||||||
|
registeredCalled chan struct{}
|
||||||
|
unregisteredCalled chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
||||||
|
if h.registeredCalled != nil {
|
||||||
|
h.registeredCalled <- struct{}{}
|
||||||
|
}
|
||||||
|
panic("boom")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
|
||||||
|
if h.unregisteredCalled != nil {
|
||||||
|
h.unregisteredCalled <- struct{}{}
|
||||||
|
}
|
||||||
|
panic("boom")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
hook := &panicHook{
|
||||||
|
registeredCalled: make(chan struct{}, 1),
|
||||||
|
unregisteredCalled: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
r.SetHook(hook)
|
||||||
|
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-hook.registeredCalled:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !r.ClientSupportsModel("client-1", "m1") {
|
||||||
|
t.Fatal("model registration failed; expected client to support model")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.UnregisterClient("client-1")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-hook.unregisteredCalled:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,12 +8,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -50,9 +51,71 @@ func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HttpRequest forwards an arbitrary HTTP request through the websocket relay.
|
||||||
|
func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("aistudio executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
if e.relay == nil {
|
||||||
|
return nil, fmt.Errorf("aistudio executor: ws relay is nil")
|
||||||
|
}
|
||||||
|
if auth == nil || auth.ID == "" {
|
||||||
|
return nil, fmt.Errorf("aistudio executor: missing auth")
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
|
||||||
|
return nil, fmt.Errorf("aistudio executor: request URL is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var body []byte
|
||||||
|
if httpReq.Body != nil {
|
||||||
|
b, errRead := io.ReadAll(httpReq.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
return nil, errRead
|
||||||
|
}
|
||||||
|
body = b
|
||||||
|
httpReq.Body = io.NopCloser(bytes.NewReader(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
|
Method: httpReq.Method,
|
||||||
|
URL: httpReq.URL.String(),
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
}
|
||||||
|
wsResp, errRelay := e.relay.NonStream(ctx, auth.ID, wsReq)
|
||||||
|
if errRelay != nil {
|
||||||
|
return nil, errRelay
|
||||||
|
}
|
||||||
|
if wsResp == nil {
|
||||||
|
return nil, fmt.Errorf("aistudio executor: ws response is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
statusText := http.StatusText(wsResp.Status)
|
||||||
|
if statusText == "" {
|
||||||
|
statusText = "Unknown"
|
||||||
|
}
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: wsResp.Status,
|
||||||
|
Status: fmt.Sprintf("%d %s", wsResp.Status, statusText),
|
||||||
|
Header: wsResp.Headers.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(wsResp.Body)),
|
||||||
|
ContentLength: int64(len(wsResp.Body)),
|
||||||
|
Request: httpReq,
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Execute performs a non-streaming request to the AI Studio API.
|
// Execute performs a non-streaming request to the AI Studio API.
|
||||||
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
if opts.Alt == "responses/compact" {
|
||||||
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
translatedReq, body, err := e.translateRequest(req, opts, false)
|
translatedReq, body, err := e.translateRequest(req, opts, false)
|
||||||
@@ -60,7 +123,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
endpoint := e.buildEndpoint(baseModel, body.action, opts.Alt)
|
||||||
wsReq := &wsrelay.HTTPRequest{
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
@@ -107,7 +170,11 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
|
|
||||||
// ExecuteStream performs a streaming request to the AI Studio API.
|
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
if opts.Alt == "responses/compact" {
|
||||||
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
translatedReq, body, err := e.translateRequest(req, opts, true)
|
translatedReq, body, err := e.translateRequest(req, opts, true)
|
||||||
@@ -115,7 +182,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
endpoint := e.buildEndpoint(baseModel, body.action, opts.Alt)
|
||||||
wsReq := &wsrelay.HTTPRequest{
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
@@ -256,6 +323,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
|
|
||||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||||
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
_, body, err := e.translateRequest(req, opts, false)
|
_, body, err := e.translateRequest(req, opts, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
@@ -265,7 +333,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
|||||||
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
|
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
|
||||||
body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings")
|
body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings")
|
||||||
|
|
||||||
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
|
endpoint := e.buildEndpoint(baseModel, "countTokens", "")
|
||||||
wsReq := &wsrelay.HTTPRequest{
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
URL: endpoint,
|
URL: endpoint,
|
||||||
@@ -321,17 +389,23 @@ type translatedPayload struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
|
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
}
|
||||||
payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
|
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
if err != nil {
|
||||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
return nil, translatedPayload{}, err
|
||||||
|
}
|
||||||
|
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,68 @@
|
|||||||
package executor
|
package executor
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type codexCache struct {
|
type codexCache struct {
|
||||||
ID string
|
ID string
|
||||||
Expire time.Time
|
Expire time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
var codexCacheMap = map[string]codexCache{}
|
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
|
||||||
|
// Protected by codexCacheMu. Entries expire after 1 hour.
|
||||||
|
var (
|
||||||
|
codexCacheMap = make(map[string]codexCache)
|
||||||
|
codexCacheMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// codexCacheCleanupInterval controls how often expired entries are purged.
|
||||||
|
const codexCacheCleanupInterval = 15 * time.Minute
|
||||||
|
|
||||||
|
// codexCacheCleanupOnce ensures the background cleanup goroutine starts only once.
|
||||||
|
var codexCacheCleanupOnce sync.Once
|
||||||
|
|
||||||
|
// startCodexCacheCleanup launches a background goroutine that periodically
|
||||||
|
// removes expired entries from codexCacheMap to prevent memory leaks.
|
||||||
|
func startCodexCacheCleanup() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(codexCacheCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
purgeExpiredCodexCache()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// purgeExpiredCodexCache removes entries that have expired.
|
||||||
|
func purgeExpiredCodexCache() {
|
||||||
|
now := time.Now()
|
||||||
|
codexCacheMu.Lock()
|
||||||
|
defer codexCacheMu.Unlock()
|
||||||
|
for key, cache := range codexCacheMap {
|
||||||
|
if cache.Expire.Before(now) {
|
||||||
|
delete(codexCacheMap, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
|
||||||
|
func getCodexCache(key string) (codexCache, bool) {
|
||||||
|
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||||
|
codexCacheMu.RLock()
|
||||||
|
cache, ok := codexCacheMap[key]
|
||||||
|
codexCacheMu.RUnlock()
|
||||||
|
if !ok || cache.Expire.Before(time.Now()) {
|
||||||
|
return codexCache{}, false
|
||||||
|
}
|
||||||
|
return cache, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCodexCache stores a cache entry.
|
||||||
|
func setCodexCache(key string, cache codexCache) {
|
||||||
|
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
|
||||||
|
codexCacheMu.Lock()
|
||||||
|
codexCacheMap[key] = cache
|
||||||
|
codexCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|||||||
258
internal/runtime/executor/caching_verify_test.go
Normal file
258
internal/runtime/executor/caching_verify_test.go
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnsureCacheControl(t *testing.T) {
|
||||||
|
// Test case 1: System prompt as string
|
||||||
|
t.Run("String System Prompt", func(t *testing.T) {
|
||||||
|
input := []byte(`{"model": "claude-3-5-sonnet", "system": "This is a long system prompt", "messages": []}`)
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
res := gjson.GetBytes(output, "system.0.cache_control.type")
|
||||||
|
if res.String() != "ephemeral" {
|
||||||
|
t.Errorf("cache_control not found in system string. Output: %s", string(output))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 2: System prompt as array
|
||||||
|
t.Run("Array System Prompt", func(t *testing.T) {
|
||||||
|
input := []byte(`{"model": "claude-3-5-sonnet", "system": [{"type": "text", "text": "Part 1"}, {"type": "text", "text": "Part 2"}], "messages": []}`)
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
// cache_control should only be on the LAST element
|
||||||
|
res0 := gjson.GetBytes(output, "system.0.cache_control")
|
||||||
|
res1 := gjson.GetBytes(output, "system.1.cache_control.type")
|
||||||
|
|
||||||
|
if res0.Exists() {
|
||||||
|
t.Errorf("cache_control should NOT be on the first element")
|
||||||
|
}
|
||||||
|
if res1.String() != "ephemeral" {
|
||||||
|
t.Errorf("cache_control not found on last system element. Output: %s", string(output))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 3: Tools are cached
|
||||||
|
t.Run("Tools Caching", func(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet",
|
||||||
|
"tools": [
|
||||||
|
{"name": "tool1", "description": "First tool", "input_schema": {"type": "object"}},
|
||||||
|
{"name": "tool2", "description": "Second tool", "input_schema": {"type": "object"}}
|
||||||
|
],
|
||||||
|
"system": "System prompt",
|
||||||
|
"messages": []
|
||||||
|
}`)
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
// cache_control should only be on the LAST tool
|
||||||
|
tool0Cache := gjson.GetBytes(output, "tools.0.cache_control")
|
||||||
|
tool1Cache := gjson.GetBytes(output, "tools.1.cache_control.type")
|
||||||
|
|
||||||
|
if tool0Cache.Exists() {
|
||||||
|
t.Errorf("cache_control should NOT be on the first tool")
|
||||||
|
}
|
||||||
|
if tool1Cache.String() != "ephemeral" {
|
||||||
|
t.Errorf("cache_control not found on last tool. Output: %s", string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// System should also have cache_control
|
||||||
|
systemCache := gjson.GetBytes(output, "system.0.cache_control.type")
|
||||||
|
if systemCache.String() != "ephemeral" {
|
||||||
|
t.Errorf("cache_control not found in system. Output: %s", string(output))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 4: Tools and system are INDEPENDENT breakpoints
|
||||||
|
// Per Anthropic docs: Up to 4 breakpoints allowed, tools and system are cached separately
|
||||||
|
t.Run("Independent Cache Breakpoints", func(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet",
|
||||||
|
"tools": [
|
||||||
|
{"name": "tool1", "description": "First tool", "input_schema": {"type": "object"}, "cache_control": {"type": "ephemeral"}}
|
||||||
|
],
|
||||||
|
"system": [{"type": "text", "text": "System"}],
|
||||||
|
"messages": []
|
||||||
|
}`)
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
// Tool already has cache_control - should not be changed
|
||||||
|
tool0Cache := gjson.GetBytes(output, "tools.0.cache_control.type")
|
||||||
|
if tool0Cache.String() != "ephemeral" {
|
||||||
|
t.Errorf("existing cache_control was incorrectly removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// System SHOULD get cache_control because it is an INDEPENDENT breakpoint
|
||||||
|
// Tools and system are separate cache levels in the hierarchy
|
||||||
|
systemCache := gjson.GetBytes(output, "system.0.cache_control.type")
|
||||||
|
if systemCache.String() != "ephemeral" {
|
||||||
|
t.Errorf("system should have its own cache_control breakpoint (independent of tools)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 5: Only tools, no system
|
||||||
|
t.Run("Only Tools No System", func(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet",
|
||||||
|
"tools": [
|
||||||
|
{"name": "tool1", "description": "Tool", "input_schema": {"type": "object"}}
|
||||||
|
],
|
||||||
|
"messages": [{"role": "user", "content": "Hi"}]
|
||||||
|
}`)
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
toolCache := gjson.GetBytes(output, "tools.0.cache_control.type")
|
||||||
|
if toolCache.String() != "ephemeral" {
|
||||||
|
t.Errorf("cache_control not found on tool. Output: %s", string(output))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 6: Many tools (Claude Code scenario)
|
||||||
|
t.Run("Many Tools (Claude Code Scenario)", func(t *testing.T) {
|
||||||
|
// Simulate Claude Code with many tools
|
||||||
|
toolsJSON := `[`
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
toolsJSON += ","
|
||||||
|
}
|
||||||
|
toolsJSON += fmt.Sprintf(`{"name": "tool%d", "description": "Tool %d", "input_schema": {"type": "object"}}`, i, i)
|
||||||
|
}
|
||||||
|
toolsJSON += `]`
|
||||||
|
|
||||||
|
input := []byte(fmt.Sprintf(`{
|
||||||
|
"model": "claude-3-5-sonnet",
|
||||||
|
"tools": %s,
|
||||||
|
"system": [{"type": "text", "text": "You are Claude Code"}],
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}]
|
||||||
|
}`, toolsJSON))
|
||||||
|
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
// Only the last tool (index 49) should have cache_control
|
||||||
|
for i := 0; i < 49; i++ {
|
||||||
|
path := fmt.Sprintf("tools.%d.cache_control", i)
|
||||||
|
if gjson.GetBytes(output, path).Exists() {
|
||||||
|
t.Errorf("tool %d should NOT have cache_control", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lastToolCache := gjson.GetBytes(output, "tools.49.cache_control.type")
|
||||||
|
if lastToolCache.String() != "ephemeral" {
|
||||||
|
t.Errorf("last tool (49) should have cache_control")
|
||||||
|
}
|
||||||
|
|
||||||
|
// System should also have cache_control
|
||||||
|
systemCache := gjson.GetBytes(output, "system.0.cache_control.type")
|
||||||
|
if systemCache.String() != "ephemeral" {
|
||||||
|
t.Errorf("system should have cache_control")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("test passed: 50 tools - cache_control only on last tool")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 7: Empty tools array
|
||||||
|
t.Run("Empty Tools Array", func(t *testing.T) {
|
||||||
|
input := []byte(`{"model": "claude-3-5-sonnet", "tools": [], "system": "Test", "messages": []}`)
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
// System should still get cache_control
|
||||||
|
systemCache := gjson.GetBytes(output, "system.0.cache_control.type")
|
||||||
|
if systemCache.String() != "ephemeral" {
|
||||||
|
t.Errorf("system should have cache_control even with empty tools array")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 8: Messages caching for multi-turn (second-to-last user)
|
||||||
|
t.Run("Messages Caching Second-To-Last User", func(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "First user"},
|
||||||
|
{"role": "assistant", "content": "Assistant reply"},
|
||||||
|
{"role": "user", "content": "Second user"},
|
||||||
|
{"role": "assistant", "content": "Assistant reply 2"},
|
||||||
|
{"role": "user", "content": "Third user"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
cacheType := gjson.GetBytes(output, "messages.2.content.0.cache_control.type")
|
||||||
|
if cacheType.String() != "ephemeral" {
|
||||||
|
t.Errorf("cache_control not found on second-to-last user turn. Output: %s", string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
lastUserCache := gjson.GetBytes(output, "messages.4.content.0.cache_control")
|
||||||
|
if lastUserCache.Exists() {
|
||||||
|
t.Errorf("last user turn should NOT have cache_control")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 9: Existing message cache_control should skip injection
|
||||||
|
t.Run("Messages Skip When Cache Control Exists", func(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "First user"}]},
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "text": "Assistant reply", "cache_control": {"type": "ephemeral"}}]},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "Second user"}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
userCache := gjson.GetBytes(output, "messages.0.content.0.cache_control")
|
||||||
|
if userCache.Exists() {
|
||||||
|
t.Errorf("cache_control should NOT be injected when a message already has cache_control")
|
||||||
|
}
|
||||||
|
|
||||||
|
existingCache := gjson.GetBytes(output, "messages.1.content.0.cache_control.type")
|
||||||
|
if existingCache.String() != "ephemeral" {
|
||||||
|
t.Errorf("existing cache_control should be preserved. Output: %s", string(output))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCacheControlOrder verifies the correct order: tools -> system -> messages
|
||||||
|
func TestCacheControlOrder(t *testing.T) {
|
||||||
|
input := []byte(`{
|
||||||
|
"model": "claude-sonnet-4",
|
||||||
|
"tools": [
|
||||||
|
{"name": "Read", "description": "Read file", "input_schema": {"type": "object", "properties": {"path": {"type": "string"}}}},
|
||||||
|
{"name": "Write", "description": "Write file", "input_schema": {"type": "object", "properties": {"path": {"type": "string"}, "content": {"type": "string"}}}}
|
||||||
|
],
|
||||||
|
"system": [
|
||||||
|
{"type": "text", "text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
|
{"type": "text", "text": "Additional instructions here..."}
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ensureCacheControl(input)
|
||||||
|
|
||||||
|
// 1. Last tool has cache_control
|
||||||
|
if gjson.GetBytes(output, "tools.1.cache_control.type").String() != "ephemeral" {
|
||||||
|
t.Error("last tool should have cache_control")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. First tool has NO cache_control
|
||||||
|
if gjson.GetBytes(output, "tools.0.cache_control").Exists() {
|
||||||
|
t.Error("first tool should NOT have cache_control")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Last system element has cache_control
|
||||||
|
if gjson.GetBytes(output, "system.1.cache_control.type").String() != "ephemeral" {
|
||||||
|
t.Error("last system element should have cache_control")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. First system element has NO cache_control
|
||||||
|
if gjson.GetBytes(output, "system.0.cache_control").Exists() {
|
||||||
|
t.Error("first system element should NOT have cache_control")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("cache order correct: tools -> system")
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
63
internal/runtime/executor/claude_executor_test.go
Normal file
63
internal/runtime/executor/claude_executor_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||||
|
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_alpha" {
|
||||||
|
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_alpha")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_bravo" {
|
||||||
|
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_bravo")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "proxy_charlie" {
|
||||||
|
t.Fatalf("tool_choice.name = %q, want %q", got, "proxy_charlie")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_delta" {
|
||||||
|
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_delta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||||
|
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||||
|
t.Fatalf("built-in tool name should not be prefixed: tools.0.name = %q, want %q", got, "web_search")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_my_custom_tool" {
|
||||||
|
t.Fatalf("custom tool should be prefixed: tools.1.name = %q, want %q", got, "proxy_my_custom_tool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||||
|
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||||
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "content.0.name").String(); got != "alpha" {
|
||||||
|
t.Fatalf("content.0.name = %q, want %q", got, "alpha")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "content.1.name").String(); got != "bravo" {
|
||||||
|
t.Fatalf("content.1.name = %q, want %q", got, "bravo")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
|
||||||
|
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
|
||||||
|
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||||
|
|
||||||
|
payload := bytes.TrimSpace(out)
|
||||||
|
if bytes.HasPrefix(payload, []byte("data:")) {
|
||||||
|
payload = bytes.TrimSpace(payload[len("data:"):])
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(payload, "content_block.name").String(); got != "alpha" {
|
||||||
|
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
|
||||||
|
}
|
||||||
|
}
|
||||||
176
internal/runtime/executor/cloak_obfuscate.go
Normal file
176
internal/runtime/executor/cloak_obfuscate.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// zeroWidthSpace is the Unicode zero-width space character used for obfuscation.
|
||||||
|
const zeroWidthSpace = "\u200B"
|
||||||
|
|
||||||
|
// SensitiveWordMatcher holds the compiled regex for matching sensitive words.
|
||||||
|
type SensitiveWordMatcher struct {
|
||||||
|
regex *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSensitiveWordMatcher compiles a regex from the word list.
|
||||||
|
// Words are sorted by length (longest first) for proper matching.
|
||||||
|
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
|
||||||
|
if len(words) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter and normalize words
|
||||||
|
var validWords []string
|
||||||
|
for _, w := range words {
|
||||||
|
w = strings.TrimSpace(w)
|
||||||
|
if utf8.RuneCountInString(w) >= 2 && !strings.Contains(w, zeroWidthSpace) {
|
||||||
|
validWords = append(validWords, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validWords) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by length (longest first) for proper matching
|
||||||
|
sort.Slice(validWords, func(i, j int) bool {
|
||||||
|
return len(validWords[i]) > len(validWords[j])
|
||||||
|
})
|
||||||
|
|
||||||
|
// Escape and join
|
||||||
|
escaped := make([]string, len(validWords))
|
||||||
|
for i, w := range validWords {
|
||||||
|
escaped[i] = regexp.QuoteMeta(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern := "(?i)" + strings.Join(escaped, "|")
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SensitiveWordMatcher{regex: re}
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateWord inserts a zero-width space after the first grapheme.
|
||||||
|
func obfuscateWord(word string) string {
|
||||||
|
if strings.Contains(word, zeroWidthSpace) {
|
||||||
|
return word
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get first rune
|
||||||
|
r, size := utf8.DecodeRuneInString(word)
|
||||||
|
if r == utf8.RuneError || size >= len(word) {
|
||||||
|
return word
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(r) + zeroWidthSpace + word[size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateText replaces all sensitive words in the text.
|
||||||
|
func (m *SensitiveWordMatcher) obfuscateText(text string) string {
|
||||||
|
if m == nil || m.regex == nil {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
|
||||||
|
// in system blocks and message content.
|
||||||
|
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||||
|
if matcher == nil || matcher.regex == nil {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obfuscate in system blocks
|
||||||
|
payload = obfuscateSystemBlocks(payload, matcher)
|
||||||
|
|
||||||
|
// Obfuscate in messages
|
||||||
|
payload = obfuscateMessages(payload, matcher)
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateSystemBlocks obfuscates sensitive words in system blocks.
|
||||||
|
func obfuscateSystemBlocks(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||||
|
system := gjson.GetBytes(payload, "system")
|
||||||
|
if !system.Exists() {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
if system.IsArray() {
|
||||||
|
modified := false
|
||||||
|
system.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
if value.Get("type").String() == "text" {
|
||||||
|
text := value.Get("text").String()
|
||||||
|
obfuscated := matcher.obfuscateText(text)
|
||||||
|
if obfuscated != text {
|
||||||
|
path := "system." + key.String() + ".text"
|
||||||
|
payload, _ = sjson.SetBytes(payload, path, obfuscated)
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if modified {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
} else if system.Type == gjson.String {
|
||||||
|
text := system.String()
|
||||||
|
obfuscated := matcher.obfuscateText(text)
|
||||||
|
if obfuscated != text {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "system", obfuscated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// obfuscateMessages obfuscates sensitive words in message content.
|
||||||
|
func obfuscateMessages(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||||
|
messages := gjson.GetBytes(payload, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.ForEach(func(msgKey, msg gjson.Result) bool {
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msgPath := "messages." + msgKey.String()
|
||||||
|
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
// Simple string content
|
||||||
|
text := content.String()
|
||||||
|
obfuscated := matcher.obfuscateText(text)
|
||||||
|
if obfuscated != text {
|
||||||
|
payload, _ = sjson.SetBytes(payload, msgPath+".content", obfuscated)
|
||||||
|
}
|
||||||
|
} else if content.IsArray() {
|
||||||
|
// Array of content blocks
|
||||||
|
content.ForEach(func(blockKey, block gjson.Result) bool {
|
||||||
|
if block.Get("type").String() == "text" {
|
||||||
|
text := block.Get("text").String()
|
||||||
|
obfuscated := matcher.obfuscateText(text)
|
||||||
|
if obfuscated != text {
|
||||||
|
path := msgPath + ".content." + blockKey.String() + ".text"
|
||||||
|
payload, _ = sjson.SetBytes(payload, path, obfuscated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
47
internal/runtime/executor/cloak_utils.go
Normal file
47
internal/runtime/executor/cloak_utils.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
|
||||||
|
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||||
|
|
||||||
|
// generateFakeUserID generates a fake user ID in Claude Code format.
|
||||||
|
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
|
||||||
|
func generateFakeUserID() string {
|
||||||
|
hexBytes := make([]byte, 32)
|
||||||
|
_, _ = rand.Read(hexBytes)
|
||||||
|
hexPart := hex.EncodeToString(hexBytes)
|
||||||
|
uuidPart := uuid.New().String()
|
||||||
|
return "user_" + hexPart + "_account__session_" + uuidPart
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidUserID checks if a user ID matches Claude Code format.
|
||||||
|
func isValidUserID(userID string) bool {
|
||||||
|
return userIDPattern.MatchString(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
|
||||||
|
// Returns true if cloaking should be applied.
|
||||||
|
func shouldCloak(cloakMode string, userAgent string) bool {
|
||||||
|
switch strings.ToLower(cloakMode) {
|
||||||
|
case "always":
|
||||||
|
return true
|
||||||
|
case "never":
|
||||||
|
return false
|
||||||
|
default: // "auto" or empty
|
||||||
|
// If client is Claude Code, don't cloak
|
||||||
|
return !strings.HasPrefix(userAgent, "claude-cli")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClaudeCodeClient checks if the User-Agent indicates a Claude Code client.
|
||||||
|
func isClaudeCodeClient(userAgent string) bool {
|
||||||
|
return strings.HasPrefix(userAgent, "claude-cli")
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -38,41 +39,88 @@ func NewCodexExecutor(cfg *config.Config) *CodexExecutor { return &CodexExecutor
|
|||||||
|
|
||||||
func (e *CodexExecutor) Identifier() string { return "codex" }
|
func (e *CodexExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
func (e *CodexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
// PrepareRequest injects Codex credentials into the outgoing HTTP request.
|
||||||
|
func (e *CodexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiKey, _ := codexCreds(auth)
|
||||||
|
if strings.TrimSpace(apiKey) != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects Codex credentials into the request and executes it.
|
||||||
|
func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("codex executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
apiKey, baseURL := codexCreds(auth)
|
if opts.Alt == "responses/compact" {
|
||||||
|
return e.executeCompact(ctx, auth, req, opts)
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
apiKey, baseURL := codexCreds(auth)
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||||
}
|
}
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
||||||
defer reporter.trackFailure(ctx, &err)
|
|
||||||
|
|
||||||
model := req.Model
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
defer reporter.trackFailure(ctx, &err)
|
||||||
model = override
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
userAgent := codexUserAgent(ctx)
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
body = NormalizeThinkingConfig(body, model, false)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
return resp, errValidate
|
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, model, body)
|
originalPayload = misc.InjectCodexUserAgent(originalPayload, userAgent)
|
||||||
body, _ = sjson.SetBytes(body, "model", model)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
|
||||||
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||||
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
|
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||||
|
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||||
|
}
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyCodexHeaders(httpReq, auth, apiKey)
|
applyCodexHeaders(httpReq, auth, apiKey, true)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -105,7 +153,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -132,7 +180,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
}
|
}
|
||||||
|
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, line, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, line, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -140,39 +188,140 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
apiKey, baseURL := codexCreds(auth)
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
apiKey, baseURL := codexCreds(auth)
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||||
}
|
}
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
model := req.Model
|
from := opts.SourceFormat
|
||||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
to := sdktranslator.FromString("openai-response")
|
||||||
model = override
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
}
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
body, _ = sjson.DeleteBytes(body, "stream")
|
||||||
|
|
||||||
|
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||||
|
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
applyCodexHeaders(httpReq, auth, apiKey, false)
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
var param any
|
||||||
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, data, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
apiKey, baseURL := codexCreds(auth)
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
userAgent := codexUserAgent(ctx)
|
||||||
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
body = NormalizeThinkingConfig(body, model, false)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
|
||||||
return nil, errValidate
|
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, model, body)
|
originalPayload = misc.InjectCodexUserAgent(originalPayload, userAgent)
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
|
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
|
||||||
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, true)
|
||||||
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.SetBytes(body, "model", model)
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||||
|
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||||
|
}
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyCodexHeaders(httpReq, auth, apiKey)
|
applyCodexHeaders(httpReq, auth, apiKey, true)
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
authID = auth.ID
|
||||||
@@ -208,7 +357,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return nil, readErr
|
return nil, readErr
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -237,7 +386,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, bytes.Clone(line), ¶m)
|
||||||
for i := range chunks {
|
for i := range chunks {
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||||
}
|
}
|
||||||
@@ -252,21 +401,30 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
model := req.Model
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
|
||||||
model = override
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("codex")
|
to := sdktranslator.FromString("codex")
|
||||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
userAgent := codexUserAgent(ctx)
|
||||||
|
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
|
||||||
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||||
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
body, _ = sjson.SetBytes(body, "model", model)
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
body, _ = sjson.SetBytes(body, "stream", false)
|
body, _ = sjson.SetBytes(body, "stream", false)
|
||||||
|
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||||
|
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||||
|
}
|
||||||
|
|
||||||
enc, err := tokenizerForCodexModel(model)
|
enc, err := tokenizerForCodexModel(baseModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -447,14 +605,14 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
|||||||
if from == "claude" {
|
if from == "claude" {
|
||||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||||
if userIDResult.Exists() {
|
if userIDResult.Exists() {
|
||||||
var hasKey bool
|
|
||||||
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
||||||
if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) {
|
var ok bool
|
||||||
|
if cache, ok = getCodexCache(key); !ok {
|
||||||
cache = codexCache{
|
cache = codexCache{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Expire: time.Now().Add(1 * time.Hour),
|
Expire: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
codexCacheMap[key] = cache
|
setCodexCache(key, cache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if from == "openai-response" {
|
} else if from == "openai-response" {
|
||||||
@@ -464,17 +622,21 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
if cache.ID != "" {
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||||
|
}
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rawJSON))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rawJSON))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
httpReq.Header.Set("Conversation_id", cache.ID)
|
if cache.ID != "" {
|
||||||
httpReq.Header.Set("Session_id", cache.ID)
|
httpReq.Header.Set("Conversation_id", cache.ID)
|
||||||
|
httpReq.Header.Set("Session_id", cache.ID)
|
||||||
|
}
|
||||||
return httpReq, nil
|
return httpReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string) {
|
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool) {
|
||||||
r.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
r.Header.Set("Authorization", "Bearer "+token)
|
r.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
|
||||||
@@ -488,7 +650,11 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string) {
|
|||||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "codex_cli_rs/0.50.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464")
|
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "codex_cli_rs/0.50.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464")
|
||||||
|
|
||||||
r.Header.Set("Accept", "text/event-stream")
|
if stream {
|
||||||
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
|
} else {
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
}
|
||||||
r.Header.Set("Connection", "Keep-Alive")
|
r.Header.Set("Connection", "Keep-Alive")
|
||||||
|
|
||||||
isAPIKey := false
|
isAPIKey := false
|
||||||
@@ -512,6 +678,16 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string) {
|
|||||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func codexUserAgent(ctx context.Context) string {
|
||||||
|
if ctx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||||
|
return strings.TrimSpace(ginCtx.Request.UserAgent())
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
@@ -528,51 +704,6 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
|
||||||
trimmed := strings.TrimSpace(alias)
|
|
||||||
if trimmed == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := e.resolveCodexConfig(auth)
|
|
||||||
if entry == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
|
||||||
|
|
||||||
// Candidate names to match against configured aliases/names.
|
|
||||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
|
||||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
|
||||||
candidates = append(candidates, trimmed)
|
|
||||||
}
|
|
||||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
|
||||||
candidates = append(candidates, original)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range entry.Models {
|
|
||||||
model := entry.Models[i]
|
|
||||||
name := strings.TrimSpace(model.Name)
|
|
||||||
modelAlias := strings.TrimSpace(model.Alias)
|
|
||||||
|
|
||||||
for _, candidate := range candidates {
|
|
||||||
if candidate == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
|
||||||
if name != "" {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
return candidate
|
|
||||||
}
|
|
||||||
if name != "" && strings.EqualFold(name, candidate) {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
|
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
|
||||||
if auth == nil || e.cfg == nil {
|
if auth == nil || e.cfg == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -63,28 +64,76 @@ func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor {
|
|||||||
// Identifier returns the executor identifier.
|
// Identifier returns the executor identifier.
|
||||||
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
|
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
|
||||||
|
|
||||||
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini CLI).
|
// PrepareRequest injects Gemini CLI credentials into the outgoing HTTP request.
|
||||||
func (e *GeminiCLIExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tokenSource, _, errSource := prepareGeminiCLITokenSource(req.Context(), e.cfg, auth)
|
||||||
|
if errSource != nil {
|
||||||
|
return errSource
|
||||||
|
}
|
||||||
|
tok, errTok := tokenSource.Token()
|
||||||
|
if errTok != nil {
|
||||||
|
return errTok
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(tok.AccessToken) == "" {
|
||||||
|
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
|
||||||
|
applyGeminiCLIHeaders(req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects Gemini CLI credentials into the request and executes it.
|
||||||
|
func (e *GeminiCLIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("gemini-cli executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
// Execute performs a non-streaming request to the Gemini CLI API.
|
// Execute performs a non-streaming request to the Gemini CLI API.
|
||||||
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
||||||
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
}
|
||||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
|
||||||
|
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -94,9 +143,9 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
projectID := resolveGeminiProjectID(auth)
|
projectID := resolveGeminiProjectID(auth)
|
||||||
models := cliPreviewFallbackOrder(req.Model)
|
models := cliPreviewFallbackOrder(baseModel)
|
||||||
if len(models) == 0 || models[0] != req.Model {
|
if len(models) == 0 || models[0] != baseModel {
|
||||||
models = append([]string{req.Model}, models...)
|
models = append([]string{baseModel}, models...)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -181,7 +230,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
if httpResp.StatusCode == 429 {
|
if httpResp.StatusCode == 429 {
|
||||||
if idx+1 < len(models) {
|
if idx+1 < len(models) {
|
||||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||||
@@ -207,29 +256,43 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
|
||||||
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
}
|
||||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
|
||||||
|
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||||
|
|
||||||
projectID := resolveGeminiProjectID(auth)
|
projectID := resolveGeminiProjectID(auth)
|
||||||
|
|
||||||
models := cliPreviewFallbackOrder(req.Model)
|
models := cliPreviewFallbackOrder(baseModel)
|
||||||
if len(models) == 0 || models[0] != req.Model {
|
if len(models) == 0 || models[0] != baseModel {
|
||||||
models = append([]string{req.Model}, models...)
|
models = append([]string{baseModel}, models...)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -303,7 +366,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
if httpResp.StatusCode == 429 {
|
if httpResp.StatusCode == 429 {
|
||||||
if idx+1 < len(models) {
|
if idx+1 < len(models) {
|
||||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||||
@@ -391,6 +454,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
|
|
||||||
// CountTokens counts tokens for the given request using the Gemini CLI API.
|
// CountTokens counts tokens for the given request using the Gemini CLI API.
|
||||||
func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
@@ -399,9 +464,9 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
|
|
||||||
models := cliPreviewFallbackOrder(req.Model)
|
models := cliPreviewFallbackOrder(baseModel)
|
||||||
if len(models) == 0 || models[0] != req.Model {
|
if len(models) == 0 || models[0] != baseModel {
|
||||||
models = append([]string{req.Model}, models...)
|
models = append([]string{baseModel}, models...)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -419,15 +484,18 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
|
|
||||||
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||||
// Gemini CLI endpoint when iterating fallback variants.
|
// Gemini CLI endpoint when iterating fallback variants.
|
||||||
for _, attemptModel := range models {
|
for range models {
|
||||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
|
||||||
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
|
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
|
}
|
||||||
|
|
||||||
payload = deleteJSONField(payload, "project")
|
payload = deleteJSONField(payload, "project")
|
||||||
payload = deleteJSONField(payload, "model")
|
payload = deleteJSONField(payload, "model")
|
||||||
payload = deleteJSONField(payload, "request.safetySettings")
|
payload = deleteJSONField(payload, "request.safetySettings")
|
||||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
payload = fixGeminiCLIImageAspectRatio(baseModel, payload)
|
||||||
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
|
|
||||||
|
|
||||||
tok, errTok := tokenSource.Token()
|
tok, errTok := tokenSource.Token()
|
||||||
if errTok != nil {
|
if errTok != nil {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -55,8 +56,38 @@ func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor {
|
|||||||
// Identifier returns the executor identifier.
|
// Identifier returns the executor identifier.
|
||||||
func (e *GeminiExecutor) Identifier() string { return "gemini" }
|
func (e *GeminiExecutor) Identifier() string { return "gemini" }
|
||||||
|
|
||||||
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini).
|
// PrepareRequest injects Gemini credentials into the outgoing HTTP request.
|
||||||
func (e *GeminiExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
func (e *GeminiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiKey, bearer := geminiCreds(auth)
|
||||||
|
if apiKey != "" {
|
||||||
|
req.Header.Set("x-goog-api-key", apiKey)
|
||||||
|
req.Header.Del("Authorization")
|
||||||
|
} else if bearer != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+bearer)
|
||||||
|
req.Header.Del("x-goog-api-key")
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(req, auth)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects Gemini credentials into the request and executes it.
|
||||||
|
func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("gemini executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
// Execute performs a non-streaming request to the Gemini API.
|
// Execute performs a non-streaming request to the Gemini API.
|
||||||
// It translates the request to Gemini format, sends it to the API, and translates
|
// It translates the request to Gemini format, sends it to the API, and translates
|
||||||
@@ -72,27 +103,35 @@ func (e *GeminiExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) e
|
|||||||
// - cliproxyexecutor.Response: The response from the API
|
// - cliproxyexecutor.Response: The response from the API
|
||||||
// - error: An error if the request fails
|
// - error: An error if the request fails
|
||||||
func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
apiKey, bearer := geminiCreds(auth)
|
apiKey, bearer := geminiCreds(auth)
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
model := req.Model
|
|
||||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
|
||||||
model = override
|
|
||||||
}
|
|
||||||
|
|
||||||
// Official Gemini API via API key or OAuth bearer
|
// Official Gemini API via API key or OAuth bearer
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
}
|
||||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
body = fixGeminiImageAspectRatio(model, body)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
body = applyPayloadConfig(e.cfg, model, body)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", model)
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -101,7 +140,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, action)
|
||||||
if opts.Alt != "" && action != "countTokens" {
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
}
|
}
|
||||||
@@ -152,7 +191,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -171,29 +210,37 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Gemini API.
|
// ExecuteStream performs a streaming request to the Gemini API.
|
||||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
apiKey, bearer := geminiCreds(auth)
|
apiKey, bearer := geminiCreds(auth)
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
model := req.Model
|
|
||||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
|
||||||
model = override
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
}
|
||||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
body = fixGeminiImageAspectRatio(model, body)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
body = applyPayloadConfig(e.cfg, model, body)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", model)
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "streamGenerateContent")
|
||||||
if opts.Alt == "" {
|
if opts.Alt == "" {
|
||||||
url = url + "?alt=sse"
|
url = url + "?alt=sse"
|
||||||
} else {
|
} else {
|
||||||
@@ -241,7 +288,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("gemini executor: close response body error: %v", errClose)
|
log.Errorf("gemini executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -291,27 +338,28 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
|
|
||||||
// CountTokens counts tokens for the given request using the Gemini API.
|
// CountTokens counts tokens for the given request using the Gemini API.
|
||||||
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
apiKey, bearer := geminiCreds(auth)
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
model := req.Model
|
apiKey, bearer := geminiCreds(auth)
|
||||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
|
||||||
model = override
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
|
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
|
||||||
|
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "countTokens")
|
||||||
|
|
||||||
requestBody := bytes.NewReader(translatedReq)
|
requestBody := bytes.NewReader(translatedReq)
|
||||||
|
|
||||||
@@ -360,7 +408,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
log.Debugf("request error, error status: %d, error body: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,51 +458,6 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
|
|||||||
return base
|
return base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
|
||||||
trimmed := strings.TrimSpace(alias)
|
|
||||||
if trimmed == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := e.resolveGeminiConfig(auth)
|
|
||||||
if entry == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
|
||||||
|
|
||||||
// Candidate names to match against configured aliases/names.
|
|
||||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
|
||||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
|
||||||
candidates = append(candidates, trimmed)
|
|
||||||
}
|
|
||||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
|
||||||
candidates = append(candidates, original)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range entry.Models {
|
|
||||||
model := entry.Models[i]
|
|
||||||
name := strings.TrimSpace(model.Name)
|
|
||||||
modelAlias := strings.TrimSpace(model.Alias)
|
|
||||||
|
|
||||||
for _, candidate := range candidates {
|
|
||||||
if candidate == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
|
||||||
if name != "" {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
return candidate
|
|
||||||
}
|
|
||||||
if name != "" && strings.EqualFold(name, candidate) {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
|
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
|
||||||
if auth == nil || e.cfg == nil {
|
if auth == nil || e.cfg == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -12,10 +12,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -31,6 +32,143 @@ const (
|
|||||||
vertexAPIVersion = "v1"
|
vertexAPIVersion = "v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isImagenModel checks if the model name is an Imagen image generation model.
|
||||||
|
// Imagen models use the :predict action instead of :generateContent.
|
||||||
|
func isImagenModel(model string) bool {
|
||||||
|
lowerModel := strings.ToLower(model)
|
||||||
|
return strings.Contains(lowerModel, "imagen")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getVertexAction returns the appropriate action for the given model.
|
||||||
|
// Imagen models use "predict", while Gemini models use "generateContent".
|
||||||
|
func getVertexAction(model string, isStream bool) string {
|
||||||
|
if isImagenModel(model) {
|
||||||
|
return "predict"
|
||||||
|
}
|
||||||
|
if isStream {
|
||||||
|
return "streamGenerateContent"
|
||||||
|
}
|
||||||
|
return "generateContent"
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertImagenToGeminiResponse converts Imagen API response to Gemini format
|
||||||
|
// so it can be processed by the standard translation pipeline.
|
||||||
|
// This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview.
|
||||||
|
func convertImagenToGeminiResponse(data []byte, model string) []byte {
|
||||||
|
predictions := gjson.GetBytes(data, "predictions")
|
||||||
|
if !predictions.Exists() || !predictions.IsArray() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Gemini-compatible response with inlineData
|
||||||
|
parts := make([]map[string]any, 0)
|
||||||
|
for _, pred := range predictions.Array() {
|
||||||
|
imageData := pred.Get("bytesBase64Encoded").String()
|
||||||
|
mimeType := pred.Get("mimeType").String()
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "image/png"
|
||||||
|
}
|
||||||
|
if imageData != "" {
|
||||||
|
parts = append(parts, map[string]any{
|
||||||
|
"inlineData": map[string]any{
|
||||||
|
"mimeType": mimeType,
|
||||||
|
"data": imageData,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate unique response ID using timestamp
|
||||||
|
responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
response := map[string]any{
|
||||||
|
"candidates": []map[string]any{{
|
||||||
|
"content": map[string]any{
|
||||||
|
"parts": parts,
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
}},
|
||||||
|
"responseId": responseId,
|
||||||
|
"modelVersion": model,
|
||||||
|
// Imagen API doesn't return token counts, set to 0 for tracking purposes
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": 0,
|
||||||
|
"candidatesTokenCount": 0,
|
||||||
|
"totalTokenCount": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToImagenRequest converts a Gemini-style request to Imagen API format.
|
||||||
|
// Imagen API uses a different structure: instances[].prompt instead of contents[].
|
||||||
|
func convertToImagenRequest(payload []byte) ([]byte, error) {
|
||||||
|
// Extract prompt from Gemini-style contents
|
||||||
|
prompt := ""
|
||||||
|
|
||||||
|
// Try to get prompt from contents[0].parts[0].text
|
||||||
|
contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text")
|
||||||
|
if contentsText.Exists() {
|
||||||
|
prompt = contentsText.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no contents, try messages format (OpenAI-compatible)
|
||||||
|
if prompt == "" {
|
||||||
|
messagesText := gjson.GetBytes(payload, "messages.#.content")
|
||||||
|
if messagesText.Exists() && messagesText.IsArray() {
|
||||||
|
for _, msg := range messagesText.Array() {
|
||||||
|
if msg.String() != "" {
|
||||||
|
prompt = msg.String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If still no prompt, try direct prompt field
|
||||||
|
if prompt == "" {
|
||||||
|
directPrompt := gjson.GetBytes(payload, "prompt")
|
||||||
|
if directPrompt.Exists() {
|
||||||
|
prompt = directPrompt.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt == "" {
|
||||||
|
return nil, fmt.Errorf("imagen: no prompt found in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Imagen API request
|
||||||
|
imagenReq := map[string]any{
|
||||||
|
"instances": []map[string]any{
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"sampleCount": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract optional parameters
|
||||||
|
if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() {
|
||||||
|
imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String()
|
||||||
|
}
|
||||||
|
if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() {
|
||||||
|
imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int())
|
||||||
|
}
|
||||||
|
if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() {
|
||||||
|
imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(imagenReq)
|
||||||
|
}
|
||||||
|
|
||||||
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
||||||
type GeminiVertexExecutor struct {
|
type GeminiVertexExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -50,13 +188,54 @@ func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
|||||||
// Identifier returns the executor identifier.
|
// Identifier returns the executor identifier.
|
||||||
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
||||||
|
|
||||||
// PrepareRequest prepares the HTTP request for execution (no-op for Vertex).
|
// PrepareRequest injects Vertex credentials into the outgoing HTTP request.
|
||||||
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
func (e *GeminiVertexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiKey, _ := vertexAPICreds(auth)
|
||||||
|
if strings.TrimSpace(apiKey) != "" {
|
||||||
|
req.Header.Set("x-goog-api-key", apiKey)
|
||||||
|
req.Header.Del("Authorization")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, _, saJSON, errCreds := vertexCreds(auth)
|
||||||
|
if errCreds != nil {
|
||||||
|
return errCreds
|
||||||
|
}
|
||||||
|
token, errToken := vertexAccessToken(req.Context(), e.cfg, auth, saJSON)
|
||||||
|
if errToken != nil {
|
||||||
|
return errToken
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(token) == "" {
|
||||||
|
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Del("x-goog-api-key")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects Vertex credentials into the request and executes it.
|
||||||
|
func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("vertex executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
// Execute performs a non-streaming request to the Vertex AI API.
|
// Execute performs a non-streaming request to the Vertex AI API.
|
||||||
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
// Try API key authentication first
|
// Try API key authentication first
|
||||||
apiKey, baseURL := vertexAPICreds(auth)
|
apiKey, baseURL := vertexAPICreds(auth)
|
||||||
|
|
||||||
@@ -75,6 +254,9 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Vertex AI API.
|
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
// Try API key authentication first
|
// Try API key authentication first
|
||||||
apiKey, baseURL := vertexAPICreds(auth)
|
apiKey, baseURL := vertexAPICreds(auth)
|
||||||
|
|
||||||
@@ -117,34 +299,51 @@ func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Aut
|
|||||||
// executeWithServiceAccount handles authentication using service account credentials.
|
// executeWithServiceAccount handles authentication using service account credentials.
|
||||||
// This method contains the original service account authentication logic.
|
// This method contains the original service account authentication logic.
|
||||||
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
var body []byte
|
||||||
to := sdktranslator.FromString("gemini")
|
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
||||||
budgetOverride = &norm
|
|
||||||
}
|
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
||||||
}
|
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
|
||||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
|
||||||
|
|
||||||
action := "generateContent"
|
// Handle Imagen models with special request format
|
||||||
|
if isImagenModel(baseModel) {
|
||||||
|
imagenBody, errImagen := convertToImagenRequest(req.Payload)
|
||||||
|
if errImagen != nil {
|
||||||
|
return resp, errImagen
|
||||||
|
}
|
||||||
|
body = imagenBody
|
||||||
|
} else {
|
||||||
|
// Standard Gemini translation flow
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
|
||||||
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, false)
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
action = "countTokens"
|
action = "countTokens"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
|
||||||
if opts.Alt != "" && action != "countTokens" {
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
}
|
}
|
||||||
@@ -196,7 +395,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -207,6 +406,16 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
|
|
||||||
|
// For Imagen models, convert response to Gemini format before translation
|
||||||
|
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
|
||||||
|
if isImagenModel(baseModel) {
|
||||||
|
data = convertImagenToGeminiResponse(data, baseModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard Gemini translation (works for both Gemini and converted Imagen responses)
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
@@ -215,32 +424,32 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
|
|
||||||
// executeWithAPIKey handles authentication using API key credentials.
|
// executeWithAPIKey handles authentication using API key credentials.
|
||||||
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
defer reporter.trackFailure(ctx, &err)
|
|
||||||
|
|
||||||
model := req.Model
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
defer reporter.trackFailure(ctx, &err)
|
||||||
model = override
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
|
||||||
budgetOverride = &norm
|
|
||||||
}
|
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
||||||
}
|
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
|
||||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
|
||||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
|
||||||
body = fixGeminiImageAspectRatio(model, body)
|
|
||||||
body = applyPayloadConfig(e.cfg, model, body)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", model)
|
|
||||||
|
|
||||||
action := "generateContent"
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, false)
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
action = "countTokens"
|
action = "countTokens"
|
||||||
@@ -251,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||||
if opts.Alt != "" && action != "countTokens" {
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
}
|
}
|
||||||
@@ -300,7 +509,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -319,32 +528,41 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
|
|
||||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
||||||
budgetOverride = &norm
|
|
||||||
}
|
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
||||||
}
|
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
|
||||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
|
||||||
|
|
||||||
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, true)
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
|
||||||
if opts.Alt == "" {
|
// Imagen models don't support streaming, skip SSE params
|
||||||
url = url + "?alt=sse"
|
if !isImagenModel(baseModel) {
|
||||||
} else {
|
if opts.Alt == "" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
@@ -389,7 +607,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -434,40 +652,44 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
|
|
||||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
defer reporter.trackFailure(ctx, &err)
|
|
||||||
|
|
||||||
model := req.Model
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
defer reporter.trackFailure(ctx, &err)
|
||||||
model = override
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
|
||||||
budgetOverride = &norm
|
|
||||||
}
|
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
||||||
}
|
|
||||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
|
||||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
|
||||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
|
||||||
body = fixGeminiImageAspectRatio(model, body)
|
|
||||||
body = applyPayloadConfig(e.cfg, model, body)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", model)
|
|
||||||
|
|
||||||
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, true)
|
||||||
// For API key auth, use simpler URL format without project/location
|
// For API key auth, use simpler URL format without project/location
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||||
if opts.Alt == "" {
|
// Imagen models don't support streaming, skip SSE params
|
||||||
url = url + "?alt=sse"
|
if !isImagenModel(baseModel) {
|
||||||
} else {
|
if opts.Alt == "" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
@@ -509,7 +731,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -554,26 +776,27 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
|
|
||||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||||
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
budgetOverride = &norm
|
if err != nil {
|
||||||
}
|
return cliproxyexecutor.Response{}, err
|
||||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
|
||||||
}
|
}
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
|
||||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
|
||||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
|
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "countTokens")
|
||||||
|
|
||||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
if errNewReq != nil {
|
if errNewReq != nil {
|
||||||
@@ -621,7 +844,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
@@ -630,10 +853,6 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
return cliproxyexecutor.Response{}, errRead
|
return cliproxyexecutor.Response{}, errRead
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
||||||
}
|
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||||
@@ -641,24 +860,20 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
|
|
||||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||||
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
||||||
model := req.Model
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
|
||||||
model = override
|
|
||||||
}
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
budgetOverride = &norm
|
if err != nil {
|
||||||
}
|
return cliproxyexecutor.Response{}, err
|
||||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
|
||||||
}
|
}
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
|
||||||
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
|
||||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
@@ -668,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
|
||||||
|
|
||||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
if errNewReq != nil {
|
if errNewReq != nil {
|
||||||
@@ -713,7 +928,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
@@ -722,10 +937,6 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
return cliproxyexecutor.Response{}, errRead
|
return cliproxyexecutor.Response{}, errRead
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
||||||
}
|
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||||
@@ -812,53 +1023,6 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
|
|||||||
return tok.AccessToken, nil
|
return tok.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
|
|
||||||
// It matches the requested model alias against configured models and returns the actual upstream name.
|
|
||||||
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
|
||||||
trimmed := strings.TrimSpace(alias)
|
|
||||||
if trimmed == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := e.resolveVertexConfig(auth)
|
|
||||||
if entry == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
|
||||||
|
|
||||||
// Candidate names to match against configured aliases/names.
|
|
||||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
|
||||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
|
||||||
candidates = append(candidates, trimmed)
|
|
||||||
}
|
|
||||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
|
||||||
candidates = append(candidates, original)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range entry.Models {
|
|
||||||
model := entry.Models[i]
|
|
||||||
name := strings.TrimSpace(model.Name)
|
|
||||||
modelAlias := strings.TrimSpace(model.Alias)
|
|
||||||
|
|
||||||
for _, candidate := range candidates {
|
|
||||||
if candidate == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
|
||||||
if name != "" {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
return candidate
|
|
||||||
}
|
|
||||||
if name != "" && strings.EqualFold(name, candidate) {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
|
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
|
||||||
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
|
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
|
||||||
if auth == nil || e.cfg == nil {
|
if auth == nil || e.cfg == nil {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -37,11 +38,41 @@ func NewIFlowExecutor(cfg *config.Config) *IFlowExecutor { return &IFlowExecutor
|
|||||||
// Identifier returns the provider key.
|
// Identifier returns the provider key.
|
||||||
func (e *IFlowExecutor) Identifier() string { return "iflow" }
|
func (e *IFlowExecutor) Identifier() string { return "iflow" }
|
||||||
|
|
||||||
// PrepareRequest implements ProviderExecutor but requires no preprocessing.
|
// PrepareRequest injects iFlow credentials into the outgoing HTTP request.
|
||||||
func (e *IFlowExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
func (e *IFlowExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiKey, _ := iflowCreds(auth)
|
||||||
|
if strings.TrimSpace(apiKey) != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects iFlow credentials into the request and executes it.
|
||||||
|
func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("iflow executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
// Execute performs a non-streaming chat completion request.
|
// Execute performs a non-streaming chat completion request.
|
||||||
func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
apiKey, baseURL := iflowCreds(auth)
|
apiKey, baseURL := iflowCreds(auth)
|
||||||
if strings.TrimSpace(apiKey) == "" {
|
if strings.TrimSpace(apiKey) == "" {
|
||||||
err = fmt.Errorf("iflow executor: missing api key")
|
err = fmt.Errorf("iflow executor: missing api key")
|
||||||
@@ -51,21 +82,27 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
baseURL = iflowauth.DefaultAPIBaseURL
|
baseURL = iflowauth.DefaultAPIBaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
|
||||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
|
||||||
return resp, errValidate
|
|
||||||
}
|
}
|
||||||
body = applyIFlowThinkingConfig(body)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
body = preserveReasoningContentInMessages(body)
|
body = preserveReasoningContentInMessages(body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
|
|
||||||
@@ -108,7 +145,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("iflow request error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -124,6 +161,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
var param any
|
var param any
|
||||||
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@@ -131,6 +170,11 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
|
|
||||||
// ExecuteStream performs a streaming chat completion request.
|
// ExecuteStream performs a streaming chat completion request.
|
||||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
apiKey, baseURL := iflowCreds(auth)
|
apiKey, baseURL := iflowCreds(auth)
|
||||||
if strings.TrimSpace(apiKey) == "" {
|
if strings.TrimSpace(apiKey) == "" {
|
||||||
err = fmt.Errorf("iflow executor: missing api key")
|
err = fmt.Errorf("iflow executor: missing api key")
|
||||||
@@ -140,27 +184,32 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
baseURL = iflowauth.DefaultAPIBaseURL
|
baseURL = iflowauth.DefaultAPIBaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
|
||||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
|
||||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
|
||||||
return nil, errValidate
|
|
||||||
}
|
}
|
||||||
body = applyIFlowThinkingConfig(body)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
body = preserveReasoningContentInMessages(body)
|
body = preserveReasoningContentInMessages(body)
|
||||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||||
toolsResult := gjson.GetBytes(body, "tools")
|
toolsResult := gjson.GetBytes(body, "tools")
|
||||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||||
body = ensureToolsArray(body)
|
body = ensureToolsArray(body)
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||||
|
|
||||||
@@ -201,7 +250,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
log.Errorf("iflow executor: close response body error: %v", errClose)
|
log.Errorf("iflow executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
log.Debugf("iflow streaming error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -243,11 +292,13 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
enc, err := tokenizerForModel(req.Model)
|
enc, err := tokenizerForModel(baseModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
|
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -485,41 +536,3 @@ func preserveReasoningContentInMessages(body []byte) []byte {
|
|||||||
|
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
|
|
||||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
|
||||||
//
|
|
||||||
// Model-specific handling:
|
|
||||||
// - GLM-4.6/4.7: Uses chat_template_kwargs.enable_thinking (boolean) and chat_template_kwargs.clear_thinking=false
|
|
||||||
// - MiniMax M2/M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
|
|
||||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
|
||||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
|
||||||
if !effort.Exists() {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
model := strings.ToLower(gjson.GetBytes(body, "model").String())
|
|
||||||
val := strings.ToLower(strings.TrimSpace(effort.String()))
|
|
||||||
enableThinking := val != "none" && val != ""
|
|
||||||
|
|
||||||
// Remove reasoning_effort as we'll convert to model-specific format
|
|
||||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
|
||||||
body, _ = sjson.DeleteBytes(body, "thinking")
|
|
||||||
|
|
||||||
// GLM-4.6/4.7: Use chat_template_kwargs
|
|
||||||
if strings.HasPrefix(model, "glm-4") {
|
|
||||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
|
||||||
if enableThinking {
|
|
||||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.clear_thinking", false)
|
|
||||||
}
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
// MiniMax M2/M2.1: Use reasoning_split
|
|
||||||
if strings.HasPrefix(model, "minimax-m2") {
|
|
||||||
body, _ = sjson.SetBytes(body, "reasoning_split", enableThinking)
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|||||||
67
internal/runtime/executor/iflow_executor_test.go
Normal file
67
internal/runtime/executor/iflow_executor_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIFlowExecutorParseSuffix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
wantBase string
|
||||||
|
wantLevel string
|
||||||
|
}{
|
||||||
|
{"no suffix", "glm-4", "glm-4", ""},
|
||||||
|
{"glm with suffix", "glm-4.1-flash(high)", "glm-4.1-flash", "high"},
|
||||||
|
{"minimax no suffix", "minimax-m2", "minimax-m2", ""},
|
||||||
|
{"minimax with suffix", "minimax-m2.1(medium)", "minimax-m2.1", "medium"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := thinking.ParseSuffix(tt.model)
|
||||||
|
if result.ModelName != tt.wantBase {
|
||||||
|
t.Errorf("ParseSuffix(%q).ModelName = %q, want %q", tt.model, result.ModelName, tt.wantBase)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreserveReasoningContentInMessages(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
want []byte // nil means output should equal input
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"non-glm model passthrough",
|
||||||
|
[]byte(`{"model":"gpt-4","messages":[]}`),
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"glm model with empty messages",
|
||||||
|
[]byte(`{"model":"glm-4","messages":[]}`),
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"glm model preserves existing reasoning_content",
|
||||||
|
[]byte(`{"model":"glm-4","messages":[{"role":"assistant","content":"hi","reasoning_content":"thinking..."}]}`),
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := preserveReasoningContentInMessages(tt.input)
|
||||||
|
want := tt.want
|
||||||
|
if want == nil {
|
||||||
|
want = tt.input
|
||||||
|
}
|
||||||
|
if string(got) != string(want) {
|
||||||
|
t.Errorf("preserveReasoningContentInMessages() = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,7 +12,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -304,11 +307,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
|
|||||||
parts = append(parts, "type=api_key")
|
parts = append(parts, "type=api_key")
|
||||||
}
|
}
|
||||||
case "oauth":
|
case "oauth":
|
||||||
if authValue != "" {
|
parts = append(parts, "type=oauth")
|
||||||
parts = append(parts, fmt.Sprintf("type=oauth account=%s", authValue))
|
|
||||||
} else {
|
|
||||||
parts = append(parts, "type=oauth")
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
if authType != "" {
|
if authType != "" {
|
||||||
if authValue != "" {
|
if authValue != "" {
|
||||||
@@ -336,6 +335,12 @@ func summarizeErrorBody(contentType string, body []byte) string {
|
|||||||
}
|
}
|
||||||
return "[html body omitted]"
|
return "[html body omitted]"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to extract error message from JSON response
|
||||||
|
if message := extractJSONErrorMessage(body); message != "" {
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
return string(body)
|
return string(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,3 +367,25 @@ func extractHTMLTitle(body []byte) string {
|
|||||||
}
|
}
|
||||||
return strings.Join(strings.Fields(title), " ")
|
return strings.Join(strings.Fields(title), " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractJSONErrorMessage attempts to extract error.message from JSON error responses
|
||||||
|
func extractJSONErrorMessage(body []byte) string {
|
||||||
|
result := gjson.GetBytes(body, "error.message")
|
||||||
|
if result.Exists() && result.String() != "" {
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// logWithRequestID returns a logrus Entry with request_id field populated from context.
|
||||||
|
// If no request ID is found in context, it returns the standard logger.
|
||||||
|
func logWithRequestID(ctx context.Context) *log.Entry {
|
||||||
|
if ctx == nil {
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
requestID := logging.GetRequestID(ctx)
|
||||||
|
if requestID == "" {
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
return log.WithField("request_id", requestID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -35,13 +36,43 @@ func NewOpenAICompatExecutor(provider string, cfg *config.Config) *OpenAICompatE
|
|||||||
// Identifier implements cliproxyauth.ProviderExecutor.
|
// Identifier implements cliproxyauth.ProviderExecutor.
|
||||||
func (e *OpenAICompatExecutor) Identifier() string { return e.provider }
|
func (e *OpenAICompatExecutor) Identifier() string { return e.provider }
|
||||||
|
|
||||||
// PrepareRequest is a no-op for now (credentials are added via headers at execution time).
|
// PrepareRequest injects OpenAI-compatible credentials into the outgoing HTTP request.
|
||||||
func (e *OpenAICompatExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
func (e *OpenAICompatExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, apiKey := e.resolveCredentials(auth)
|
||||||
|
if strings.TrimSpace(apiKey) != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects OpenAI-compatible credentials into the request and executes it.
|
||||||
|
func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("openai compat executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
baseURL, apiKey := e.resolveCredentials(auth)
|
baseURL, apiKey := e.resolveCredentials(auth)
|
||||||
@@ -50,23 +81,33 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Translate inbound request to OpenAI format
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
endpoint := "/chat/completions"
|
||||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
if opts.Alt == "responses/compact" {
|
||||||
if modelOverride != "" {
|
to = sdktranslator.FromString("openai-response")
|
||||||
translated = e.overrideModel(translated, modelOverride)
|
endpoint = "/responses/compact"
|
||||||
}
|
}
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
}
|
||||||
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||||
return resp, errValidate
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
|
if opts.Alt == "responses/compact" {
|
||||||
|
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
|
||||||
|
translated = updated
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := strings.TrimSuffix(baseURL, "/") + endpoint
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -114,7 +155,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -135,7 +176,9 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
baseURL, apiKey := e.resolveCredentials(auth)
|
baseURL, apiKey := e.resolveCredentials(auth)
|
||||||
@@ -143,19 +186,21 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
|
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
modelOverride := e.resolveUpstreamModel(req.Model, auth)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
if modelOverride != "" {
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
translated = e.overrideModel(translated, modelOverride)
|
|
||||||
}
|
}
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||||
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
|
||||||
return nil, errValidate
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
@@ -203,7 +248,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -231,6 +276,11 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
if len(line) == 0 {
|
if len(line) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
|
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
|
||||||
// Pass through translator; it yields one or more chunks for the target schema.
|
// Pass through translator; it yields one or more chunks for the target schema.
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m)
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m)
|
||||||
@@ -250,14 +300,17 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
modelForCounting := req.Model
|
modelForCounting := baseModel
|
||||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
|
||||||
translated = e.overrideModel(translated, modelOverride)
|
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
modelForCounting = modelOverride
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
enc, err := tokenizerForModel(modelForCounting)
|
enc, err := tokenizerForModel(modelForCounting)
|
||||||
@@ -293,53 +346,6 @@ func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (base
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
|
||||||
if alias == "" || auth == nil || e.cfg == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
compat := e.resolveCompatConfig(auth)
|
|
||||||
if compat == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
for i := range compat.Models {
|
|
||||||
model := compat.Models[i]
|
|
||||||
if model.Alias != "" {
|
|
||||||
if strings.EqualFold(model.Alias, alias) {
|
|
||||||
if model.Name != "" {
|
|
||||||
return model.Name
|
|
||||||
}
|
|
||||||
return alias
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.EqualFold(model.Name, alias) {
|
|
||||||
return model.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) allowCompatReasoningEffort(model string, auth *cliproxyauth.Auth) bool {
|
|
||||||
trimmed := strings.TrimSpace(model)
|
|
||||||
if trimmed == "" || e == nil || e.cfg == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
compat := e.resolveCompatConfig(auth)
|
|
||||||
if compat == nil || len(compat.Models) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := range compat.Models {
|
|
||||||
entry := compat.Models[i]
|
|
||||||
if strings.EqualFold(strings.TrimSpace(entry.Alias), trimmed) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if strings.EqualFold(strings.TrimSpace(entry.Name), trimmed) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
|
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
|
||||||
if auth == nil || e.cfg == nil {
|
if auth == nil || e.cfg == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenAICompatExecutorCompactPassthrough(t *testing.T) {
|
||||||
|
var gotPath string
|
||||||
|
var gotBody []byte
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
gotBody = body
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"base_url": server.URL + "/v1",
|
||||||
|
"api_key": "test",
|
||||||
|
}}
|
||||||
|
payload := []byte(`{"model":"gpt-5.1-codex-max","input":[{"role":"user","content":"hi"}]}`)
|
||||||
|
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "gpt-5.1-codex-max",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||||
|
Alt: "responses/compact",
|
||||||
|
Stream: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute error: %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1/responses/compact" {
|
||||||
|
t.Fatalf("path = %q, want %q", gotPath, "/v1/responses/compact")
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(gotBody, "input").Exists() {
|
||||||
|
t.Fatalf("expected input in body")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "messages").Exists() {
|
||||||
|
t.Fatalf("unexpected messages in body")
|
||||||
|
}
|
||||||
|
if string(resp.Payload) != `{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}` {
|
||||||
|
t.Fatalf("payload = %s", string(resp.Payload))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,136 +1,45 @@
|
|||||||
package executor
|
package executor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"encoding/json"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
|
||||||
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
|
|
||||||
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
|
||||||
// Use the alias from metadata if available, as it's registered in the global registry
|
|
||||||
// with thinking metadata; the upstream model name may not be registered.
|
|
||||||
lookupModel := util.ResolveOriginalModel(model, metadata)
|
|
||||||
|
|
||||||
// Determine which model to use for thinking support check.
|
|
||||||
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
|
|
||||||
thinkingModel := lookupModel
|
|
||||||
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
|
|
||||||
thinkingModel = model
|
|
||||||
}
|
|
||||||
|
|
||||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
|
|
||||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
if !util.ModelSupportsThinking(thinkingModel) {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
|
|
||||||
budgetOverride = &norm
|
|
||||||
}
|
|
||||||
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
|
|
||||||
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
|
|
||||||
func ApplyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
|
|
||||||
// Use the alias from metadata if available, as it's registered in the global registry
|
|
||||||
// with thinking metadata; the upstream model name may not be registered.
|
|
||||||
lookupModel := util.ResolveOriginalModel(model, metadata)
|
|
||||||
|
|
||||||
// Determine which model to use for thinking support check.
|
|
||||||
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
|
|
||||||
thinkingModel := lookupModel
|
|
||||||
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
|
|
||||||
thinkingModel = model
|
|
||||||
}
|
|
||||||
|
|
||||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
|
|
||||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
if !util.ModelSupportsThinking(thinkingModel) {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
|
|
||||||
budgetOverride = &norm
|
|
||||||
}
|
|
||||||
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
|
||||||
// Metadata values take precedence over any existing field when the model supports thinking, intentionally
|
|
||||||
// overwriting caller-provided values to honor suffix/default metadata priority.
|
|
||||||
func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string, allowCompat bool) []byte {
|
|
||||||
if len(metadata) == 0 {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
if field == "" {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
baseModel := util.ResolveOriginalModel(model, metadata)
|
|
||||||
if baseModel == "" {
|
|
||||||
baseModel = model
|
|
||||||
}
|
|
||||||
if !util.ModelSupportsThinking(baseModel) && !allowCompat {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
if effort, ok := util.ReasoningEffortFromMetadata(metadata); ok && effort != "" {
|
|
||||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
|
||||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
|
||||||
return updated
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Fallback: numeric thinking_budget suffix for level-based (OpenAI-style) models.
|
|
||||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
|
||||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
|
||||||
if effort, ok := util.ThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
|
||||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
|
||||||
return updated
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyPayloadConfig applies payload default and override rules from configuration
|
|
||||||
// to the given JSON payload for the specified model.
|
|
||||||
// Defaults only fill missing fields, while overrides always overwrite existing values.
|
|
||||||
func applyPayloadConfig(cfg *config.Config, model string, payload []byte) []byte {
|
|
||||||
return applyPayloadConfigWithRoot(cfg, model, "", "", payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||||
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
||||||
// and restricts matches to the given protocol when supplied.
|
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload []byte) []byte {
|
// against the original payload when provided. requestedModel carries the client-visible
|
||||||
|
// model name before alias resolution so payload rules can target aliases precisely.
|
||||||
|
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||||
if cfg == nil || len(payload) == 0 {
|
if cfg == nil || len(payload) == 0 {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
rules := cfg.Payload
|
rules := cfg.Payload
|
||||||
if len(rules.Default) == 0 && len(rules.Override) == 0 {
|
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
model = strings.TrimSpace(model)
|
model = strings.TrimSpace(model)
|
||||||
if model == "" {
|
requestedModel = strings.TrimSpace(requestedModel)
|
||||||
|
if model == "" && requestedModel == "" {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
candidates := payloadModelCandidates(model, requestedModel)
|
||||||
out := payload
|
out := payload
|
||||||
|
source := original
|
||||||
|
if len(source) == 0 {
|
||||||
|
source = payload
|
||||||
|
}
|
||||||
|
appliedDefaults := make(map[string]struct{})
|
||||||
// Apply default rules: first write wins per field across all matching rules.
|
// Apply default rules: first write wins per field across all matching rules.
|
||||||
for i := range rules.Default {
|
for i := range rules.Default {
|
||||||
rule := &rules.Default[i]
|
rule := &rules.Default[i]
|
||||||
if !payloadRuleMatchesModel(rule, model, protocol) {
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for path, value := range rule.Params {
|
for path, value := range rule.Params {
|
||||||
@@ -138,7 +47,58 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
if fullPath == "" {
|
if fullPath == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if gjson.GetBytes(out, fullPath).Exists() {
|
if gjson.GetBytes(source, fullPath).Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := appliedDefaults[fullPath]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||||
|
if errSet != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = updated
|
||||||
|
appliedDefaults[fullPath] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Apply default raw rules: first write wins per field across all matching rules.
|
||||||
|
for i := range rules.DefaultRaw {
|
||||||
|
rule := &rules.DefaultRaw[i]
|
||||||
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for path, value := range rule.Params {
|
||||||
|
fullPath := buildPayloadPath(root, path)
|
||||||
|
if fullPath == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(source, fullPath).Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := appliedDefaults[fullPath]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rawValue, ok := payloadRawValue(value)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||||
|
if errSet != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = updated
|
||||||
|
appliedDefaults[fullPath] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Apply override rules: last write wins per field across all matching rules.
|
||||||
|
for i := range rules.Override {
|
||||||
|
rule := &rules.Override[i]
|
||||||
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for path, value := range rule.Params {
|
||||||
|
fullPath := buildPayloadPath(root, path)
|
||||||
|
if fullPath == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||||
@@ -148,10 +108,10 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
out = updated
|
out = updated
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Apply override rules: last write wins per field across all matching rules.
|
// Apply override raw rules: last write wins per field across all matching rules.
|
||||||
for i := range rules.Override {
|
for i := range rules.OverrideRaw {
|
||||||
rule := &rules.Override[i]
|
rule := &rules.OverrideRaw[i]
|
||||||
if !payloadRuleMatchesModel(rule, model, protocol) {
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for path, value := range rule.Params {
|
for path, value := range rule.Params {
|
||||||
@@ -159,38 +119,95 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
if fullPath == "" {
|
if fullPath == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
rawValue, ok := payloadRawValue(value)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
out = updated
|
out = updated
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Apply filter rules: remove matching paths from payload.
|
||||||
|
for i := range rules.Filter {
|
||||||
|
rule := &rules.Filter[i]
|
||||||
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, path := range rule.Params {
|
||||||
|
fullPath := buildPayloadPath(root, path)
|
||||||
|
if fullPath == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||||
|
if errDel != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = updated
|
||||||
|
}
|
||||||
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) bool {
|
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool {
|
||||||
if rule == nil {
|
if len(rules) == 0 || len(models) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(rule.Models) == 0 {
|
for _, model := range models {
|
||||||
return false
|
for _, entry := range rules {
|
||||||
}
|
name := strings.TrimSpace(entry.Name)
|
||||||
for _, entry := range rule.Models {
|
if name == "" {
|
||||||
name := strings.TrimSpace(entry.Name)
|
continue
|
||||||
if name == "" {
|
}
|
||||||
continue
|
if ep := strings.TrimSpace(entry.Protocol); ep != "" && protocol != "" && !strings.EqualFold(ep, protocol) {
|
||||||
}
|
continue
|
||||||
if ep := strings.TrimSpace(entry.Protocol); ep != "" && protocol != "" && !strings.EqualFold(ep, protocol) {
|
}
|
||||||
continue
|
if matchModelPattern(name, model) {
|
||||||
}
|
return true
|
||||||
if matchModelPattern(name, model) {
|
}
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func payloadModelCandidates(model, requestedModel string) []string {
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
requestedModel = strings.TrimSpace(requestedModel)
|
||||||
|
if model == "" && requestedModel == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
candidates := make([]string, 0, 3)
|
||||||
|
seen := make(map[string]struct{}, 3)
|
||||||
|
addCandidate := func(value string) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := strings.ToLower(value)
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
candidates = append(candidates, value)
|
||||||
|
}
|
||||||
|
if model != "" {
|
||||||
|
addCandidate(model)
|
||||||
|
}
|
||||||
|
if requestedModel != "" {
|
||||||
|
parsed := thinking.ParseSuffix(requestedModel)
|
||||||
|
base := strings.TrimSpace(parsed.ModelName)
|
||||||
|
if base != "" {
|
||||||
|
addCandidate(base)
|
||||||
|
}
|
||||||
|
if parsed.HasSuffix {
|
||||||
|
addCandidate(requestedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
// buildPayloadPath combines an optional root path with a relative parameter path.
|
// buildPayloadPath combines an optional root path with a relative parameter path.
|
||||||
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
||||||
// the parameter path is treated as relative to root.
|
// the parameter path is treated as relative to root.
|
||||||
@@ -209,6 +226,53 @@ func buildPayloadPath(root, path string) string {
|
|||||||
return r + "." + p
|
return r + "." + p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func payloadRawValue(value any) ([]byte, bool) {
|
||||||
|
if value == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case string:
|
||||||
|
return []byte(typed), true
|
||||||
|
case []byte:
|
||||||
|
return typed, true
|
||||||
|
default:
|
||||||
|
raw, errMarshal := json.Marshal(typed)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return raw, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||||
|
fallback = strings.TrimSpace(fallback)
|
||||||
|
if len(opts.Metadata) == 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(v) == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case []byte:
|
||||||
|
if len(v) == 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(string(v))
|
||||||
|
if trimmed == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return trimmed
|
||||||
|
default:
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
||||||
// Examples:
|
// Examples:
|
||||||
//
|
//
|
||||||
@@ -253,102 +317,3 @@ func matchModelPattern(pattern, model string) bool {
|
|||||||
}
|
}
|
||||||
return pi == len(pattern)
|
return pi == len(pattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NormalizeThinkingConfig normalizes thinking-related fields in the payload
|
|
||||||
// based on model capabilities. For models without thinking support, it strips
|
|
||||||
// reasoning fields. For models with level-based thinking, it validates and
|
|
||||||
// normalizes the reasoning effort level. For models with numeric budget thinking,
|
|
||||||
// it strips the effort string fields.
|
|
||||||
func NormalizeThinkingConfig(payload []byte, model string, allowCompat bool) []byte {
|
|
||||||
if len(payload) == 0 || model == "" {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
|
|
||||||
if !util.ModelSupportsThinking(model) {
|
|
||||||
if allowCompat {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
return StripThinkingFields(payload, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
if util.ModelUsesThinkingLevels(model) {
|
|
||||||
return NormalizeReasoningEffortLevel(payload, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model supports thinking but uses numeric budgets, not levels.
|
|
||||||
// Strip effort string fields since they are not applicable.
|
|
||||||
return StripThinkingFields(payload, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// StripThinkingFields removes thinking-related fields from the payload for
|
|
||||||
// models that do not support thinking. If effortOnly is true, only removes
|
|
||||||
// effort string fields (for models using numeric budgets).
|
|
||||||
func StripThinkingFields(payload []byte, effortOnly bool) []byte {
|
|
||||||
fieldsToRemove := []string{
|
|
||||||
"reasoning_effort",
|
|
||||||
"reasoning.effort",
|
|
||||||
}
|
|
||||||
if !effortOnly {
|
|
||||||
fieldsToRemove = append([]string{"reasoning", "thinking"}, fieldsToRemove...)
|
|
||||||
}
|
|
||||||
out := payload
|
|
||||||
for _, field := range fieldsToRemove {
|
|
||||||
if gjson.GetBytes(out, field).Exists() {
|
|
||||||
out, _ = sjson.DeleteBytes(out, field)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// NormalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
|
||||||
// or reasoning.effort field for level-based thinking models.
|
|
||||||
func NormalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
|
||||||
out := payload
|
|
||||||
|
|
||||||
if effort := gjson.GetBytes(out, "reasoning_effort"); effort.Exists() {
|
|
||||||
if normalized, ok := util.NormalizeReasoningEffortLevel(model, effort.String()); ok {
|
|
||||||
out, _ = sjson.SetBytes(out, "reasoning_effort", normalized)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if effort := gjson.GetBytes(out, "reasoning.effort"); effort.Exists() {
|
|
||||||
if normalized, ok := util.NormalizeReasoningEffortLevel(model, effort.String()); ok {
|
|
||||||
out, _ = sjson.SetBytes(out, "reasoning.effort", normalized)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
|
||||||
// Returns a statusErr with 400 when an unsupported level is supplied to avoid silently
|
|
||||||
// downgrading requests.
|
|
||||||
func ValidateThinkingConfig(payload []byte, model string) error {
|
|
||||||
if len(payload) == 0 || model == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !util.ModelSupportsThinking(model) || !util.ModelUsesThinkingLevels(model) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
levels := util.GetModelThinkingLevels(model)
|
|
||||||
checkField := func(path string) error {
|
|
||||||
if effort := gjson.GetBytes(payload, path); effort.Exists() {
|
|
||||||
if _, ok := util.NormalizeReasoningEffortLevel(model, effort.String()); !ok {
|
|
||||||
return statusErr{
|
|
||||||
code: http.StatusBadRequest,
|
|
||||||
msg: fmt.Sprintf("unsupported reasoning effort level %q for model %s (supported: %s)", effort.String(), model, strings.Join(levels, ", ")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkField("reasoning_effort"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := checkField("reasoning.effort"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -36,27 +37,65 @@ func NewQwenExecutor(cfg *config.Config) *QwenExecutor { return &QwenExecutor{cf
|
|||||||
|
|
||||||
func (e *QwenExecutor) Identifier() string { return "qwen" }
|
func (e *QwenExecutor) Identifier() string { return "qwen" }
|
||||||
|
|
||||||
func (e *QwenExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
// PrepareRequest injects Qwen credentials into the outgoing HTTP request.
|
||||||
|
func (e *QwenExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
token, _ := qwenCreds(auth)
|
||||||
|
if strings.TrimSpace(token) != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects Qwen credentials into the request and executes it.
|
||||||
|
func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("qwen executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
token, baseURL := qwenCreds(auth)
|
if opts.Alt == "responses/compact" {
|
||||||
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
token, baseURL := qwenCreds(auth)
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://portal.qwen.ai/v1"
|
baseURL = "https://portal.qwen.ai/v1"
|
||||||
}
|
}
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
if len(opts.OriginalRequest) > 0 {
|
||||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
|
||||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
|
||||||
return resp, errValidate
|
|
||||||
}
|
}
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -97,7 +136,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -109,30 +148,42 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||||
var param any
|
var param any
|
||||||
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
token, baseURL := qwenCreds(auth)
|
if opts.Alt == "responses/compact" {
|
||||||
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
|
}
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
token, baseURL := qwenCreds(auth)
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://portal.qwen.ai/v1"
|
baseURL = "https://portal.qwen.ai/v1"
|
||||||
}
|
}
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
|
||||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
|
||||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
|
||||||
return nil, errValidate
|
|
||||||
}
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
toolsResult := gjson.GetBytes(body, "tools")
|
toolsResult := gjson.GetBytes(body, "tools")
|
||||||
// I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response.
|
// I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response.
|
||||||
// This will have no real consequences. It's just to scare Qwen3.
|
// This will have no real consequences. It's just to scare Qwen3.
|
||||||
@@ -140,7 +191,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||||
}
|
}
|
||||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
@@ -176,7 +228,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -220,13 +272,15 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
modelName := gjson.GetBytes(body, "model").String()
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
if strings.TrimSpace(modelName) == "" {
|
if strings.TrimSpace(modelName) == "" {
|
||||||
modelName = req.Model
|
modelName = baseModel
|
||||||
}
|
}
|
||||||
|
|
||||||
enc, err := tokenizerForModel(modelName)
|
enc, err := tokenizerForModel(modelName)
|
||||||
|
|||||||
30
internal/runtime/executor/qwen_executor_test.go
Normal file
30
internal/runtime/executor/qwen_executor_test.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQwenExecutorParseSuffix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
wantBase string
|
||||||
|
wantLevel string
|
||||||
|
}{
|
||||||
|
{"no suffix", "qwen-max", "qwen-max", ""},
|
||||||
|
{"with level suffix", "qwen-max(high)", "qwen-max", "high"},
|
||||||
|
{"with budget suffix", "qwen-max(16384)", "qwen-max", "16384"},
|
||||||
|
{"complex model name", "qwen-plus-latest(medium)", "qwen-plus-latest", "medium"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := thinking.ParseSuffix(tt.model)
|
||||||
|
if result.ModelName != tt.wantBase {
|
||||||
|
t.Errorf("ParseSuffix(%q).ModelName = %q, want %q", tt.model, result.ModelName, tt.wantBase)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
11
internal/runtime/executor/thinking_providers.go
Normal file
11
internal/runtime/executor/thinking_providers.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/codex"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
||||||
|
)
|
||||||
@@ -199,15 +199,31 @@ func parseOpenAIUsage(data []byte) usage.Detail {
|
|||||||
if !usageNode.Exists() {
|
if !usageNode.Exists() {
|
||||||
return usage.Detail{}
|
return usage.Detail{}
|
||||||
}
|
}
|
||||||
|
inputNode := usageNode.Get("prompt_tokens")
|
||||||
|
if !inputNode.Exists() {
|
||||||
|
inputNode = usageNode.Get("input_tokens")
|
||||||
|
}
|
||||||
|
outputNode := usageNode.Get("completion_tokens")
|
||||||
|
if !outputNode.Exists() {
|
||||||
|
outputNode = usageNode.Get("output_tokens")
|
||||||
|
}
|
||||||
detail := usage.Detail{
|
detail := usage.Detail{
|
||||||
InputTokens: usageNode.Get("prompt_tokens").Int(),
|
InputTokens: inputNode.Int(),
|
||||||
OutputTokens: usageNode.Get("completion_tokens").Int(),
|
OutputTokens: outputNode.Int(),
|
||||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||||
}
|
}
|
||||||
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
|
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
|
||||||
|
if !cached.Exists() {
|
||||||
|
cached = usageNode.Get("input_tokens_details.cached_tokens")
|
||||||
|
}
|
||||||
|
if cached.Exists() {
|
||||||
detail.CachedTokens = cached.Int()
|
detail.CachedTokens = cached.Int()
|
||||||
}
|
}
|
||||||
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
|
||||||
|
if !reasoning.Exists() {
|
||||||
|
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
|
||||||
|
}
|
||||||
|
if reasoning.Exists() {
|
||||||
detail.ReasoningTokens = reasoning.Int()
|
detail.ReasoningTokens = reasoning.Int()
|
||||||
}
|
}
|
||||||
return detail
|
return detail
|
||||||
|
|||||||
43
internal/runtime/executor/usage_helpers_test.go
Normal file
43
internal/runtime/executor/usage_helpers_test.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
|
||||||
|
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
|
||||||
|
detail := parseOpenAIUsage(data)
|
||||||
|
if detail.InputTokens != 1 {
|
||||||
|
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
|
||||||
|
}
|
||||||
|
if detail.OutputTokens != 2 {
|
||||||
|
t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 2)
|
||||||
|
}
|
||||||
|
if detail.TotalTokens != 3 {
|
||||||
|
t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 3)
|
||||||
|
}
|
||||||
|
if detail.CachedTokens != 4 {
|
||||||
|
t.Fatalf("cached tokens = %d, want %d", detail.CachedTokens, 4)
|
||||||
|
}
|
||||||
|
if detail.ReasoningTokens != 5 {
|
||||||
|
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 5)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOpenAIUsageResponses(t *testing.T) {
|
||||||
|
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
|
||||||
|
detail := parseOpenAIUsage(data)
|
||||||
|
if detail.InputTokens != 10 {
|
||||||
|
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
|
||||||
|
}
|
||||||
|
if detail.OutputTokens != 20 {
|
||||||
|
t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 20)
|
||||||
|
}
|
||||||
|
if detail.TotalTokens != 30 {
|
||||||
|
t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 30)
|
||||||
|
}
|
||||||
|
if detail.CachedTokens != 7 {
|
||||||
|
t.Fatalf("cached tokens = %d, want %d", detail.CachedTokens, 7)
|
||||||
|
}
|
||||||
|
if detail.ReasoningTokens != 9 {
|
||||||
|
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -386,11 +386,12 @@ func (s *ObjectTokenStore) syncConfigFromBucket(ctx context.Context, example str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ObjectTokenStore) syncAuthFromBucket(ctx context.Context) error {
|
func (s *ObjectTokenStore) syncAuthFromBucket(ctx context.Context) error {
|
||||||
if err := os.RemoveAll(s.authDir); err != nil {
|
// NOTE: We intentionally do NOT use os.RemoveAll here.
|
||||||
return fmt.Errorf("object store: reset auth directory: %w", err)
|
// Wiping the directory triggers file watcher delete events, which then
|
||||||
}
|
// propagate deletions to the remote object store (race condition).
|
||||||
|
// Instead, we just ensure the directory exists and overwrite files incrementally.
|
||||||
if err := os.MkdirAll(s.authDir, 0o700); err != nil {
|
if err := os.MkdirAll(s.authDir, 0o700); err != nil {
|
||||||
return fmt.Errorf("object store: recreate auth directory: %w", err)
|
return fmt.Errorf("object store: create auth directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := s.prefixedKey(objectStoreAuthPrefix + "/")
|
prefix := s.prefixedKey(objectStoreAuthPrefix + "/")
|
||||||
|
|||||||
487
internal/thinking/apply.go
Normal file
487
internal/thinking/apply.go
Normal file
@@ -0,0 +1,487 @@
|
|||||||
|
// Package thinking provides unified thinking configuration processing.
|
||||||
|
package thinking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// providerAppliers maps provider names to their ProviderApplier implementations.
|
||||||
|
var providerAppliers = map[string]ProviderApplier{
|
||||||
|
"gemini": nil,
|
||||||
|
"gemini-cli": nil,
|
||||||
|
"claude": nil,
|
||||||
|
"openai": nil,
|
||||||
|
"codex": nil,
|
||||||
|
"iflow": nil,
|
||||||
|
"antigravity": nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderApplier returns the ProviderApplier for the given provider name.
|
||||||
|
// Returns nil if the provider is not registered.
|
||||||
|
func GetProviderApplier(provider string) ProviderApplier {
|
||||||
|
return providerAppliers[provider]
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProvider registers a provider applier by name.
|
||||||
|
func RegisterProvider(name string, applier ProviderApplier) {
|
||||||
|
providerAppliers[name] = applier
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUserDefinedModel reports whether the model is a user-defined model that should
|
||||||
|
// have thinking configuration passed through without validation.
|
||||||
|
//
|
||||||
|
// User-defined models are configured via config file's models[] array
|
||||||
|
// (e.g., openai-compatibility.*.models[], *-api-key.models[]). These models
|
||||||
|
// are marked with UserDefined=true at registration time.
|
||||||
|
//
|
||||||
|
// User-defined models should have their thinking configuration applied directly,
|
||||||
|
// letting the upstream service validate the configuration.
|
||||||
|
func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
||||||
|
if modelInfo == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return modelInfo.UserDefined
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyThinking applies thinking configuration to a request body.
|
||||||
|
//
|
||||||
|
// This is the unified entry point for all providers. It follows the processing
|
||||||
|
// order defined in FR25: route check → model capability query → config extraction
|
||||||
|
// → validation → application.
|
||||||
|
//
|
||||||
|
// Suffix Priority: When the model name includes a thinking suffix (e.g., "gemini-2.5-pro(8192)"),
|
||||||
|
// the suffix configuration takes priority over any thinking parameters in the request body.
|
||||||
|
// This enables users to override thinking settings via the model name without modifying their
|
||||||
|
// request payload.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - body: Original request body JSON
|
||||||
|
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
|
||||||
|
// - fromFormat: Source request format (e.g., openai, codex, gemini)
|
||||||
|
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
|
||||||
|
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - Modified request body JSON with thinking configuration applied
|
||||||
|
// - Error if validation fails (ThinkingError). On error, the original body
|
||||||
|
// is returned (not nil) to enable defensive programming patterns.
|
||||||
|
//
|
||||||
|
// Passthrough behavior (returns original body without error):
|
||||||
|
// - Unknown provider (not in providerAppliers map)
|
||||||
|
// - modelInfo.Thinking is nil (model doesn't support thinking)
|
||||||
|
//
|
||||||
|
// Note: Unknown models (modelInfo is nil) are treated as user-defined models: we skip
|
||||||
|
// validation and still apply the thinking config so the upstream can validate it.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// // With suffix - suffix config takes priority
|
||||||
|
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini", "gemini")
|
||||||
|
//
|
||||||
|
// // Without suffix - uses body config
|
||||||
|
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini", "gemini")
|
||||||
|
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string, providerKey string) ([]byte, error) {
|
||||||
|
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
|
||||||
|
providerKey = strings.ToLower(strings.TrimSpace(providerKey))
|
||||||
|
if providerKey == "" {
|
||||||
|
providerKey = providerFormat
|
||||||
|
}
|
||||||
|
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
|
||||||
|
if fromFormat == "" {
|
||||||
|
fromFormat = providerFormat
|
||||||
|
}
|
||||||
|
// 1. Route check: Get provider applier
|
||||||
|
applier := GetProviderApplier(providerFormat)
|
||||||
|
if applier == nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": providerFormat,
|
||||||
|
"model": model,
|
||||||
|
}).Debug("thinking: unknown provider, passthrough |")
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Parse suffix and get modelInfo
|
||||||
|
suffixResult := ParseSuffix(model)
|
||||||
|
baseModel := suffixResult.ModelName
|
||||||
|
// Use provider-specific lookup to handle capability differences across providers.
|
||||||
|
modelInfo := registry.LookupModelInfo(baseModel, providerKey)
|
||||||
|
|
||||||
|
// 3. Model capability check
|
||||||
|
// Unknown models are treated as user-defined so thinking config can still be applied.
|
||||||
|
// The upstream service is responsible for validating the configuration.
|
||||||
|
if IsUserDefinedModel(modelInfo) {
|
||||||
|
return applyUserDefinedModel(body, modelInfo, fromFormat, providerFormat, suffixResult)
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
config := extractThinkingConfig(body, providerFormat)
|
||||||
|
if hasThinkingConfig(config) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"model": baseModel,
|
||||||
|
"provider": providerFormat,
|
||||||
|
}).Debug("thinking: model does not support thinking, stripping config |")
|
||||||
|
return StripThinkingConfig(body, providerFormat), nil
|
||||||
|
}
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": providerFormat,
|
||||||
|
"model": baseModel,
|
||||||
|
}).Debug("thinking: model does not support thinking, passthrough |")
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Get config: suffix priority over body
|
||||||
|
var config ThinkingConfig
|
||||||
|
if suffixResult.HasSuffix {
|
||||||
|
config = parseSuffixToConfig(suffixResult.RawSuffix, providerFormat, model)
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": providerFormat,
|
||||||
|
"model": model,
|
||||||
|
"mode": config.Mode,
|
||||||
|
"budget": config.Budget,
|
||||||
|
"level": config.Level,
|
||||||
|
}).Debug("thinking: config from model suffix |")
|
||||||
|
} else {
|
||||||
|
config = extractThinkingConfig(body, providerFormat)
|
||||||
|
if hasThinkingConfig(config) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": providerFormat,
|
||||||
|
"model": modelInfo.ID,
|
||||||
|
"mode": config.Mode,
|
||||||
|
"budget": config.Budget,
|
||||||
|
"level": config.Level,
|
||||||
|
}).Debug("thinking: original config from request |")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasThinkingConfig(config) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": providerFormat,
|
||||||
|
"model": modelInfo.ID,
|
||||||
|
}).Debug("thinking: no config found, passthrough |")
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Validate and normalize configuration
|
||||||
|
validated, err := ValidateConfig(config, modelInfo, fromFormat, providerFormat, suffixResult.HasSuffix)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": providerFormat,
|
||||||
|
"model": modelInfo.ID,
|
||||||
|
"error": err.Error(),
|
||||||
|
}).Warn("thinking: validation failed |")
|
||||||
|
// Return original body on validation failure (defensive programming).
|
||||||
|
// This ensures callers who ignore the error won't receive nil body.
|
||||||
|
// The upstream service will decide how to handle the unmodified request.
|
||||||
|
return body, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Defensive check: ValidateConfig should never return (nil, nil)
|
||||||
|
if validated == nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": providerFormat,
|
||||||
|
"model": modelInfo.ID,
|
||||||
|
}).Warn("thinking: ValidateConfig returned nil config without error, passthrough |")
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": providerFormat,
|
||||||
|
"model": modelInfo.ID,
|
||||||
|
"mode": validated.Mode,
|
||||||
|
"budget": validated.Budget,
|
||||||
|
"level": validated.Level,
|
||||||
|
}).Debug("thinking: processed config to apply |")
|
||||||
|
|
||||||
|
// 6. Apply configuration using provider-specific applier
|
||||||
|
return applier.Apply(body, *validated, modelInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSuffixToConfig converts a raw suffix string to ThinkingConfig.
|
||||||
|
//
|
||||||
|
// Parsing priority:
|
||||||
|
// 1. Special values: "none" → ModeNone, "auto"/"-1" → ModeAuto
|
||||||
|
// 2. Level names: "minimal", "low", "medium", "high", "xhigh" → ModeLevel
|
||||||
|
// 3. Numeric values: positive integers → ModeBudget, 0 → ModeNone
|
||||||
|
//
|
||||||
|
// If none of the above match, returns empty ThinkingConfig (treated as no config).
|
||||||
|
func parseSuffixToConfig(rawSuffix, provider, model string) ThinkingConfig {
|
||||||
|
// 1. Try special values first (none, auto, -1)
|
||||||
|
if mode, ok := ParseSpecialSuffix(rawSuffix); ok {
|
||||||
|
switch mode {
|
||||||
|
case ModeNone:
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
case ModeAuto:
|
||||||
|
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Try level parsing (minimal, low, medium, high, xhigh)
|
||||||
|
if level, ok := ParseLevelSuffix(rawSuffix); ok {
|
||||||
|
return ThinkingConfig{Mode: ModeLevel, Level: level}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Try numeric parsing
|
||||||
|
if budget, ok := ParseNumericSuffix(rawSuffix); ok {
|
||||||
|
if budget == 0 {
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
}
|
||||||
|
return ThinkingConfig{Mode: ModeBudget, Budget: budget}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown suffix format - return empty config
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": provider,
|
||||||
|
"model": model,
|
||||||
|
"raw_suffix": rawSuffix,
|
||||||
|
}).Debug("thinking: unknown suffix format, treating as no config |")
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyUserDefinedModel applies thinking configuration for user-defined models
|
||||||
|
// without ThinkingSupport validation.
|
||||||
|
func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromFormat, toFormat string, suffixResult SuffixResult) ([]byte, error) {
|
||||||
|
// Get model ID for logging
|
||||||
|
modelID := ""
|
||||||
|
if modelInfo != nil {
|
||||||
|
modelID = modelInfo.ID
|
||||||
|
} else {
|
||||||
|
modelID = suffixResult.ModelName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get config: suffix priority over body
|
||||||
|
var config ThinkingConfig
|
||||||
|
if suffixResult.HasSuffix {
|
||||||
|
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
|
||||||
|
} else {
|
||||||
|
config = extractThinkingConfig(body, toFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasThinkingConfig(config) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"model": modelID,
|
||||||
|
"provider": toFormat,
|
||||||
|
}).Debug("thinking: user-defined model, passthrough (no config) |")
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
applier := GetProviderApplier(toFormat)
|
||||||
|
if applier == nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"model": modelID,
|
||||||
|
"provider": toFormat,
|
||||||
|
}).Debug("thinking: user-defined model, passthrough (unknown provider) |")
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": toFormat,
|
||||||
|
"model": modelID,
|
||||||
|
"mode": config.Mode,
|
||||||
|
"budget": config.Budget,
|
||||||
|
"level": config.Level,
|
||||||
|
}).Debug("thinking: applying config for user-defined model (skip validation)")
|
||||||
|
|
||||||
|
config = normalizeUserDefinedConfig(config, fromFormat, toFormat)
|
||||||
|
return applier.Apply(body, config, modelInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat string) ThinkingConfig {
|
||||||
|
if config.Mode != ModeLevel {
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
if !isBudgetBasedProvider(toFormat) || !isLevelBasedProvider(fromFormat) {
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
budget, ok := ConvertLevelToBudget(string(config.Level))
|
||||||
|
if !ok {
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
config.Mode = ModeBudget
|
||||||
|
config.Budget = budget
|
||||||
|
config.Level = ""
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractThinkingConfig extracts provider-specific thinking config from request body.
|
||||||
|
func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch provider {
|
||||||
|
case "claude":
|
||||||
|
return extractClaudeConfig(body)
|
||||||
|
case "gemini", "gemini-cli", "antigravity":
|
||||||
|
return extractGeminiConfig(body, provider)
|
||||||
|
case "openai":
|
||||||
|
return extractOpenAIConfig(body)
|
||||||
|
case "codex":
|
||||||
|
return extractCodexConfig(body)
|
||||||
|
case "iflow":
|
||||||
|
config := extractIFlowConfig(body)
|
||||||
|
if hasThinkingConfig(config) {
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
return extractOpenAIConfig(body)
|
||||||
|
default:
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasThinkingConfig(config ThinkingConfig) bool {
|
||||||
|
return config.Mode != ModeBudget || config.Budget != 0 || config.Level != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractClaudeConfig extracts thinking configuration from Claude format request body.
|
||||||
|
//
|
||||||
|
// Claude API format:
|
||||||
|
// - thinking.type: "enabled" or "disabled"
|
||||||
|
// - thinking.budget_tokens: integer (-1=auto, 0=disabled, >0=budget)
|
||||||
|
//
|
||||||
|
// Priority: thinking.type="disabled" takes precedence over budget_tokens.
|
||||||
|
// When type="enabled" without budget_tokens, returns ModeAuto to indicate
|
||||||
|
// the user wants thinking enabled but didn't specify a budget.
|
||||||
|
func extractClaudeConfig(body []byte) ThinkingConfig {
|
||||||
|
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||||
|
if thinkingType == "disabled" {
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check budget_tokens
|
||||||
|
if budget := gjson.GetBytes(body, "thinking.budget_tokens"); budget.Exists() {
|
||||||
|
value := int(budget.Int())
|
||||||
|
switch value {
|
||||||
|
case 0:
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
case -1:
|
||||||
|
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
|
||||||
|
default:
|
||||||
|
return ThinkingConfig{Mode: ModeBudget, Budget: value}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If type="enabled" but no budget_tokens, treat as auto (user wants thinking but no budget specified)
|
||||||
|
if thinkingType == "enabled" {
|
||||||
|
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractGeminiConfig extracts thinking configuration from Gemini format request body.
|
||||||
|
//
|
||||||
|
// Gemini API format:
|
||||||
|
// - generationConfig.thinkingConfig.thinkingLevel: "none", "auto", or level name (Gemini 3)
|
||||||
|
// - generationConfig.thinkingConfig.thinkingBudget: integer (Gemini 2.5)
|
||||||
|
//
|
||||||
|
// For gemini-cli and antigravity providers, the path is prefixed with "request.".
|
||||||
|
//
|
||||||
|
// Priority: thinkingLevel is checked first (Gemini 3 format), then thinkingBudget (Gemini 2.5 format).
|
||||||
|
// This allows newer Gemini 3 level-based configs to take precedence.
|
||||||
|
func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
|
||||||
|
prefix := "generationConfig.thinkingConfig"
|
||||||
|
if provider == "gemini-cli" || provider == "antigravity" {
|
||||||
|
prefix = "request.generationConfig.thinkingConfig"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check thinkingLevel first (Gemini 3 format takes precedence)
|
||||||
|
if level := gjson.GetBytes(body, prefix+".thinkingLevel"); level.Exists() {
|
||||||
|
value := level.String()
|
||||||
|
switch value {
|
||||||
|
case "none":
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
case "auto":
|
||||||
|
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
|
||||||
|
default:
|
||||||
|
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check thinkingBudget (Gemini 2.5 format)
|
||||||
|
if budget := gjson.GetBytes(body, prefix+".thinkingBudget"); budget.Exists() {
|
||||||
|
value := int(budget.Int())
|
||||||
|
switch value {
|
||||||
|
case 0:
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
case -1:
|
||||||
|
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
|
||||||
|
default:
|
||||||
|
return ThinkingConfig{Mode: ModeBudget, Budget: value}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractOpenAIConfig extracts thinking configuration from OpenAI format request body.
|
||||||
|
//
|
||||||
|
// OpenAI API format:
|
||||||
|
// - reasoning_effort: "none", "low", "medium", "high" (discrete levels)
|
||||||
|
//
|
||||||
|
// OpenAI uses level-based thinking configuration only, no numeric budget support.
|
||||||
|
// The "none" value is treated specially to return ModeNone.
|
||||||
|
func extractOpenAIConfig(body []byte) ThinkingConfig {
|
||||||
|
// Check reasoning_effort (OpenAI Chat Completions format)
|
||||||
|
if effort := gjson.GetBytes(body, "reasoning_effort"); effort.Exists() {
|
||||||
|
value := effort.String()
|
||||||
|
if value == "none" {
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
}
|
||||||
|
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCodexConfig extracts thinking configuration from Codex format request body.
|
||||||
|
//
|
||||||
|
// Codex API format (OpenAI Responses API):
|
||||||
|
// - reasoning.effort: "none", "low", "medium", "high"
|
||||||
|
//
|
||||||
|
// This is similar to OpenAI but uses nested field "reasoning.effort" instead of "reasoning_effort".
|
||||||
|
func extractCodexConfig(body []byte) ThinkingConfig {
|
||||||
|
// Check reasoning.effort (Codex / OpenAI Responses API format)
|
||||||
|
if effort := gjson.GetBytes(body, "reasoning.effort"); effort.Exists() {
|
||||||
|
value := effort.String()
|
||||||
|
if value == "none" {
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
}
|
||||||
|
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractIFlowConfig extracts thinking configuration from iFlow format request body.
|
||||||
|
//
|
||||||
|
// iFlow API format (supports multiple model families):
|
||||||
|
// - GLM format: chat_template_kwargs.enable_thinking (boolean)
|
||||||
|
// - MiniMax format: reasoning_split (boolean)
|
||||||
|
//
|
||||||
|
// Returns ModeBudget with Budget=1 as a sentinel value indicating "enabled".
|
||||||
|
// The actual budget/configuration is determined by the iFlow applier based on model capabilities.
|
||||||
|
// Budget=1 is used because iFlow models don't use numeric budgets; they only support on/off.
|
||||||
|
func extractIFlowConfig(body []byte) ThinkingConfig {
|
||||||
|
// GLM format: chat_template_kwargs.enable_thinking
|
||||||
|
if enabled := gjson.GetBytes(body, "chat_template_kwargs.enable_thinking"); enabled.Exists() {
|
||||||
|
if enabled.Bool() {
|
||||||
|
// Budget=1 is a sentinel meaning "enabled" (iFlow doesn't use numeric budgets)
|
||||||
|
return ThinkingConfig{Mode: ModeBudget, Budget: 1}
|
||||||
|
}
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MiniMax format: reasoning_split
|
||||||
|
if split := gjson.GetBytes(body, "reasoning_split"); split.Exists() {
|
||||||
|
if split.Bool() {
|
||||||
|
// Budget=1 is a sentinel meaning "enabled" (iFlow doesn't use numeric budgets)
|
||||||
|
return ThinkingConfig{Mode: ModeBudget, Budget: 1}
|
||||||
|
}
|
||||||
|
return ThinkingConfig{Mode: ModeNone, Budget: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ThinkingConfig{}
|
||||||
|
}
|
||||||
142
internal/thinking/convert.go
Normal file
142
internal/thinking/convert.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package thinking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// levelToBudgetMap defines the standard Level → Budget mapping.
|
||||||
|
// All keys are lowercase; lookups should use strings.ToLower.
|
||||||
|
var levelToBudgetMap = map[string]int{
|
||||||
|
"none": 0,
|
||||||
|
"auto": -1,
|
||||||
|
"minimal": 512,
|
||||||
|
"low": 1024,
|
||||||
|
"medium": 8192,
|
||||||
|
"high": 24576,
|
||||||
|
"xhigh": 32768,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertLevelToBudget converts a thinking level to a budget value.
|
||||||
|
//
|
||||||
|
// This is a semantic conversion that maps discrete levels to numeric budgets.
|
||||||
|
// Level matching is case-insensitive.
|
||||||
|
//
|
||||||
|
// Level → Budget mapping:
|
||||||
|
// - none → 0
|
||||||
|
// - auto → -1
|
||||||
|
// - minimal → 512
|
||||||
|
// - low → 1024
|
||||||
|
// - medium → 8192
|
||||||
|
// - high → 24576
|
||||||
|
// - xhigh → 32768
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - budget: The converted budget value
|
||||||
|
// - ok: true if level is valid, false otherwise
|
||||||
|
func ConvertLevelToBudget(level string) (int, bool) {
|
||||||
|
budget, ok := levelToBudgetMap[strings.ToLower(level)]
|
||||||
|
return budget, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// BudgetThreshold constants define the upper bounds for each thinking level.
|
||||||
|
// These are used by ConvertBudgetToLevel for range-based mapping.
|
||||||
|
const (
|
||||||
|
// ThresholdMinimal is the upper bound for "minimal" level (1-512)
|
||||||
|
ThresholdMinimal = 512
|
||||||
|
// ThresholdLow is the upper bound for "low" level (513-1024)
|
||||||
|
ThresholdLow = 1024
|
||||||
|
// ThresholdMedium is the upper bound for "medium" level (1025-8192)
|
||||||
|
ThresholdMedium = 8192
|
||||||
|
// ThresholdHigh is the upper bound for "high" level (8193-24576)
|
||||||
|
ThresholdHigh = 24576
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertBudgetToLevel converts a budget value to the nearest thinking level.
|
||||||
|
//
|
||||||
|
// This is a semantic conversion that maps numeric budgets to discrete levels.
|
||||||
|
// Uses threshold-based mapping for range conversion.
|
||||||
|
//
|
||||||
|
// Budget → Level thresholds:
|
||||||
|
// - -1 → auto
|
||||||
|
// - 0 → none
|
||||||
|
// - 1-512 → minimal
|
||||||
|
// - 513-1024 → low
|
||||||
|
// - 1025-8192 → medium
|
||||||
|
// - 8193-24576 → high
|
||||||
|
// - 24577+ → xhigh
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - level: The converted thinking level string
|
||||||
|
// - ok: true if budget is valid, false for invalid negatives (< -1)
|
||||||
|
func ConvertBudgetToLevel(budget int) (string, bool) {
|
||||||
|
switch {
|
||||||
|
case budget < -1:
|
||||||
|
// Invalid negative values
|
||||||
|
return "", false
|
||||||
|
case budget == -1:
|
||||||
|
return string(LevelAuto), true
|
||||||
|
case budget == 0:
|
||||||
|
return string(LevelNone), true
|
||||||
|
case budget <= ThresholdMinimal:
|
||||||
|
return string(LevelMinimal), true
|
||||||
|
case budget <= ThresholdLow:
|
||||||
|
return string(LevelLow), true
|
||||||
|
case budget <= ThresholdMedium:
|
||||||
|
return string(LevelMedium), true
|
||||||
|
case budget <= ThresholdHigh:
|
||||||
|
return string(LevelHigh), true
|
||||||
|
default:
|
||||||
|
return string(LevelXHigh), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelCapability describes the thinking format support of a model.
|
||||||
|
type ModelCapability int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CapabilityUnknown indicates modelInfo is nil (passthrough behavior, internal use).
|
||||||
|
CapabilityUnknown ModelCapability = iota - 1
|
||||||
|
// CapabilityNone indicates model doesn't support thinking (Thinking is nil).
|
||||||
|
CapabilityNone
|
||||||
|
// CapabilityBudgetOnly indicates the model supports numeric budgets only.
|
||||||
|
CapabilityBudgetOnly
|
||||||
|
// CapabilityLevelOnly indicates the model supports discrete levels only.
|
||||||
|
CapabilityLevelOnly
|
||||||
|
// CapabilityHybrid indicates the model supports both budgets and levels.
|
||||||
|
CapabilityHybrid
|
||||||
|
)
|
||||||
|
|
||||||
|
// detectModelCapability determines the thinking format capability of a model.
|
||||||
|
//
|
||||||
|
// This is an internal function used by validation and conversion helpers.
|
||||||
|
// It analyzes the model's ThinkingSupport configuration to classify the model:
|
||||||
|
// - CapabilityNone: modelInfo.Thinking is nil (model doesn't support thinking)
|
||||||
|
// - CapabilityBudgetOnly: Has Min/Max but no Levels (Claude, Gemini 2.5)
|
||||||
|
// - CapabilityLevelOnly: Has Levels but no Min/Max (OpenAI, iFlow)
|
||||||
|
// - CapabilityHybrid: Has both Min/Max and Levels (Gemini 3)
|
||||||
|
//
|
||||||
|
// Note: Returns a special sentinel value when modelInfo itself is nil (unknown model).
|
||||||
|
func detectModelCapability(modelInfo *registry.ModelInfo) ModelCapability {
|
||||||
|
if modelInfo == nil {
|
||||||
|
return CapabilityUnknown // sentinel for "passthrough" behavior
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return CapabilityNone
|
||||||
|
}
|
||||||
|
support := modelInfo.Thinking
|
||||||
|
hasBudget := support.Min > 0 || support.Max > 0
|
||||||
|
hasLevels := len(support.Levels) > 0
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case hasBudget && hasLevels:
|
||||||
|
return CapabilityHybrid
|
||||||
|
case hasBudget:
|
||||||
|
return CapabilityBudgetOnly
|
||||||
|
case hasLevels:
|
||||||
|
return CapabilityLevelOnly
|
||||||
|
default:
|
||||||
|
return CapabilityNone
|
||||||
|
}
|
||||||
|
}
|
||||||
82
internal/thinking/errors.go
Normal file
82
internal/thinking/errors.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
// Package thinking provides unified thinking configuration processing logic.
|
||||||
|
package thinking
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// ErrorCode represents the type of thinking configuration error.
|
||||||
|
type ErrorCode string
|
||||||
|
|
||||||
|
// Error codes for thinking configuration processing.
|
||||||
|
const (
|
||||||
|
// ErrInvalidSuffix indicates the suffix format cannot be parsed.
|
||||||
|
// Example: "model(abc" (missing closing parenthesis)
|
||||||
|
ErrInvalidSuffix ErrorCode = "INVALID_SUFFIX"
|
||||||
|
|
||||||
|
// ErrUnknownLevel indicates the level value is not in the valid list.
|
||||||
|
// Example: "model(ultra)" where "ultra" is not a valid level
|
||||||
|
ErrUnknownLevel ErrorCode = "UNKNOWN_LEVEL"
|
||||||
|
|
||||||
|
// ErrThinkingNotSupported indicates the model does not support thinking.
|
||||||
|
// Example: claude-haiku-4-5 does not have thinking capability
|
||||||
|
ErrThinkingNotSupported ErrorCode = "THINKING_NOT_SUPPORTED"
|
||||||
|
|
||||||
|
// ErrLevelNotSupported indicates the model does not support level mode.
|
||||||
|
// Example: using level with a budget-only model
|
||||||
|
ErrLevelNotSupported ErrorCode = "LEVEL_NOT_SUPPORTED"
|
||||||
|
|
||||||
|
// ErrBudgetOutOfRange indicates the budget value is outside model range.
|
||||||
|
// Example: budget 64000 exceeds max 20000
|
||||||
|
ErrBudgetOutOfRange ErrorCode = "BUDGET_OUT_OF_RANGE"
|
||||||
|
|
||||||
|
// ErrProviderMismatch indicates the provider does not match the model.
|
||||||
|
// Example: applying Claude format to a Gemini model
|
||||||
|
ErrProviderMismatch ErrorCode = "PROVIDER_MISMATCH"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ThinkingError represents an error that occurred during thinking configuration processing.
|
||||||
|
//
|
||||||
|
// This error type provides structured information about the error, including:
|
||||||
|
// - Code: A machine-readable error code for programmatic handling
|
||||||
|
// - Message: A human-readable description of the error
|
||||||
|
// - Model: The model name related to the error (optional)
|
||||||
|
// - Details: Additional context information (optional)
|
||||||
|
type ThinkingError struct {
|
||||||
|
// Code is the machine-readable error code
|
||||||
|
Code ErrorCode
|
||||||
|
// Message is the human-readable error description.
|
||||||
|
// Should be lowercase, no trailing period, with context if applicable.
|
||||||
|
Message string
|
||||||
|
// Model is the model name related to this error (optional)
|
||||||
|
Model string
|
||||||
|
// Details contains additional context information (optional)
|
||||||
|
Details map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface.
|
||||||
|
// Returns the message directly without code prefix.
|
||||||
|
// Use Code field for programmatic error handling.
|
||||||
|
func (e *ThinkingError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewThinkingError creates a new ThinkingError with the given code and message.
|
||||||
|
func NewThinkingError(code ErrorCode, message string) *ThinkingError {
|
||||||
|
return &ThinkingError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewThinkingErrorWithModel creates a new ThinkingError with model context.
|
||||||
|
func NewThinkingErrorWithModel(code ErrorCode, message, model string) *ThinkingError {
|
||||||
|
return &ThinkingError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusCode implements a portable status code interface for HTTP handlers.
|
||||||
|
func (e *ThinkingError) StatusCode() int {
|
||||||
|
return http.StatusBadRequest
|
||||||
|
}
|
||||||
201
internal/thinking/provider/antigravity/apply.go
Normal file
201
internal/thinking/provider/antigravity/apply.go
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
// Package antigravity implements thinking configuration for Antigravity API format.
|
||||||
|
//
|
||||||
|
// Antigravity uses request.generationConfig.thinkingConfig.* path (same as gemini-cli)
|
||||||
|
// but requires additional normalization for Claude models:
|
||||||
|
// - Ensure thinking budget < max_tokens
|
||||||
|
// - Remove thinkingConfig if budget < minimum allowed
|
||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Applier applies thinking configuration for Antigravity API format.
|
||||||
|
type Applier struct{}
|
||||||
|
|
||||||
|
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||||
|
|
||||||
|
// NewApplier creates a new Antigravity thinking applier.
|
||||||
|
func NewApplier() *Applier {
|
||||||
|
return &Applier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
thinking.RegisterProvider("antigravity", NewApplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies thinking configuration to Antigravity request body.
|
||||||
|
//
|
||||||
|
// For Claude models, additional constraints are applied:
|
||||||
|
// - Ensure thinking budget < max_tokens
|
||||||
|
// - Remove thinkingConfig if budget < minimum allowed
|
||||||
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
|
return a.applyCompatible(body, config, modelInfo)
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
isClaude := strings.Contains(strings.ToLower(modelInfo.ID), "claude")
|
||||||
|
|
||||||
|
// ModeAuto: Always use Budget format with thinkingBudget=-1
|
||||||
|
if config.Mode == thinking.ModeAuto {
|
||||||
|
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
|
||||||
|
}
|
||||||
|
if config.Mode == thinking.ModeBudget {
|
||||||
|
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-auto modes, choose format based on model capabilities
|
||||||
|
support := modelInfo.Thinking
|
||||||
|
if len(support.Levels) > 0 {
|
||||||
|
return a.applyLevelFormat(body, config)
|
||||||
|
}
|
||||||
|
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
isClaude := false
|
||||||
|
if modelInfo != nil {
|
||||||
|
isClaude = strings.Contains(strings.ToLower(modelInfo.ID), "claude")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeAuto {
|
||||||
|
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
|
||||||
|
return a.applyLevelFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||||
|
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||||
|
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||||
|
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeNone {
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
|
||||||
|
if config.Level != "" {
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only handle ModeLevel - budget conversion should be done by upper layer
|
||||||
|
if config.Mode != thinking.ModeLevel {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(config.Level)
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo, isClaude bool) ([]byte, error) {
|
||||||
|
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||||
|
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||||
|
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||||
|
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
|
||||||
|
budget := config.Budget
|
||||||
|
includeThoughts := false
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeNone:
|
||||||
|
includeThoughts = false
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
includeThoughts = true
|
||||||
|
default:
|
||||||
|
includeThoughts = budget > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply Claude-specific constraints
|
||||||
|
if isClaude && modelInfo != nil {
|
||||||
|
budget, result = a.normalizeClaudeBudget(budget, result, modelInfo)
|
||||||
|
// Check if budget was removed entirely
|
||||||
|
if budget == -2 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeClaudeBudget applies Claude-specific constraints to thinking budget.
|
||||||
|
//
|
||||||
|
// It handles:
|
||||||
|
// - Ensuring thinking budget < max_tokens
|
||||||
|
// - Removing thinkingConfig if budget < minimum allowed
|
||||||
|
//
|
||||||
|
// Returns the normalized budget and updated payload.
|
||||||
|
// Returns budget=-2 as a sentinel indicating thinkingConfig was removed entirely.
|
||||||
|
func (a *Applier) normalizeClaudeBudget(budget int, payload []byte, modelInfo *registry.ModelInfo) (int, []byte) {
|
||||||
|
if modelInfo == nil {
|
||||||
|
return budget, payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get effective max tokens
|
||||||
|
effectiveMax, setDefaultMax := a.effectiveMaxTokens(payload, modelInfo)
|
||||||
|
if effectiveMax > 0 && budget >= effectiveMax {
|
||||||
|
budget = effectiveMax - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check minimum budget
|
||||||
|
minBudget := 0
|
||||||
|
if modelInfo.Thinking != nil {
|
||||||
|
minBudget = modelInfo.Thinking.Min
|
||||||
|
}
|
||||||
|
if minBudget > 0 && budget >= 0 && budget < minBudget {
|
||||||
|
// Budget is below minimum, remove thinking config entirely
|
||||||
|
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.thinkingConfig")
|
||||||
|
return -2, payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default max tokens if needed
|
||||||
|
if setDefaultMax && effectiveMax > 0 {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax)
|
||||||
|
}
|
||||||
|
|
||||||
|
return budget, payload
|
||||||
|
}
|
||||||
|
|
||||||
|
// effectiveMaxTokens returns the max tokens to cap thinking:
|
||||||
|
// prefer request-provided maxOutputTokens; otherwise fall back to model default.
|
||||||
|
// The boolean indicates whether the value came from the model default (and thus should be written back).
|
||||||
|
func (a *Applier) effectiveMaxTokens(payload []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
|
||||||
|
if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 {
|
||||||
|
return int(maxTok.Int()), false
|
||||||
|
}
|
||||||
|
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
|
||||||
|
return modelInfo.MaxCompletionTokens, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
166
internal/thinking/provider/claude/apply.go
Normal file
166
internal/thinking/provider/claude/apply.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
// Package claude implements thinking configuration scaffolding for Claude models.
|
||||||
|
//
|
||||||
|
// Claude models use the thinking.budget_tokens format with values in the range
|
||||||
|
// 1024-128000. Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5),
|
||||||
|
// while older models do not.
|
||||||
|
// See: _bmad-output/planning-artifacts/architecture.md#Epic-6
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Applier implements thinking.ProviderApplier for Claude models.
|
||||||
|
// This applier is stateless and holds no configuration.
|
||||||
|
type Applier struct{}
|
||||||
|
|
||||||
|
// NewApplier creates a new Claude thinking applier.
|
||||||
|
func NewApplier() *Applier {
|
||||||
|
return &Applier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
thinking.RegisterProvider("claude", NewApplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies thinking configuration to Claude request body.
|
||||||
|
//
|
||||||
|
// IMPORTANT: This method expects config to be pre-validated by thinking.ValidateConfig.
|
||||||
|
// ValidateConfig handles:
|
||||||
|
// - Mode conversion (Level→Budget, Auto→Budget)
|
||||||
|
// - Budget clamping to model range
|
||||||
|
// - ZeroAllowed constraint enforcement
|
||||||
|
//
|
||||||
|
// Apply only processes ModeBudget and ModeNone; other modes are passed through unchanged.
|
||||||
|
//
|
||||||
|
// Expected output format when enabled:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "thinking": {
|
||||||
|
// "type": "enabled",
|
||||||
|
// "budget_tokens": 16384
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Expected output format when disabled:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "thinking": {
|
||||||
|
// "type": "disabled"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
|
return applyCompatibleClaude(body, config)
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process ModeBudget and ModeNone; other modes pass through
|
||||||
|
// (caller should use ValidateConfig first to normalize modes)
|
||||||
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced)
|
||||||
|
// Decide enabled/disabled based on budget value
|
||||||
|
if config.Budget == 0 {
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
|
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||||
|
|
||||||
|
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
|
||||||
|
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
|
||||||
|
// Anthropic API requires this constraint; violating it returns a 400 error.
|
||||||
|
func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo *registry.ModelInfo) []byte {
|
||||||
|
if budgetTokens <= 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the request satisfies Claude constraints:
|
||||||
|
// 1) Determine effective max_tokens (request overrides model default)
|
||||||
|
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
|
||||||
|
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
|
||||||
|
// 4) If max_tokens came from model default, write it back into the request
|
||||||
|
|
||||||
|
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
|
||||||
|
if setDefaultMax && effectiveMax > 0 {
|
||||||
|
body, _ = sjson.SetBytes(body, "max_tokens", effectiveMax)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the budget we would apply after enforcing budget_tokens < max_tokens.
|
||||||
|
adjustedBudget := budgetTokens
|
||||||
|
if effectiveMax > 0 && adjustedBudget >= effectiveMax {
|
||||||
|
adjustedBudget = effectiveMax - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
minBudget := 0
|
||||||
|
if modelInfo != nil && modelInfo.Thinking != nil {
|
||||||
|
minBudget = modelInfo.Thinking.Min
|
||||||
|
}
|
||||||
|
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
|
||||||
|
// If enforcing the max_tokens constraint would push the budget below the model minimum,
|
||||||
|
// leave the request unchanged.
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if adjustedBudget != budgetTokens {
|
||||||
|
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", adjustedBudget)
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// effectiveMaxTokens returns the max tokens to cap thinking:
|
||||||
|
// prefer request-provided max_tokens; otherwise fall back to model default.
|
||||||
|
// The boolean indicates whether the value came from the model default (and thus should be written back).
|
||||||
|
func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
|
||||||
|
if maxTok := gjson.GetBytes(body, "max_tokens"); maxTok.Exists() && maxTok.Int() > 0 {
|
||||||
|
return int(maxTok.Int()), false
|
||||||
|
}
|
||||||
|
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
|
||||||
|
return modelInfo.MaxCompletionTokens, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeNone:
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
return result, nil
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
|
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||||
|
return result, nil
|
||||||
|
default:
|
||||||
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
|
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
131
internal/thinking/provider/codex/apply.go
Normal file
131
internal/thinking/provider/codex/apply.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
// Package codex implements thinking configuration for Codex (OpenAI Responses API) models.
|
||||||
|
//
|
||||||
|
// Codex models use the reasoning.effort format with discrete levels
|
||||||
|
// (low/medium/high). This is similar to OpenAI but uses nested field
|
||||||
|
// "reasoning.effort" instead of "reasoning_effort".
|
||||||
|
// See: _bmad-output/planning-artifacts/architecture.md#Epic-8
|
||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Applier implements thinking.ProviderApplier for Codex models.
|
||||||
|
//
|
||||||
|
// Codex-specific behavior:
|
||||||
|
// - Output format: reasoning.effort (string: low/medium/high/xhigh)
|
||||||
|
// - Level-only mode: no numeric budget support
|
||||||
|
// - Some models support ZeroAllowed (gpt-5.1, gpt-5.2)
|
||||||
|
type Applier struct{}
|
||||||
|
|
||||||
|
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||||
|
|
||||||
|
// NewApplier creates a new Codex thinking applier.
|
||||||
|
func NewApplier() *Applier {
|
||||||
|
return &Applier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
thinking.RegisterProvider("codex", NewApplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies thinking configuration to Codex request body.
|
||||||
|
//
|
||||||
|
// Expected output format:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "reasoning": {
|
||||||
|
// "effort": "high"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
|
return applyCompatibleCodex(body, config)
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only handle ModeLevel and ModeNone; other modes pass through unchanged.
|
||||||
|
if config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeLevel {
|
||||||
|
result, _ := sjson.SetBytes(body, "reasoning.effort", string(config.Level))
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
effort := ""
|
||||||
|
support := modelInfo.Thinking
|
||||||
|
if config.Budget == 0 {
|
||||||
|
if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) {
|
||||||
|
effort = string(thinking.LevelNone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if effort == "" && config.Level != "" {
|
||||||
|
effort = string(config.Level)
|
||||||
|
}
|
||||||
|
if effort == "" && len(support.Levels) > 0 {
|
||||||
|
effort = support.Levels[0]
|
||||||
|
}
|
||||||
|
if effort == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyCompatibleCodex(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
var effort string
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeLevel:
|
||||||
|
if config.Level == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
effort = string(config.Level)
|
||||||
|
case thinking.ModeNone:
|
||||||
|
effort = string(thinking.LevelNone)
|
||||||
|
if config.Level != "" {
|
||||||
|
effort = string(config.Level)
|
||||||
|
}
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
// Auto mode for user-defined models: pass through as "auto"
|
||||||
|
effort = string(thinking.LevelAuto)
|
||||||
|
case thinking.ModeBudget:
|
||||||
|
// Budget mode: convert budget to level using threshold mapping
|
||||||
|
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||||
|
if !ok {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
effort = level
|
||||||
|
default:
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasLevel(levels []string, target string) bool {
|
||||||
|
for _, level := range levels {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(level), target) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
169
internal/thinking/provider/gemini/apply.go
Normal file
169
internal/thinking/provider/gemini/apply.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
// Package gemini implements thinking configuration for Gemini models.
|
||||||
|
//
|
||||||
|
// Gemini models have two formats:
|
||||||
|
// - Gemini 2.5: Uses thinkingBudget (numeric)
|
||||||
|
// - Gemini 3.x: Uses thinkingLevel (string: minimal/low/medium/high)
|
||||||
|
// or thinkingBudget=-1 for auto/dynamic mode
|
||||||
|
//
|
||||||
|
// Output format is determined by ThinkingConfig.Mode and ThinkingSupport.Levels:
|
||||||
|
// - ModeAuto: Always uses thinkingBudget=-1 (both Gemini 2.5 and 3.x)
|
||||||
|
// - len(Levels) > 0: Uses thinkingLevel (Gemini 3.x discrete levels)
|
||||||
|
// - len(Levels) == 0: Uses thinkingBudget (Gemini 2.5)
|
||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Applier applies thinking configuration for Gemini models.
|
||||||
|
//
|
||||||
|
// Gemini-specific behavior:
|
||||||
|
// - Gemini 2.5: thinkingBudget format, flash series supports ZeroAllowed
|
||||||
|
// - Gemini 3.x: thinkingLevel format, cannot be disabled
|
||||||
|
// - Use ThinkingSupport.Levels to decide output format
|
||||||
|
type Applier struct{}
|
||||||
|
|
||||||
|
// NewApplier creates a new Gemini thinking applier.
|
||||||
|
func NewApplier() *Applier {
|
||||||
|
return &Applier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
thinking.RegisterProvider("gemini", NewApplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies thinking configuration to Gemini request body.
|
||||||
|
//
|
||||||
|
// Expected output format (Gemini 2.5):
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "generationConfig": {
|
||||||
|
// "thinkingConfig": {
|
||||||
|
// "thinkingBudget": 8192,
|
||||||
|
// "includeThoughts": true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Expected output format (Gemini 3.x):
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "generationConfig": {
|
||||||
|
// "thinkingConfig": {
|
||||||
|
// "thinkingLevel": "high",
|
||||||
|
// "includeThoughts": true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
|
return a.applyCompatible(body, config)
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choose format based on config.Mode and model capabilities:
|
||||||
|
// - ModeLevel: use Level format (validation will reject unsupported levels)
|
||||||
|
// - ModeNone: use Level format if model has Levels, else Budget format
|
||||||
|
// - ModeBudget/ModeAuto: use Budget format
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeLevel:
|
||||||
|
return a.applyLevelFormat(body, config)
|
||||||
|
case thinking.ModeNone:
|
||||||
|
// ModeNone: route based on model capability (has Levels or not)
|
||||||
|
if len(modelInfo.Thinking.Levels) > 0 {
|
||||||
|
return a.applyLevelFormat(body, config)
|
||||||
|
}
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
default:
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeAuto {
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
|
||||||
|
return a.applyLevelFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
// ModeNone semantics:
|
||||||
|
// - ModeNone + Budget=0: completely disable thinking (not possible for Level-only models)
|
||||||
|
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
|
||||||
|
// ValidateConfig sets config.Level to the lowest level when ModeNone + Budget > 0.
|
||||||
|
|
||||||
|
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||||
|
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingBudget")
|
||||||
|
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||||
|
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeNone {
|
||||||
|
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", false)
|
||||||
|
if config.Level != "" {
|
||||||
|
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only handle ModeLevel - budget conversion should be done by upper layer
|
||||||
|
if config.Mode != thinking.ModeLevel {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(config.Level)
|
||||||
|
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", level)
|
||||||
|
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||||
|
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingLevel")
|
||||||
|
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||||
|
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
|
||||||
|
budget := config.Budget
|
||||||
|
// ModeNone semantics:
|
||||||
|
// - ModeNone + Budget=0: completely disable thinking
|
||||||
|
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
|
||||||
|
// When ZeroAllowed=false, ValidateConfig clamps Budget to Min while preserving ModeNone.
|
||||||
|
includeThoughts := false
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeNone:
|
||||||
|
includeThoughts = false
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
includeThoughts = true
|
||||||
|
default:
|
||||||
|
includeThoughts = budget > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
|
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
126
internal/thinking/provider/geminicli/apply.go
Normal file
126
internal/thinking/provider/geminicli/apply.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
// Package geminicli implements thinking configuration for Gemini CLI API format.
|
||||||
|
//
|
||||||
|
// Gemini CLI uses request.generationConfig.thinkingConfig.* path instead of
|
||||||
|
// generationConfig.thinkingConfig.* used by standard Gemini API.
|
||||||
|
package geminicli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Applier applies thinking configuration for Gemini CLI API format.
|
||||||
|
type Applier struct{}
|
||||||
|
|
||||||
|
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||||
|
|
||||||
|
// NewApplier creates a new Gemini CLI thinking applier.
|
||||||
|
func NewApplier() *Applier {
|
||||||
|
return &Applier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
thinking.RegisterProvider("gemini-cli", NewApplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies thinking configuration to Gemini CLI request body.
|
||||||
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
|
return a.applyCompatible(body, config)
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModeAuto: Always use Budget format with thinkingBudget=-1
|
||||||
|
if config.Mode == thinking.ModeAuto {
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
}
|
||||||
|
if config.Mode == thinking.ModeBudget {
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-auto modes, choose format based on model capabilities
|
||||||
|
support := modelInfo.Thinking
|
||||||
|
if len(support.Levels) > 0 {
|
||||||
|
return a.applyLevelFormat(body, config)
|
||||||
|
}
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeAuto {
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
|
||||||
|
return a.applyLevelFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.applyBudgetFormat(body, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||||
|
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||||
|
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||||
|
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeNone {
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
|
||||||
|
if config.Level != "" {
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only handle ModeLevel - budget conversion should be done by upper layer
|
||||||
|
if config.Mode != thinking.ModeLevel {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(config.Level)
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||||
|
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||||
|
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||||
|
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
|
||||||
|
budget := config.Budget
|
||||||
|
includeThoughts := false
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeNone:
|
||||||
|
includeThoughts = false
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
includeThoughts = true
|
||||||
|
default:
|
||||||
|
includeThoughts = budget > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
|
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
173
internal/thinking/provider/iflow/apply.go
Normal file
173
internal/thinking/provider/iflow/apply.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
// Package iflow implements thinking configuration for iFlow models.
|
||||||
|
//
|
||||||
|
// iFlow models use boolean toggle semantics:
|
||||||
|
// - Models using chat_template_kwargs.enable_thinking (boolean toggle)
|
||||||
|
// - MiniMax models: reasoning_split (boolean)
|
||||||
|
//
|
||||||
|
// Level values are converted to boolean: none=false, all others=true
|
||||||
|
// See: _bmad-output/planning-artifacts/architecture.md#Epic-9
|
||||||
|
package iflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Applier implements thinking.ProviderApplier for iFlow models.
|
||||||
|
//
|
||||||
|
// iFlow-specific behavior:
|
||||||
|
// - enable_thinking toggle models: enable_thinking boolean
|
||||||
|
// - GLM models: enable_thinking boolean + clear_thinking=false
|
||||||
|
// - MiniMax models: reasoning_split boolean
|
||||||
|
// - Level to boolean: none=false, others=true
|
||||||
|
// - No quantized support (only on/off)
|
||||||
|
type Applier struct{}
|
||||||
|
|
||||||
|
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||||
|
|
||||||
|
// NewApplier creates a new iFlow thinking applier.
|
||||||
|
func NewApplier() *Applier {
|
||||||
|
return &Applier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
thinking.RegisterProvider("iflow", NewApplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies thinking configuration to iFlow request body.
|
||||||
|
//
|
||||||
|
// Expected output format (GLM):
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "chat_template_kwargs": {
|
||||||
|
// "enable_thinking": true,
|
||||||
|
// "clear_thinking": false
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Expected output format (MiniMax):
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "reasoning_split": true
|
||||||
|
// }
|
||||||
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isEnableThinkingModel(modelInfo.ID) {
|
||||||
|
return applyEnableThinking(body, config, isGLMModel(modelInfo.ID)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isMiniMaxModel(modelInfo.ID) {
|
||||||
|
return applyMiniMax(body, config), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configToBoolean converts ThinkingConfig to boolean for iFlow models.
|
||||||
|
//
|
||||||
|
// Conversion rules:
|
||||||
|
// - ModeNone: false
|
||||||
|
// - ModeAuto: true
|
||||||
|
// - ModeBudget + Budget=0: false
|
||||||
|
// - ModeBudget + Budget>0: true
|
||||||
|
// - ModeLevel + Level="none": false
|
||||||
|
// - ModeLevel + any other level: true
|
||||||
|
// - Default (unknown mode): true
|
||||||
|
func configToBoolean(config thinking.ThinkingConfig) bool {
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeNone:
|
||||||
|
return false
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
return true
|
||||||
|
case thinking.ModeBudget:
|
||||||
|
return config.Budget > 0
|
||||||
|
case thinking.ModeLevel:
|
||||||
|
return config.Level != thinking.LevelNone
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyEnableThinking applies thinking configuration for models that use
|
||||||
|
// chat_template_kwargs.enable_thinking format.
|
||||||
|
//
|
||||||
|
// Output format when enabled:
|
||||||
|
//
|
||||||
|
// {"chat_template_kwargs": {"enable_thinking": true, "clear_thinking": false}}
|
||||||
|
//
|
||||||
|
// Output format when disabled:
|
||||||
|
//
|
||||||
|
// {"chat_template_kwargs": {"enable_thinking": false}}
|
||||||
|
//
|
||||||
|
// Note: clear_thinking is only set for GLM models when thinking is enabled.
|
||||||
|
func applyEnableThinking(body []byte, config thinking.ThinkingConfig, setClearThinking bool) []byte {
|
||||||
|
enableThinking := configToBoolean(config)
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||||
|
|
||||||
|
// clear_thinking is a GLM-only knob, strip it for other models.
|
||||||
|
result, _ = sjson.DeleteBytes(result, "chat_template_kwargs.clear_thinking")
|
||||||
|
|
||||||
|
// clear_thinking only needed when thinking is enabled
|
||||||
|
if enableThinking && setClearThinking {
|
||||||
|
result, _ = sjson.SetBytes(result, "chat_template_kwargs.clear_thinking", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyMiniMax applies thinking configuration for MiniMax models.
|
||||||
|
//
|
||||||
|
// Output format:
|
||||||
|
//
|
||||||
|
// {"reasoning_split": true/false}
|
||||||
|
func applyMiniMax(body []byte, config thinking.ThinkingConfig) []byte {
|
||||||
|
reasoningSplit := configToBoolean(config)
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := sjson.SetBytes(body, "reasoning_split", reasoningSplit)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// isEnableThinkingModel determines if the model uses chat_template_kwargs.enable_thinking format.
|
||||||
|
func isEnableThinkingModel(modelID string) bool {
|
||||||
|
if isGLMModel(modelID) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
id := strings.ToLower(modelID)
|
||||||
|
switch id {
|
||||||
|
case "qwen3-max-preview", "deepseek-v3.2", "deepseek-v3.1":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isGLMModel determines if the model is a GLM series model.
|
||||||
|
func isGLMModel(modelID string) bool {
|
||||||
|
return strings.HasPrefix(strings.ToLower(modelID), "glm")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isMiniMaxModel determines if the model is a MiniMax series model.
|
||||||
|
// MiniMax models use reasoning_split format.
|
||||||
|
func isMiniMaxModel(modelID string) bool {
|
||||||
|
return strings.HasPrefix(strings.ToLower(modelID), "minimax")
|
||||||
|
}
|
||||||
128
internal/thinking/provider/openai/apply.go
Normal file
128
internal/thinking/provider/openai/apply.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
// Package openai implements thinking configuration for OpenAI/Codex models.
|
||||||
|
//
|
||||||
|
// OpenAI models use the reasoning_effort format with discrete levels
|
||||||
|
// (low/medium/high). Some models support xhigh and none levels.
|
||||||
|
// See: _bmad-output/planning-artifacts/architecture.md#Epic-8
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||||
|
//
|
||||||
|
// OpenAI-specific behavior:
|
||||||
|
// - Output format: reasoning_effort (string: low/medium/high/xhigh)
|
||||||
|
// - Level-only mode: no numeric budget support
|
||||||
|
// - Some models support ZeroAllowed (gpt-5.1, gpt-5.2)
|
||||||
|
type Applier struct{}
|
||||||
|
|
||||||
|
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||||
|
|
||||||
|
// NewApplier creates a new OpenAI thinking applier.
|
||||||
|
func NewApplier() *Applier {
|
||||||
|
return &Applier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
thinking.RegisterProvider("openai", NewApplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies thinking configuration to OpenAI request body.
|
||||||
|
//
|
||||||
|
// Expected output format:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "reasoning_effort": "high"
|
||||||
|
// }
|
||||||
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
|
return applyCompatibleOpenAI(body, config)
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only handle ModeLevel and ModeNone; other modes pass through unchanged.
|
||||||
|
if config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == thinking.ModeLevel {
|
||||||
|
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
effort := ""
|
||||||
|
support := modelInfo.Thinking
|
||||||
|
if config.Budget == 0 {
|
||||||
|
if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) {
|
||||||
|
effort = string(thinking.LevelNone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if effort == "" && config.Level != "" {
|
||||||
|
effort = string(config.Level)
|
||||||
|
}
|
||||||
|
if effort == "" && len(support.Levels) > 0 {
|
||||||
|
effort = support.Levels[0]
|
||||||
|
}
|
||||||
|
if effort == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
var effort string
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeLevel:
|
||||||
|
if config.Level == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
effort = string(config.Level)
|
||||||
|
case thinking.ModeNone:
|
||||||
|
effort = string(thinking.LevelNone)
|
||||||
|
if config.Level != "" {
|
||||||
|
effort = string(config.Level)
|
||||||
|
}
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
// Auto mode for user-defined models: pass through as "auto"
|
||||||
|
effort = string(thinking.LevelAuto)
|
||||||
|
case thinking.ModeBudget:
|
||||||
|
// Budget mode: convert budget to level using threshold mapping
|
||||||
|
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||||
|
if !ok {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
effort = level
|
||||||
|
default:
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasLevel(levels []string, target string) bool {
|
||||||
|
for _, level := range levels {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(level), target) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
58
internal/thinking/strip.go
Normal file
58
internal/thinking/strip.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
// Package thinking provides unified thinking configuration processing.
|
||||||
|
package thinking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StripThinkingConfig removes thinking configuration fields from request body.
|
||||||
|
//
|
||||||
|
// This function is used when a model doesn't support thinking but the request
|
||||||
|
// contains thinking configuration. The configuration is silently removed to
|
||||||
|
// prevent upstream API errors.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - body: Original request body JSON
|
||||||
|
// - provider: Provider name (determines which fields to strip)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - Modified request body JSON with thinking configuration removed
|
||||||
|
// - Original body is returned unchanged if:
|
||||||
|
// - body is empty or invalid JSON
|
||||||
|
// - provider is unknown
|
||||||
|
// - no thinking configuration found
|
||||||
|
func StripThinkingConfig(body []byte, provider string) []byte {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
var paths []string
|
||||||
|
switch provider {
|
||||||
|
case "claude":
|
||||||
|
paths = []string{"thinking"}
|
||||||
|
case "gemini":
|
||||||
|
paths = []string{"generationConfig.thinkingConfig"}
|
||||||
|
case "gemini-cli", "antigravity":
|
||||||
|
paths = []string{"request.generationConfig.thinkingConfig"}
|
||||||
|
case "openai":
|
||||||
|
paths = []string{"reasoning_effort"}
|
||||||
|
case "codex":
|
||||||
|
paths = []string{"reasoning.effort"}
|
||||||
|
case "iflow":
|
||||||
|
paths = []string{
|
||||||
|
"chat_template_kwargs.enable_thinking",
|
||||||
|
"chat_template_kwargs.clear_thinking",
|
||||||
|
"reasoning_split",
|
||||||
|
"reasoning_effort",
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
result := body
|
||||||
|
for _, path := range paths {
|
||||||
|
result, _ = sjson.DeleteBytes(result, path)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
146
internal/thinking/suffix.go
Normal file
146
internal/thinking/suffix.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
// Package thinking provides unified thinking configuration processing.
|
||||||
|
//
|
||||||
|
// This file implements suffix parsing functionality for extracting
|
||||||
|
// thinking configuration from model names in the format model(value).
|
||||||
|
package thinking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseSuffix extracts thinking suffix from a model name.
|
||||||
|
//
|
||||||
|
// The suffix format is: model-name(value)
|
||||||
|
// Examples:
|
||||||
|
// - "claude-sonnet-4-5(16384)" -> ModelName="claude-sonnet-4-5", RawSuffix="16384"
|
||||||
|
// - "gpt-5.2(high)" -> ModelName="gpt-5.2", RawSuffix="high"
|
||||||
|
// - "gemini-2.5-pro" -> ModelName="gemini-2.5-pro", HasSuffix=false
|
||||||
|
//
|
||||||
|
// This function only extracts the suffix; it does not validate or interpret
|
||||||
|
// the suffix content. Use ParseNumericSuffix, ParseLevelSuffix, etc. for
|
||||||
|
// content interpretation.
|
||||||
|
func ParseSuffix(model string) SuffixResult {
|
||||||
|
// Find the last opening parenthesis
|
||||||
|
lastOpen := strings.LastIndex(model, "(")
|
||||||
|
if lastOpen == -1 {
|
||||||
|
return SuffixResult{ModelName: model, HasSuffix: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the string ends with a closing parenthesis
|
||||||
|
if !strings.HasSuffix(model, ")") {
|
||||||
|
return SuffixResult{ModelName: model, HasSuffix: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract components
|
||||||
|
modelName := model[:lastOpen]
|
||||||
|
rawSuffix := model[lastOpen+1 : len(model)-1]
|
||||||
|
|
||||||
|
return SuffixResult{
|
||||||
|
ModelName: modelName,
|
||||||
|
HasSuffix: true,
|
||||||
|
RawSuffix: rawSuffix,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseNumericSuffix attempts to parse a raw suffix as a numeric budget value.
|
||||||
|
//
|
||||||
|
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as an integer.
|
||||||
|
// Only non-negative integers are considered valid numeric suffixes.
|
||||||
|
//
|
||||||
|
// Platform note: The budget value uses Go's int type, which is 32-bit on 32-bit
|
||||||
|
// systems and 64-bit on 64-bit systems. Values exceeding the platform's int range
|
||||||
|
// will return ok=false.
|
||||||
|
//
|
||||||
|
// Leading zeros are accepted: "08192" parses as 8192.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// - "8192" -> budget=8192, ok=true
|
||||||
|
// - "0" -> budget=0, ok=true (represents ModeNone)
|
||||||
|
// - "08192" -> budget=8192, ok=true (leading zeros accepted)
|
||||||
|
// - "-1" -> budget=0, ok=false (negative numbers are not valid numeric suffixes)
|
||||||
|
// - "high" -> budget=0, ok=false (not a number)
|
||||||
|
// - "9223372036854775808" -> budget=0, ok=false (overflow on 64-bit systems)
|
||||||
|
//
|
||||||
|
// For special handling of -1 as auto mode, use ParseSpecialSuffix instead.
|
||||||
|
func ParseNumericSuffix(rawSuffix string) (budget int, ok bool) {
|
||||||
|
if rawSuffix == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := strconv.Atoi(rawSuffix)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Negative numbers are not valid numeric suffixes
|
||||||
|
// -1 should be handled by special value parsing as "auto"
|
||||||
|
if value < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseSpecialSuffix attempts to parse a raw suffix as a special thinking mode value.
|
||||||
|
//
|
||||||
|
// This function handles special strings that represent a change in thinking mode:
|
||||||
|
// - "none" -> ModeNone (disables thinking)
|
||||||
|
// - "auto" -> ModeAuto (automatic/dynamic thinking)
|
||||||
|
// - "-1" -> ModeAuto (numeric representation of auto mode)
|
||||||
|
//
|
||||||
|
// String values are case-insensitive.
|
||||||
|
func ParseSpecialSuffix(rawSuffix string) (mode ThinkingMode, ok bool) {
|
||||||
|
if rawSuffix == "" {
|
||||||
|
return ModeBudget, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case-insensitive matching
|
||||||
|
switch strings.ToLower(rawSuffix) {
|
||||||
|
case "none":
|
||||||
|
return ModeNone, true
|
||||||
|
case "auto", "-1":
|
||||||
|
return ModeAuto, true
|
||||||
|
default:
|
||||||
|
return ModeBudget, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseLevelSuffix attempts to parse a raw suffix as a discrete thinking level.
|
||||||
|
//
|
||||||
|
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as a level.
|
||||||
|
// Only discrete effort levels are valid: minimal, low, medium, high, xhigh.
|
||||||
|
// Level matching is case-insensitive.
|
||||||
|
//
|
||||||
|
// Special values (none, auto) are NOT handled by this function; use ParseSpecialSuffix
|
||||||
|
// instead. This separation allows callers to prioritize special value handling.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// - "high" -> level=LevelHigh, ok=true
|
||||||
|
// - "HIGH" -> level=LevelHigh, ok=true (case insensitive)
|
||||||
|
// - "medium" -> level=LevelMedium, ok=true
|
||||||
|
// - "none" -> level="", ok=false (special value, use ParseSpecialSuffix)
|
||||||
|
// - "auto" -> level="", ok=false (special value, use ParseSpecialSuffix)
|
||||||
|
// - "8192" -> level="", ok=false (numeric, use ParseNumericSuffix)
|
||||||
|
// - "ultra" -> level="", ok=false (unknown level)
|
||||||
|
func ParseLevelSuffix(rawSuffix string) (level ThinkingLevel, ok bool) {
|
||||||
|
if rawSuffix == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case-insensitive matching
|
||||||
|
switch strings.ToLower(rawSuffix) {
|
||||||
|
case "minimal":
|
||||||
|
return LevelMinimal, true
|
||||||
|
case "low":
|
||||||
|
return LevelLow, true
|
||||||
|
case "medium":
|
||||||
|
return LevelMedium, true
|
||||||
|
case "high":
|
||||||
|
return LevelHigh, true
|
||||||
|
case "xhigh":
|
||||||
|
return LevelXHigh, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
41
internal/thinking/text.go
Normal file
41
internal/thinking/text.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package thinking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetThinkingText extracts the thinking text from a content part.
|
||||||
|
// Handles various formats:
|
||||||
|
// - Simple string: { "thinking": "text" } or { "text": "text" }
|
||||||
|
// - Wrapped object: { "thinking": { "text": "text", "cache_control": {...} } }
|
||||||
|
// - Gemini-style: { "thought": true, "text": "text" }
|
||||||
|
// Returns the extracted text string.
|
||||||
|
func GetThinkingText(part gjson.Result) string {
|
||||||
|
// Try direct text field first (Gemini-style)
|
||||||
|
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||||
|
return text.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try thinking field
|
||||||
|
thinkingField := part.Get("thinking")
|
||||||
|
if !thinkingField.Exists() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// thinking is a string
|
||||||
|
if thinkingField.Type == gjson.String {
|
||||||
|
return thinkingField.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// thinking is an object with inner text/thinking
|
||||||
|
if thinkingField.IsObject() {
|
||||||
|
if inner := thinkingField.Get("text"); inner.Exists() && inner.Type == gjson.String {
|
||||||
|
return inner.String()
|
||||||
|
}
|
||||||
|
if inner := thinkingField.Get("thinking"); inner.Exists() && inner.Type == gjson.String {
|
||||||
|
return inner.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
116
internal/thinking/types.go
Normal file
116
internal/thinking/types.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
// Package thinking provides unified thinking configuration processing.
|
||||||
|
//
|
||||||
|
// This package offers a unified interface for parsing, validating, and applying
|
||||||
|
// thinking configurations across various AI providers (Claude, Gemini, OpenAI, iFlow).
|
||||||
|
package thinking
|
||||||
|
|
||||||
|
import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
|
||||||
|
// ThinkingMode represents the type of thinking configuration mode.
|
||||||
|
type ThinkingMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ModeBudget indicates using a numeric budget (corresponds to suffix "(1000)" etc.)
|
||||||
|
ModeBudget ThinkingMode = iota
|
||||||
|
// ModeLevel indicates using a discrete level (corresponds to suffix "(high)" etc.)
|
||||||
|
ModeLevel
|
||||||
|
// ModeNone indicates thinking is disabled (corresponds to suffix "(none)" or budget=0)
|
||||||
|
ModeNone
|
||||||
|
// ModeAuto indicates automatic/dynamic thinking (corresponds to suffix "(auto)" or budget=-1)
|
||||||
|
ModeAuto
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns the string representation of ThinkingMode.
|
||||||
|
func (m ThinkingMode) String() string {
|
||||||
|
switch m {
|
||||||
|
case ModeBudget:
|
||||||
|
return "budget"
|
||||||
|
case ModeLevel:
|
||||||
|
return "level"
|
||||||
|
case ModeNone:
|
||||||
|
return "none"
|
||||||
|
case ModeAuto:
|
||||||
|
return "auto"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ThinkingLevel represents a discrete thinking level.
|
||||||
|
type ThinkingLevel string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LevelNone disables thinking
|
||||||
|
LevelNone ThinkingLevel = "none"
|
||||||
|
// LevelAuto enables automatic/dynamic thinking
|
||||||
|
LevelAuto ThinkingLevel = "auto"
|
||||||
|
// LevelMinimal sets minimal thinking effort
|
||||||
|
LevelMinimal ThinkingLevel = "minimal"
|
||||||
|
// LevelLow sets low thinking effort
|
||||||
|
LevelLow ThinkingLevel = "low"
|
||||||
|
// LevelMedium sets medium thinking effort
|
||||||
|
LevelMedium ThinkingLevel = "medium"
|
||||||
|
// LevelHigh sets high thinking effort
|
||||||
|
LevelHigh ThinkingLevel = "high"
|
||||||
|
// LevelXHigh sets extra-high thinking effort
|
||||||
|
LevelXHigh ThinkingLevel = "xhigh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ThinkingConfig represents a unified thinking configuration.
|
||||||
|
//
|
||||||
|
// This struct is used to pass thinking configuration information between components.
|
||||||
|
// Depending on Mode, either Budget or Level field is effective:
|
||||||
|
// - ModeNone: Budget=0, Level is ignored
|
||||||
|
// - ModeAuto: Budget=-1, Level is ignored
|
||||||
|
// - ModeBudget: Budget is a positive integer, Level is ignored
|
||||||
|
// - ModeLevel: Budget is ignored, Level is a valid level
|
||||||
|
type ThinkingConfig struct {
|
||||||
|
// Mode specifies the configuration mode
|
||||||
|
Mode ThinkingMode
|
||||||
|
// Budget is the thinking budget (token count), only effective when Mode is ModeBudget.
|
||||||
|
// Special values: 0 means disabled, -1 means automatic
|
||||||
|
Budget int
|
||||||
|
// Level is the thinking level, only effective when Mode is ModeLevel
|
||||||
|
Level ThinkingLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
// SuffixResult represents the result of parsing a model name for thinking suffix.
|
||||||
|
//
|
||||||
|
// A thinking suffix is specified in the format model-name(value), where value
|
||||||
|
// can be a numeric budget (e.g., "16384") or a level name (e.g., "high").
|
||||||
|
type SuffixResult struct {
|
||||||
|
// ModelName is the model name with the suffix removed.
|
||||||
|
// If no suffix was found, this equals the original input.
|
||||||
|
ModelName string
|
||||||
|
|
||||||
|
// HasSuffix indicates whether a valid suffix was found.
|
||||||
|
HasSuffix bool
|
||||||
|
|
||||||
|
// RawSuffix is the content inside the parentheses, without the parentheses.
|
||||||
|
// Empty string if HasSuffix is false.
|
||||||
|
RawSuffix string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderApplier defines the interface for provider-specific thinking configuration application.
|
||||||
|
//
|
||||||
|
// Types implementing this interface are responsible for converting a unified ThinkingConfig
|
||||||
|
// into provider-specific format and applying it to the request body.
|
||||||
|
//
|
||||||
|
// Implementation requirements:
|
||||||
|
// - Apply method must be idempotent
|
||||||
|
// - Must not modify the input config or modelInfo
|
||||||
|
// - Returns a modified copy of the request body
|
||||||
|
// - Returns appropriate ThinkingError for unsupported configurations
|
||||||
|
type ProviderApplier interface {
|
||||||
|
// Apply applies the thinking configuration to the request body.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - body: Original request body JSON
|
||||||
|
// - config: Unified thinking configuration
|
||||||
|
// - modelInfo: Model registry information containing ThinkingSupport properties
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - Modified request body JSON
|
||||||
|
// - ThinkingError if the configuration is invalid or unsupported
|
||||||
|
Apply(body []byte, config ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error)
|
||||||
|
}
|
||||||
378
internal/thinking/validate.go
Normal file
378
internal/thinking/validate.go
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
// Package thinking provides unified thinking configuration processing logic.
|
||||||
|
package thinking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateConfig validates a thinking configuration against model capabilities.
|
||||||
|
//
|
||||||
|
// This function performs comprehensive validation:
|
||||||
|
// - Checks if the model supports thinking
|
||||||
|
// - Auto-converts between Budget and Level formats based on model capability
|
||||||
|
// - Validates that requested level is in the model's supported levels list
|
||||||
|
// - Clamps budget values to model's allowed range
|
||||||
|
// - When converting Budget -> Level for level-only models, clamps the derived standard level to the nearest supported level
|
||||||
|
// (special values none/auto are preserved)
|
||||||
|
// - When config comes from a model suffix, strict budget validation is disabled (we clamp instead of error)
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - config: The thinking configuration to validate
|
||||||
|
// - support: Model's ThinkingSupport properties (nil means no thinking support)
|
||||||
|
// - fromFormat: Source provider format (used to determine strict validation rules)
|
||||||
|
// - toFormat: Target provider format
|
||||||
|
// - fromSuffix: Whether config was sourced from model suffix
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - Normalized ThinkingConfig with clamped values
|
||||||
|
// - ThinkingError if validation fails (ErrThinkingNotSupported, ErrLevelNotSupported, etc.)
|
||||||
|
//
|
||||||
|
// Auto-conversion behavior:
|
||||||
|
// - Budget-only model + Level config → Level converted to Budget
|
||||||
|
// - Level-only model + Budget config → Budget converted to Level
|
||||||
|
// - Hybrid model → preserve original format
|
||||||
|
func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, fromFormat, toFormat string, fromSuffix bool) (*ThinkingConfig, error) {
|
||||||
|
fromFormat, toFormat = strings.ToLower(strings.TrimSpace(fromFormat)), strings.ToLower(strings.TrimSpace(toFormat))
|
||||||
|
model := "unknown"
|
||||||
|
support := (*registry.ThinkingSupport)(nil)
|
||||||
|
if modelInfo != nil {
|
||||||
|
if modelInfo.ID != "" {
|
||||||
|
model = modelInfo.ID
|
||||||
|
}
|
||||||
|
support = modelInfo.Thinking
|
||||||
|
}
|
||||||
|
|
||||||
|
if support == nil {
|
||||||
|
if config.Mode != ModeNone {
|
||||||
|
return nil, NewThinkingErrorWithModel(ErrThinkingNotSupported, "thinking not supported for this model", model)
|
||||||
|
}
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
allowClampUnsupported := isBudgetBasedProvider(fromFormat) && isLevelBasedProvider(toFormat)
|
||||||
|
strictBudget := !fromSuffix && fromFormat != "" && isSameProviderFamily(fromFormat, toFormat)
|
||||||
|
budgetDerivedFromLevel := false
|
||||||
|
|
||||||
|
capability := detectModelCapability(modelInfo)
|
||||||
|
switch capability {
|
||||||
|
case CapabilityBudgetOnly:
|
||||||
|
if config.Mode == ModeLevel {
|
||||||
|
if config.Level == LevelAuto {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
budget, ok := ConvertLevelToBudget(string(config.Level))
|
||||||
|
if !ok {
|
||||||
|
return nil, NewThinkingError(ErrUnknownLevel, fmt.Sprintf("unknown level: %s", config.Level))
|
||||||
|
}
|
||||||
|
config.Mode = ModeBudget
|
||||||
|
config.Budget = budget
|
||||||
|
config.Level = ""
|
||||||
|
budgetDerivedFromLevel = true
|
||||||
|
}
|
||||||
|
case CapabilityLevelOnly:
|
||||||
|
if config.Mode == ModeBudget {
|
||||||
|
level, ok := ConvertBudgetToLevel(config.Budget)
|
||||||
|
if !ok {
|
||||||
|
return nil, NewThinkingError(ErrUnknownLevel, fmt.Sprintf("budget %d cannot be converted to a valid level", config.Budget))
|
||||||
|
}
|
||||||
|
// When converting Budget -> Level for level-only models, clamp the derived standard level
|
||||||
|
// to the nearest supported level. Special values (none/auto) are preserved.
|
||||||
|
config.Mode = ModeLevel
|
||||||
|
config.Level = clampLevel(ThinkingLevel(level), modelInfo, toFormat)
|
||||||
|
config.Budget = 0
|
||||||
|
}
|
||||||
|
case CapabilityHybrid:
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == ModeLevel && config.Level == LevelNone {
|
||||||
|
config.Mode = ModeNone
|
||||||
|
config.Budget = 0
|
||||||
|
config.Level = ""
|
||||||
|
}
|
||||||
|
if config.Mode == ModeLevel && config.Level == LevelAuto {
|
||||||
|
config.Mode = ModeAuto
|
||||||
|
config.Budget = -1
|
||||||
|
config.Level = ""
|
||||||
|
}
|
||||||
|
if config.Mode == ModeBudget && config.Budget == 0 {
|
||||||
|
config.Mode = ModeNone
|
||||||
|
config.Level = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(support.Levels) > 0 && config.Mode == ModeLevel {
|
||||||
|
if !isLevelSupported(string(config.Level), support.Levels) {
|
||||||
|
if allowClampUnsupported {
|
||||||
|
config.Level = clampLevel(config.Level, modelInfo, toFormat)
|
||||||
|
}
|
||||||
|
if !isLevelSupported(string(config.Level), support.Levels) {
|
||||||
|
// User explicitly specified an unsupported level - return error
|
||||||
|
// (budget-derived levels may be clamped based on source format)
|
||||||
|
validLevels := normalizeLevels(support.Levels)
|
||||||
|
message := fmt.Sprintf("level %q not supported, valid levels: %s", strings.ToLower(string(config.Level)), strings.Join(validLevels, ", "))
|
||||||
|
return nil, NewThinkingError(ErrLevelNotSupported, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strictBudget && config.Mode == ModeBudget && !budgetDerivedFromLevel {
|
||||||
|
min, max := support.Min, support.Max
|
||||||
|
if min != 0 || max != 0 {
|
||||||
|
if config.Budget < min || config.Budget > max || (config.Budget == 0 && !support.ZeroAllowed) {
|
||||||
|
message := fmt.Sprintf("budget %d out of range [%d,%d]", config.Budget, min, max)
|
||||||
|
return nil, NewThinkingError(ErrBudgetOutOfRange, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert ModeAuto to mid-range if dynamic not allowed
|
||||||
|
if config.Mode == ModeAuto && !support.DynamicAllowed {
|
||||||
|
config = convertAutoToMidRange(config, support, toFormat, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Mode == ModeNone && toFormat == "claude" {
|
||||||
|
// Claude supports explicit disable via thinking.type="disabled".
|
||||||
|
// Keep Budget=0 so applier can omit budget_tokens.
|
||||||
|
config.Budget = 0
|
||||||
|
config.Level = ""
|
||||||
|
} else {
|
||||||
|
switch config.Mode {
|
||||||
|
case ModeBudget, ModeAuto, ModeNone:
|
||||||
|
config.Budget = clampBudget(config.Budget, modelInfo, toFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModeNone with clamped Budget > 0: set Level to lowest for Level-only/Hybrid models
|
||||||
|
// This ensures Apply layer doesn't need to access support.Levels
|
||||||
|
if config.Mode == ModeNone && config.Budget > 0 && len(support.Levels) > 0 {
|
||||||
|
config.Level = ThinkingLevel(support.Levels[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertAutoToMidRange converts ModeAuto to a mid-range value when dynamic is not allowed.
|
||||||
|
//
|
||||||
|
// This function handles the case where a model does not support dynamic/auto thinking.
|
||||||
|
// The auto mode is silently converted to a fixed value based on model capability:
|
||||||
|
// - Level-only models: convert to ModeLevel with LevelMedium
|
||||||
|
// - Budget models: convert to ModeBudget with mid = (Min + Max) / 2
|
||||||
|
//
|
||||||
|
// Logging:
|
||||||
|
// - Debug level when conversion occurs
|
||||||
|
// - Fields: original_mode, clamped_to, reason
|
||||||
|
func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupport, provider, model string) ThinkingConfig {
|
||||||
|
// For level-only models (has Levels but no Min/Max range), use ModeLevel with medium
|
||||||
|
if len(support.Levels) > 0 && support.Min == 0 && support.Max == 0 {
|
||||||
|
config.Mode = ModeLevel
|
||||||
|
config.Level = LevelMedium
|
||||||
|
config.Budget = 0
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": provider,
|
||||||
|
"model": model,
|
||||||
|
"original_mode": "auto",
|
||||||
|
"clamped_to": string(LevelMedium),
|
||||||
|
}).Debug("thinking: mode converted, dynamic not allowed, using medium level |")
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// For budget models, use mid-range budget
|
||||||
|
mid := (support.Min + support.Max) / 2
|
||||||
|
if mid <= 0 && support.ZeroAllowed {
|
||||||
|
config.Mode = ModeNone
|
||||||
|
config.Budget = 0
|
||||||
|
} else if mid <= 0 {
|
||||||
|
config.Mode = ModeBudget
|
||||||
|
config.Budget = support.Min
|
||||||
|
} else {
|
||||||
|
config.Mode = ModeBudget
|
||||||
|
config.Budget = mid
|
||||||
|
}
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": provider,
|
||||||
|
"model": model,
|
||||||
|
"original_mode": "auto",
|
||||||
|
"clamped_to": config.Budget,
|
||||||
|
}).Debug("thinking: mode converted, dynamic not allowed |")
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// standardLevelOrder defines the canonical ordering of thinking levels from lowest to highest.
|
||||||
|
var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh}
|
||||||
|
|
||||||
|
// clampLevel clamps the given level to the nearest supported level.
|
||||||
|
// On tie, prefers the lower level.
|
||||||
|
func clampLevel(level ThinkingLevel, modelInfo *registry.ModelInfo, provider string) ThinkingLevel {
|
||||||
|
model := "unknown"
|
||||||
|
var supported []string
|
||||||
|
if modelInfo != nil {
|
||||||
|
if modelInfo.ID != "" {
|
||||||
|
model = modelInfo.ID
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking != nil {
|
||||||
|
supported = modelInfo.Thinking.Levels
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(supported) == 0 || isLevelSupported(string(level), supported) {
|
||||||
|
return level
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := levelIndex(string(level))
|
||||||
|
if pos == -1 {
|
||||||
|
return level
|
||||||
|
}
|
||||||
|
bestIdx, bestDist := -1, len(standardLevelOrder)+1
|
||||||
|
|
||||||
|
for _, s := range supported {
|
||||||
|
if idx := levelIndex(strings.TrimSpace(s)); idx != -1 {
|
||||||
|
if dist := abs(pos - idx); dist < bestDist || (dist == bestDist && idx < bestIdx) {
|
||||||
|
bestIdx, bestDist = idx, dist
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bestIdx >= 0 {
|
||||||
|
clamped := standardLevelOrder[bestIdx]
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": provider,
|
||||||
|
"model": model,
|
||||||
|
"original_value": string(level),
|
||||||
|
"clamped_to": string(clamped),
|
||||||
|
}).Debug("thinking: level clamped |")
|
||||||
|
return clamped
|
||||||
|
}
|
||||||
|
return level
|
||||||
|
}
|
||||||
|
|
||||||
|
// clampBudget clamps a budget value to the model's supported range.
|
||||||
|
func clampBudget(value int, modelInfo *registry.ModelInfo, provider string) int {
|
||||||
|
model := "unknown"
|
||||||
|
support := (*registry.ThinkingSupport)(nil)
|
||||||
|
if modelInfo != nil {
|
||||||
|
if modelInfo.ID != "" {
|
||||||
|
model = modelInfo.ID
|
||||||
|
}
|
||||||
|
support = modelInfo.Thinking
|
||||||
|
}
|
||||||
|
if support == nil {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto value (-1) passes through without clamping.
|
||||||
|
if value == -1 {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
min, max := support.Min, support.Max
|
||||||
|
if value == 0 && !support.ZeroAllowed {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": provider,
|
||||||
|
"model": model,
|
||||||
|
"original_value": value,
|
||||||
|
"clamped_to": min,
|
||||||
|
"min": min,
|
||||||
|
"max": max,
|
||||||
|
}).Warn("thinking: budget zero not allowed |")
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some models are level-only and do not define numeric budget ranges.
|
||||||
|
if min == 0 && max == 0 {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
if value < min {
|
||||||
|
if value == 0 && support.ZeroAllowed {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
logClamp(provider, model, value, min, min, max)
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
if value > max {
|
||||||
|
logClamp(provider, model, value, max, min, max)
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLevelSupported(level string, supported []string) bool {
|
||||||
|
for _, s := range supported {
|
||||||
|
if strings.EqualFold(level, strings.TrimSpace(s)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func levelIndex(level string) int {
|
||||||
|
for i, l := range standardLevelOrder {
|
||||||
|
if strings.EqualFold(level, string(l)) {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeLevels(levels []string) []string {
|
||||||
|
out := make([]string, len(levels))
|
||||||
|
for i, l := range levels {
|
||||||
|
out[i] = strings.ToLower(strings.TrimSpace(l))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBudgetBasedProvider(provider string) bool {
|
||||||
|
switch provider {
|
||||||
|
case "gemini", "gemini-cli", "antigravity", "claude":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLevelBasedProvider(provider string) bool {
|
||||||
|
switch provider {
|
||||||
|
case "openai", "openai-response", "codex":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isGeminiFamily(provider string) bool {
|
||||||
|
switch provider {
|
||||||
|
case "gemini", "gemini-cli", "antigravity":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSameProviderFamily(from, to string) bool {
|
||||||
|
if from == to {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return isGeminiFamily(from) && isGeminiFamily(to)
|
||||||
|
}
|
||||||
|
|
||||||
|
func abs(x int) int {
|
||||||
|
if x < 0 {
|
||||||
|
return -x
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
func logClamp(provider, model string, original, clampedTo, min, max int) {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"provider": provider,
|
||||||
|
"model": model,
|
||||||
|
"original_value": original,
|
||||||
|
"min": min,
|
||||||
|
"max": max,
|
||||||
|
"clamped_to": clampedTo,
|
||||||
|
}).Debug("thinking: budget clamped |")
|
||||||
|
}
|
||||||
@@ -7,41 +7,16 @@ package claude
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// deriveSessionID generates a stable session ID from the request.
|
|
||||||
// Uses the hash of the first user message to identify the conversation.
|
|
||||||
func deriveSessionID(rawJSON []byte) string {
|
|
||||||
messages := gjson.GetBytes(rawJSON, "messages")
|
|
||||||
if !messages.IsArray() {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
for _, msg := range messages.Array() {
|
|
||||||
if msg.Get("role").String() == "user" {
|
|
||||||
content := msg.Get("content").String()
|
|
||||||
if content == "" {
|
|
||||||
// Try to get text from content array
|
|
||||||
content = msg.Get("content.0.text").String()
|
|
||||||
}
|
|
||||||
if content != "" {
|
|
||||||
h := sha256.Sum256([]byte(content))
|
|
||||||
return hex.EncodeToString(h[:16])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
|
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
|
||||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||||
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
|
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
|
||||||
@@ -61,11 +36,9 @@ func deriveSessionID(rawJSON []byte) string {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []byte: The transformed request data in Gemini CLI API format
|
// - []byte: The transformed request data in Gemini CLI API format
|
||||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||||
|
enableThoughtTranslate := true
|
||||||
rawJSON := bytes.Clone(inputRawJSON)
|
rawJSON := bytes.Clone(inputRawJSON)
|
||||||
|
|
||||||
// Derive session ID for signature caching
|
|
||||||
sessionID := deriveSessionID(rawJSON)
|
|
||||||
|
|
||||||
// system instruction
|
// system instruction
|
||||||
systemInstructionJSON := ""
|
systemInstructionJSON := ""
|
||||||
hasSystemInstruction := false
|
hasSystemInstruction := false
|
||||||
@@ -123,43 +96,50 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
contentTypeResult := contentResult.Get("type")
|
contentTypeResult := contentResult.Get("type")
|
||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
||||||
// Use GetThinkingText to handle wrapped thinking objects
|
// Use GetThinkingText to handle wrapped thinking objects
|
||||||
thinkingText := util.GetThinkingText(contentResult)
|
thinkingText := thinking.GetThinkingText(contentResult)
|
||||||
signatureResult := contentResult.Get("signature")
|
|
||||||
clientSignature := ""
|
|
||||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
|
||||||
clientSignature = signatureResult.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always try cached signature first (more reliable than client-provided)
|
// Always try cached signature first (more reliable than client-provided)
|
||||||
// Client may send stale or invalid signatures from different sessions
|
// Client may send stale or invalid signatures from different sessions
|
||||||
signature := ""
|
signature := ""
|
||||||
if sessionID != "" && thinkingText != "" {
|
if thinkingText != "" {
|
||||||
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
|
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
|
||||||
signature = cachedSig
|
signature = cachedSig
|
||||||
log.Debugf("Using cached signature for thinking block")
|
// log.Debugf("Using cached signature for thinking block")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to client signature only if cache miss and client signature is valid
|
// Fallback to client signature only if cache miss and client signature is valid
|
||||||
if signature == "" && cache.HasValidSignature(clientSignature) {
|
if signature == "" {
|
||||||
signature = clientSignature
|
signatureResult := contentResult.Get("signature")
|
||||||
log.Debugf("Using client-provided signature for thinking block")
|
clientSignature := ""
|
||||||
|
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||||
|
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
|
||||||
|
if len(arrayClientSignatures) == 2 {
|
||||||
|
if modelName == arrayClientSignatures[0] {
|
||||||
|
clientSignature = arrayClientSignatures[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cache.HasValidSignature(modelName, clientSignature) {
|
||||||
|
signature = clientSignature
|
||||||
|
}
|
||||||
|
// log.Debugf("Using client-provided signature for thinking block")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store for subsequent tool_use in the same message
|
// Store for subsequent tool_use in the same message
|
||||||
if cache.HasValidSignature(signature) {
|
if cache.HasValidSignature(modelName, signature) {
|
||||||
currentMessageThinkingSignature = signature
|
currentMessageThinkingSignature = signature
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip trailing unsigned thinking blocks on last assistant message
|
// Skip trailing unsigned thinking blocks on last assistant message
|
||||||
isUnsigned := !cache.HasValidSignature(signature)
|
isUnsigned := !cache.HasValidSignature(modelName, signature)
|
||||||
|
|
||||||
// If unsigned, skip entirely (don't convert to text)
|
// If unsigned, skip entirely (don't convert to text)
|
||||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||||
// Converting to text would break this requirement
|
// Converting to text would break this requirement
|
||||||
if isUnsigned {
|
if isUnsigned {
|
||||||
// TypeScript plugin approach: drop unsigned thinking blocks entirely
|
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||||
log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
enableThoughtTranslate = false
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,15 +155,17 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||||
prompt := contentResult.Get("text").String()
|
prompt := contentResult.Get("text").String()
|
||||||
partJSON := `{}`
|
// Skip empty text parts to avoid Gemini API error:
|
||||||
if prompt != "" {
|
// "required oneof field 'data' must have one initialized field"
|
||||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
if prompt == "" {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
partJSON := `{}`
|
||||||
|
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||||
// Antigravity API validates signatures, so dummy values are rejected.
|
// Antigravity API validates signatures, so dummy values are rejected.
|
||||||
// The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies.
|
|
||||||
|
|
||||||
functionName := contentResult.Get("name").String()
|
functionName := contentResult.Get("name").String()
|
||||||
argsResult := contentResult.Get("input")
|
argsResult := contentResult.Get("input")
|
||||||
@@ -208,7 +190,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||||
// and also works for Claude through Antigravity API
|
// and also works for Claude through Antigravity API
|
||||||
const skipSentinel = "skip_thought_signature_validator"
|
const skipSentinel = "skip_thought_signature_validator"
|
||||||
if cache.HasValidSignature(currentMessageThinkingSignature) {
|
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
|
||||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||||
} else {
|
} else {
|
||||||
// No valid signature - use skip sentinel to bypass validation
|
// No valid signature - use skip sentinel to bypass validation
|
||||||
@@ -306,6 +288,13 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip messages with empty parts array to avoid Gemini API error:
|
||||||
|
// "required oneof field 'data' must have one initialized field"
|
||||||
|
partsCheck := gjson.Get(clientContentJSON, "parts")
|
||||||
|
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||||
hasContents = true
|
hasContents = true
|
||||||
} else if contentsResult.Type == gjson.String {
|
} else if contentsResult.Type == gjson.String {
|
||||||
@@ -388,12 +377,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) {
|
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||||
if t.Get("type").String() == "enabled" {
|
if t.Get("type").String() == "enabled" {
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -73,30 +74,41 @@ func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Valid signature must be at least 50 characters
|
// Valid signature must be at least 50 characters
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Let me think..."
|
||||||
|
|
||||||
|
// Pre-cache the signature (simulating a previous response for the same thinking text)
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||||
{"type": "text", "text": "Answer"}
|
{"type": "text", "text": "Answer"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Check thinking block conversion
|
// Check thinking block conversion (now in contents.1 due to user message)
|
||||||
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
|
firstPart := gjson.Get(outputStr, "request.contents.1.parts.0")
|
||||||
if !firstPart.Get("thought").Bool() {
|
if !firstPart.Get("thought").Bool() {
|
||||||
t.Error("thinking block should have thought: true")
|
t.Error("thinking block should have thought: true")
|
||||||
}
|
}
|
||||||
if firstPart.Get("text").String() != "Let me think..." {
|
if firstPart.Get("text").String() != thinkingText {
|
||||||
t.Error("thinking text mismatch")
|
t.Error("thinking text mismatch")
|
||||||
}
|
}
|
||||||
if firstPart.Get("thoughtSignature").String() != validSignature {
|
if firstPart.Get("thoughtSignature").String() != validSignature {
|
||||||
@@ -105,6 +117,8 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Unsigned thinking blocks should be removed entirely (not converted to text)
|
// Unsigned thinking blocks should be removed entirely (not converted to text)
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
@@ -226,14 +240,22 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Let me think..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||||
{
|
{
|
||||||
"type": "tool_use",
|
"type": "tool_use",
|
||||||
"id": "call_123",
|
"id": "call_123",
|
||||||
@@ -245,11 +267,13 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
|||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Check function call has the signature from the preceding thinking block
|
// Check function call has the signature from the preceding thinking block (now in contents.1)
|
||||||
part := gjson.Get(outputStr, "request.contents.0.parts.1")
|
part := gjson.Get(outputStr, "request.contents.1.parts.1")
|
||||||
if part.Get("functionCall.name").String() != "get_weather" {
|
if part.Get("functionCall.name").String() != "get_weather" {
|
||||||
t.Errorf("Expected functionCall, got %s", part.Raw)
|
t.Errorf("Expected functionCall, got %s", part.Raw)
|
||||||
}
|
}
|
||||||
@@ -259,26 +283,36 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Case: text block followed by thinking block -> should be reordered to thinking first
|
// Case: text block followed by thinking block -> should be reordered to thinking first
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Planning..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "Here is the plan."},
|
{"type": "text", "text": "Here is the plan."},
|
||||||
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Verify order: Thinking block MUST be first
|
// Verify order: Thinking block MUST be first (now in contents.1 due to user message)
|
||||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||||
}
|
}
|
||||||
@@ -343,8 +377,8 @@ func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
|||||||
if thinkingConfig.Get("thinkingBudget").Int() != 8000 {
|
if thinkingConfig.Get("thinkingBudget").Int() != 8000 {
|
||||||
t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int())
|
t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int())
|
||||||
}
|
}
|
||||||
if !thinkingConfig.Get("include_thoughts").Bool() {
|
if !thinkingConfig.Get("includeThoughts").Bool() {
|
||||||
t.Error("include_thoughts should be true")
|
t.Error("includeThoughts should be true")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Log("thinkingConfig not present - model may not be registered in test registry")
|
t.Log("thinkingConfig not present - model may not be registered in test registry")
|
||||||
@@ -459,7 +493,12 @@ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
||||||
|
cache.ClearSignatureCache("")
|
||||||
|
|
||||||
// Last assistant message ends with signed thinking block - should be kept
|
// Last assistant message ends with signed thinking block - should be kept
|
||||||
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Valid thinking..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -471,12 +510,14 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin
|
|||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "Here is my answer"},
|
{"type": "text", "text": "Here is my answer"},
|
||||||
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ type Params struct {
|
|||||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||||
|
|
||||||
// Signature caching support
|
// Signature caching support
|
||||||
SessionID string // Session ID derived from request for signature caching
|
|
||||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,9 +69,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
HasFirstResponse: false,
|
HasFirstResponse: false,
|
||||||
ResponseType: 0,
|
ResponseType: 0,
|
||||||
ResponseIndex: 0,
|
ResponseIndex: 0,
|
||||||
SessionID: deriveSessionID(originalRequestRawJSON),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||||
|
|
||||||
params := (*param).(*Params)
|
params := (*param).(*Params)
|
||||||
|
|
||||||
@@ -136,16 +135,16 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
// Process thinking content (internal reasoning)
|
// Process thinking content (internal reasoning)
|
||||||
if partResult.Get("thought").Bool() {
|
if partResult.Get("thought").Bool() {
|
||||||
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||||
log.Debug("Branch: signature_delta")
|
// log.Debug("Branch: signature_delta")
|
||||||
|
|
||||||
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
|
if params.CurrentThinkingText.Len() > 0 {
|
||||||
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
|
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||||
log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
// log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len())
|
||||||
params.CurrentThinkingText.Reset()
|
params.CurrentThinkingText.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
output = output + "event: content_block_delta\n"
|
output = output + "event: content_block_delta\n"
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
params.HasContent = true
|
params.HasContent = true
|
||||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||||
@@ -372,7 +371,7 @@ func resolveStopReason(params *Params) string {
|
|||||||
// - string: A Claude-compatible JSON response.
|
// - string: A Claude-compatible JSON response.
|
||||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||||
_ = originalRequestRawJSON
|
_ = originalRequestRawJSON
|
||||||
_ = requestRawJSON
|
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||||
|
|
||||||
root := gjson.ParseBytes(rawJSON)
|
root := gjson.ParseBytes(rawJSON)
|
||||||
promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
|
promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
|
||||||
@@ -437,7 +436,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
|||||||
block := `{"type":"thinking","thinking":""}`
|
block := `{"type":"thinking","thinking":""}`
|
||||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||||
if thinkingSignature != "" {
|
if thinkingSignature != "" {
|
||||||
block, _ = sjson.Set(block, "signature", thinkingSignature)
|
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||||
}
|
}
|
||||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||||
thinkingBuilder.Reset()
|
thinkingBuilder.Reset()
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user