Compare commits

..

135 Commits

Author SHA1 Message Date
Luis Pater
bc3195c8d8 refactor(logger): remove unnecessary request details limit logic 2026-01-10 14:46:59 +08:00
Luis Pater
4d7f389b69 Fixed: #941
fix(translator): ensure fallback to valid originalRequestRawJSON in response handling
2026-01-10 01:01:09 +08:00
Luis Pater
95f87d5669 Merge pull request #947 from pykancha/fix-memory-leak
Resolve memory leaks causing OOM in k8s deployment
2026-01-10 00:40:47 +08:00
Luis Pater
c83365a349 Merge pull request #938 from router-for-me/log
refactor(logging): clean up oauth logs and debugs
2026-01-10 00:02:45 +08:00
Luis Pater
6b3604cf2b Merge pull request #943 from ben-vargas/fix-tool-mappings
Fix Claude OAuth tool name mapping (proxy_)
2026-01-09 23:52:29 +08:00
Luis Pater
af6bdca14f Fixed: #942
fix(executor): ignore non-SSE lines in OpenAI-compatible streams
2026-01-09 23:41:50 +08:00
hemanta212
1c773c428f fix: Remove investigation artifacts 2026-01-09 17:47:59 +05:45
Ben Vargas
e785bfcd12 Use unprefixed Claude request for translation
Keep the upstream payload prefixed for OAuth while passing the unprefixed request body into response translators. This avoids proxy_ leaking into OpenAI Responses echoed tool metadata while preserving the Claude OAuth workaround.
2026-01-09 00:54:35 -07:00
hemanta212
47dacce6ea fix(server): resolve memory leaks causing OOM in k8s deployment
- usage/logger_plugin: cap modelStats.Details at 1000 entries per model
- cache/signature_cache: add background cleanup for expired sessions (10 min)
- management/handler: add background cleanup for stale IP rate-limit entries (1 hr)
- executor/cache_helpers: add mutex protection and TTL cleanup for codexCacheMap (15 min)
- executor/codex_executor: use thread-safe cache accessors

Add reproduction tests demonstrating leak behavior before/after fixes.

Amp-Thread-ID: https://ampcode.com/threads/T-019ba0fc-1d7b-7338-8e1d-ca0520412777
Co-authored-by: Amp <amp@ampcode.com>
2026-01-09 13:33:46 +05:45
Ben Vargas
dcac3407ab Fix Claude OAuth tool name mapping
Prefix tool names with proxy_ for Claude OAuth requests and strip the prefix from streaming and non-streaming responses to restore client-facing names.

Updates the Claude executor to:
- add prefixing for tools, tool_choice, and tool_use messages when using OAuth tokens
- strip the prefix from tool_use events in SSE and non-streaming payloads
- add focused unit tests for prefix/strip helpers
2026-01-09 00:10:38 -07:00
hkfires
7004295e1d build(docker): move stats export execution after image build 2026-01-09 11:24:00 +08:00
hkfires
ee62ef4745 refactor(logging): clean up oauth logs and debugs 2026-01-09 11:20:55 +08:00
Luis Pater
ef6bafbf7e fix(executor): handle context cancellation and deadline errors explicitly 2026-01-09 10:48:29 +08:00
Luis Pater
ed28b71e87 refactor(amp): remove duplicate comments in response rewriter 2026-01-09 08:21:13 +08:00
Luis Pater
d47b7dc79a refactor(response): enhance parameter handling for Codex to Claude conversion 2026-01-09 05:20:19 +08:00
Luis Pater
49b9709ce5 Merge pull request #787 from sususu98/fix/antigravity-429-retry-delay-parsing
fix(antigravity): parse retry-after delay from 429 response body
2026-01-09 04:45:25 +08:00
Luis Pater
a2eba2cdf5 Merge pull request #763 from mvelbaum/feature/improve-oauth-use-logging
feat(logging): disambiguate OAuth credential selection in debug logs
2026-01-09 04:43:21 +08:00
Luis Pater
3d01b3cfe8 Merge pull request #553 from XInTheDark/fix/builtin-tools-web-search
fix(translator): preserve built-in tools (web_search) to Responses API
2026-01-09 04:40:13 +08:00
Luis Pater
af2efa6f7e Merge pull request #605 from soilSpoon/feature/amp-compat
feature: Improves Amp client compatibility
2026-01-09 04:28:17 +08:00
Luis Pater
d73b61d367 Merge pull request #901 from uzhao/vscode-plugin
Vscode plugin
2026-01-08 22:22:27 +08:00
Luis Pater
59a448b645 feat(executor): centralize systemInstruction handling for Claude and Gemini-3-Pro models 2026-01-08 21:05:33 +08:00
Chén Mù
4adb9eed77 Merge pull request #921 from router-for-me/atgy
fix(executor): update gemini model identifier to gemini-3-pro-preview
2026-01-08 19:20:32 +08:00
hkfires
b6a0f7a07f fix(executor): update gemini model identifier to gemini-3-pro-preview
Update the model name check in `buildRequest` to target "gemini-3-pro-preview" instead of "gemini-3-pro" when applying specific system instruction handling.
2026-01-08 19:14:52 +08:00
Luis Pater
1b2f907671 feat(executor): update system instruction handling for Claude and Gemini-3-Pro models 2026-01-08 12:42:26 +08:00
Luis Pater
bda04eed8a feat(executor): add model-specific support for "gemini-3-pro" in execution and payload handling 2026-01-08 12:27:03 +08:00
Luis Pater
67985d8226 feat(executor): enhance Antigravity payload with user role and dynamic system instructions 2026-01-08 10:55:25 +08:00
Jianyang Zhao
cbcb061812 Update README_CN.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-07 20:07:01 -05:00
Jianyang Zhao
9fc2e1b3c8 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-07 20:06:55 -05:00
Jianyang Zhao
3b484aea9e Add Claude Proxy VSCode to README_CN.md
Added information about Claude Proxy VSCode extension.
2026-01-07 20:03:07 -05:00
Jianyang Zhao
963a0950fa Add Claude Proxy VSCode extension to README
Added Claude Proxy VSCode extension to the README.
2026-01-07 20:02:50 -05:00
Luis Pater
f4ba1ab910 fix(executor): remove unused tokenRefreshTimeout constant and pass zero timeout to HTTP client 2026-01-07 18:16:49 +08:00
Luis Pater
2662f91082 feat(management): add PostOAuthCallback handler to token requester interface 2026-01-07 10:47:32 +08:00
Luis Pater
c1db2c7d7c Merge pull request #888 from router-for-me/api-call-TOKEN-fix
fix(management): refresh antigravity token for api-call $TOKEN$
2026-01-07 01:19:24 +08:00
LTbinglingfeng
5e5d8142f9 fix(auth): error when antigravity refresh token missing during refresh 2026-01-07 01:09:50 +08:00
LTbinglingfeng
b01619b441 fix(management): refresh antigravity token for api-call $TOKEN$ 2026-01-07 00:14:02 +08:00
Luis Pater
f861bd6a94 docs: add 9Router to community projects in README 2026-01-06 23:15:28 +08:00
Luis Pater
6dbfdd140d Merge pull request #871 from decolua/patch-1
Update README.md
2026-01-06 22:58:53 +08:00
decolua
386ccffed4 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 20:54:33 +07:00
decolua
ffddd1c90a Update README.md 2026-01-05 20:29:26 +07:00
Luis Pater
8f8dfd081b Merge pull request #850 from can1357/main
feat(translator): add developer role support for Gemini translators
2026-01-05 11:27:24 +08:00
Luis Pater
9f1b445c7c docs: add ProxyPilot to community projects in Chinese README 2026-01-05 11:23:48 +08:00
Luis Pater
ae933dfe14 Merge pull request #858 from Finesssee/add-proxypilot
docs: add ProxyPilot to community projects
2026-01-05 11:20:52 +08:00
Luis Pater
e124db723b Merge pull request #862 from router-for-me/gemini
fix(gemini): abort default injection on existing thinking keys
2026-01-05 10:41:07 +08:00
hkfires
05444cf32d fix(gemini): abort default injection on existing thinking keys 2026-01-05 10:24:30 +08:00
Luis Pater
8edbda57cf feat(translator): add thoughtSignature to node parts for Gemini and Antigravity requests
Enhanced node structure by including `thoughtSignature` for inline data parts in Gemini OpenAI, Gemini CLI, and Antigravity request handlers to improve traceability of thought processes.
2026-01-05 09:25:17 +08:00
Finessse
821249a5ed docs: add ProxyPilot to community projects 2026-01-04 18:19:41 +07:00
Luis Pater
ee33863b47 Merge pull request #857 from router-for-me/management-update
Management update
2026-01-04 18:07:13 +08:00
Supra4E8C
cd22c849e2 feat(management): 更新OAuth模型映射的清理逻辑以增强数据安全性 2026-01-04 17:57:34 +08:00
Supra4E8C
f0e73efda2 feat(management): add vertex api key and oauth model mappings endpoints 2026-01-04 17:32:00 +08:00
Supra4E8C
3156109c71 feat(management): 支持管理接口调整日志大小/强制前缀/路由策略 2026-01-04 12:21:49 +08:00
can1357
6762e081f3 feat(translator): add developer role support for Gemini translators
Treat OpenAI's "developer" role the same as "system" role in request
translation for gemini, gemini-cli, and antigravity backends.
2026-01-03 21:01:01 +01:00
Luis Pater
7815ee338d fix(translator): adjust message_delta emission boundary in Claude-to-OpenAI conversion
Fixed incorrect boundary logic for `message_delta` emission, ensuring proper handling of usage updates and `emitMessageStopIfNeeded` within the response loop.
2026-01-04 01:36:51 +08:00
Luis Pater
44b6c872e2 feat(config): add support for Fork in OAuth model mappings with alias handling
Implemented `Fork` flag in `ModelNameMapping` to allow aliases as additional models while preserving the original model ID. Updated the `applyOAuthModelMappings` logic, added tests for `Fork` behavior, and updated documentation and examples accordingly.
2026-01-04 01:18:29 +08:00
Luis Pater
7a77b23f2d feat(executor): add token refresh timeout and improve context handling during refresh
Introduced `tokenRefreshTimeout` constant for token refresh operations and enhanced context propagation for `refreshToken` by embedding roundtrip information if available. Adjusted `refreshAuth` to ensure default context initialization and handle cancellation errors appropriately.
2026-01-04 00:26:08 +08:00
Luis Pater
672e8549c0 docs: reorganize README to adjust CodMate placement
Moved CodMate entry under ProxyPal in both English and Chinese README files for consistency in structure and better readability.
2026-01-03 21:31:53 +08:00
Luis Pater
66f5269a23 Merge pull request #837 from loocor/main
docs: add CodMate to community projects
2026-01-03 21:30:15 +08:00
Luis Pater
ebec293497 feat(api): integrate TokenStore for improved auth entry management
Replaced file-based auth entry counting with `TokenStore`-backed implementation, enhancing flexibility and context-aware token management. Updated related logic to reflect this change.
2026-01-03 04:53:47 +08:00
Luis Pater
e02ceecd35 feat(registry): introduce ModelRegistryHook for monitoring model registrations and unregistrations
Added support for external hooks to observe model registry events using the `ModelRegistryHook` interface. Implemented thread-safe, non-blocking execution of hooks with panic recovery. Comprehensive tests added to verify hook behavior during registration, unregistration, blocking, and panic scenarios.
2026-01-02 23:18:40 +08:00
Luis Pater
c8b33a8cc3 Merge pull request #824 from router-for-me/script
feat(script): add usage statistics preservation across container rebuilds
2026-01-02 20:42:25 +08:00
Loocor
dca8d5ded8 Add CodMate app information to README
Added CodMate section to README with app details.
2026-01-02 17:15:38 +08:00
Loocor
2a7fd1e897 Add CodMate description to README_CN.md
添加 CodMate 应用的描述,提供 CLI AI 会话管理功能。
2026-01-02 17:15:09 +08:00
Luis Pater
b9d1e70ac2 Merge pull request #830 from router-for-me/gemini
fix(util): disable default thinking for gemini-3 series
2026-01-02 10:59:24 +08:00
hkfires
fdf5720217 fix(gemini): remove default thinking for gemini 3 models 2026-01-02 10:55:59 +08:00
hkfires
f40bd0cd51 feat(script): add usage statistics preservation across container rebuilds 2026-01-02 10:01:20 +08:00
hkfires
e33676bb87 fix(util): disable default thinking for gemini-3 series 2026-01-02 09:43:40 +08:00
Luis Pater
2a663d5cba feat(executor): enhance payload translation with original request context
Refactored `applyPayloadConfig` to `applyPayloadConfigWithRoot`, adding support for default rule validation against the original payload when available. Updated all executors to use `applyPayloadConfigWithRoot` and incorporate an optional original request payload for translations.
2026-01-02 00:03:26 +08:00
Luis Pater
750b930679 Merge pull request #823 from router-for-me/translator
feat(translator): enhance Claude-to-OpenAI conversion with thinking block and tool result handling
2026-01-01 20:16:10 +08:00
hkfires
3902fd7501 fix(iflow): remove thinking field from request body in thinking config handler 2026-01-01 19:40:28 +08:00
hkfires
4fc3d5e935 refactor(iflow): simplify thinking config handling for GLM and MiniMax models 2026-01-01 19:31:08 +08:00
hkfires
2d2f4572a7 fix(translator): remove unnecessary whitespace trimming in reasoning text collection 2026-01-01 12:39:09 +08:00
hkfires
8f4c46f38d fix(translator): emit tool_result messages before user content in Claude-to-OpenAI conversion 2026-01-01 11:11:43 +08:00
hkfires
b6ba51bc2a feat(translator): add thinking block and tool result handling for Claude-to-OpenAI conversion 2026-01-01 09:41:25 +08:00
Luis Pater
6a66d32d37 Merge pull request #803 from HsnSaboor/fix-invalid-function-names-sanitization-v2
feat(translator): resolve invalid function name errors by sanitizing Claude tool names
2026-01-01 01:15:50 +08:00
Luis Pater
8d15723195 feat(registry): add GetAvailableModelsByProvider method for retrieving models by provider 2025-12-31 23:37:46 +08:00
Chén Mù
736e0aae86 Merge pull request #814 from router-for-me/aistudio
Fix model alias thinking suffix
2025-12-31 03:08:05 -08:00
hkfires
8bf3305b2b fix(thinking): fallback to upstream model for thinking support when alias not in registry 2025-12-31 18:07:13 +08:00
hkfires
d00e3ea973 feat(thinking): add numeric budget to thinkingLevel conversion fallback 2025-12-31 17:14:47 +08:00
hkfires
89db4e9481 fix(thinking): use model alias for thinking config resolution in mapped models 2025-12-31 17:09:22 +08:00
hkfires
e332419081 feat(registry): add thinking support for gemini-2.5-computer-use-preview model 2025-12-31 17:09:22 +08:00
Luis Pater
e998b1229a feat(updater): add fallback URL and logic for missing management asset 2025-12-31 11:51:20 +08:00
Luis Pater
bbed134bd1 feat(api): add GetAuthStatus method to ManagementTokenRequester interface 2025-12-31 09:40:48 +08:00
Saboor Hassan
47b9503112 chore: revert changes to internal/translator to comply with path guard
This commit reverts all modifications within internal/translator. A separate issue
will be created for the maintenance team to integrate SanitizeFunctionName into
the translators.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 02:19:26 +05:00
Saboor Hassan
3b9253c2be fix(translator): resolve invalid function name errors by sanitizing Claude tool names
This commit centralizes tool name sanitization in SanitizeFunctionName,
applying character compliance, starting character rules, and length limits.
It also fixes a regression in gemini_schema tests and preserves MCP-specific
shortening logic while ensuring compliance.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 02:14:46 +05:00
Saboor Hassan
d241359153 fix(translator): address PR feedback for tool name sanitization
- Pre-compile sanitization regex for better performance.
- Optimize SanitizeFunctionName for conciseness and correctness.
- Handle 64-char edge cases by truncating before prepending underscore.
- Fix bug in Antigravity translator (incorrect join index).
- Refactor Gemini translators to avoid redundant sanitization calls.
- Add comprehensive unit tests including 64-char edge cases.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 01:54:41 +05:00
Saboor Hassan
f4d4249ba5 feat(translator): sanitize tool/function names for upstream provider compatibility
Implemented SanitizeFunctionName utility to ensure Claude tool names meet
Gemini/Upstream strict naming conventions (alphanumeric, starts with letter/underscore, max 64 chars).
Applied sanitization to tool definitions and usage in all relevant translators.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 01:41:07 +05:00
Chén Mù
cb56cb250e Merge pull request #800 from router-for-me/modelmappings
feat(watcher): add model mappings change detection
2025-12-30 06:50:42 -08:00
hkfires
e0381a6ae0 refactor(watcher): extract model summary functions to dedicated file 2025-12-30 22:39:12 +08:00
hkfires
2c01b2ef64 feat(watcher): add Gemini models and OAuth model mappings change detection 2025-12-30 22:39:12 +08:00
Chén Mù
e947266743 Merge pull request #795 from router-for-me/modelmappings
refactor(executor): resolve upstream model at conductor level before execution
2025-12-30 05:31:19 -08:00
Luis Pater
c6b0e85b54 Fixed: #790
fix(gemini): include full text in response output events
2025-12-30 20:44:13 +08:00
hkfires
26efbed05c refactor(executor): remove redundant upstream model parameter from translateRequest 2025-12-30 20:20:42 +08:00
hkfires
96340bf136 refactor(executor): resolve upstream model at conductor level before execution 2025-12-30 19:31:54 +08:00
hkfires
b055e00c1a fix(executor): use upstream model for thinking config and payload translation 2025-12-30 17:49:44 +08:00
sususu
414db44c00 fix(antigravity): parse retry-after delay from 429 response body
When receiving HTTP 429 (Too Many Requests) responses, parse the retry
delay from the response body using parseRetryDelay and populate the
statusErr.retryAfter field. This allows upstream callers to respect
the server's requested retry timing.

Applied to all error paths in Execute, executeClaudeNonStream,
ExecuteStream, CountTokens, and refreshToken functions.
2025-12-30 16:07:32 +08:00
Chén Mù
857c880f99 Merge pull request #785 from router-for-me/gemini
feat(gemini): add per-key model alias support for Gemini provider
2025-12-29 23:32:40 -08:00
hkfires
ce7474d953 feat(cliproxy): propagate thinking support metadata to aliased models 2025-12-30 15:16:54 +08:00
hkfires
70fdd70b84 refactor(cliproxy): extract generic buildConfigModels function for model info generation 2025-12-30 13:35:22 +08:00
hkfires
08ab6a7d77 feat(gemini): add per-key model alias support for Gemini provider 2025-12-30 13:27:57 +08:00
Luis Pater
9fa2a7e9df Merge pull request #782 from router-for-me/modelmappings
refactor(config): rename model-name-mappings to oauth-model-mappings
2025-12-30 11:40:12 +08:00
hkfires
d443c86620 refactor(config): rename model mapping fields from from/to to name/alias 2025-12-30 11:07:59 +08:00
hkfires
7be3f1c36c refactor(config): rename model-name-mappings to oauth-model-mappings 2025-12-30 11:07:58 +08:00
Luis Pater
f6ab6d97b9 fix(logging): add isDirWritable utility to enhance log dir validation in ConfigureLogOutput 2025-12-30 10:48:25 +08:00
Luis Pater
bc866bac49 fix(logging): refactor ConfigureLogOutput to accept config object and adjust log directory handling 2025-12-30 10:28:25 +08:00
Luis Pater
50e6d845f4 feat(cliproxy): introduce global model name mappings for improved aliasing and routing 2025-12-30 08:13:06 +08:00
Luis Pater
a8cb01819d Merge pull request #772 from soffchen/main
fix: Implement fallback log directory for file logging on read-only system
2025-12-30 02:24:49 +08:00
Luis Pater
530273906b Merge pull request #776 from router-for-me/fix-ag-claude
fix(antigravity): inject required placeholder when properties exist w…
2025-12-30 00:37:01 +08:00
Supra4E8C
06ddf575d9 fix(antigravity): inject required placeholder when properties exist without required 2025-12-29 23:55:59 +08:00
hkfires
3099114cbb refactor(api): simplify codex id token claims extraction 2025-12-29 19:48:02 +08:00
Soff
44b63f0767 fix: Return an error if the user home directory cannot be determined for the fallback log path. 2025-12-29 18:46:15 +08:00
Soff Chen
6705d20194 fix: Implement fallback log directory for file logging on read-only systems. 2025-12-29 18:35:48 +08:00
Chén Mù
a38a9c0b0f Merge pull request #770 from router-for-me/api
feat(api): add id token claims extraction for codex auth entries
2025-12-29 00:44:41 -08:00
hkfires
8286caa366 feat(api): add id token claims extraction for codex auth entries 2025-12-29 16:34:16 +08:00
Chén Mù
bd1ec8424d Merge pull request #767 from router-for-me/amp
feat(amp): add per-client upstream API key mapping support
2025-12-28 22:10:11 -08:00
hkfires
225e2c6797 feat(amp): add per-client upstream API key mapping support 2025-12-29 12:26:25 +08:00
Luis Pater
d8fc485513 fix(translators): correct key path for system_instruction.parts in Claude request logic 2025-12-29 11:54:26 +08:00
hkfires
f137eb0ac4 chore: add codex, agents, and opencode dirs to ignore files 2025-12-29 08:42:29 +08:00
Chén Mù
f39a460487 Merge pull request #761 from router-for-me/log
fix(logging): improve request/response capture
2025-12-28 16:13:10 -08:00
Luis Pater
ee171bc563 feat(api): add ManagementTokenRequester interface for management token request endpoints 2025-12-29 02:42:29 +08:00
hkfires
a95428f204 fix(handlers): preserve upstream response logs before duplicate detection 2025-12-28 22:35:36 +08:00
Michael Velbaum
cb3bdffb43 refactor(logging): streamline auth selection debug messages
Reduce duplicate Debugf calls by appending proxy info via an optional suffix and keep the debug-level guard inside the helper.
2025-12-28 16:10:11 +02:00
Michael Velbaum
48f19aab51 refactor(logging): pass request entry into auth selection log
Avoid re-creating the request-scoped log entry in the helper and use a switch for account type dispatch.
2025-12-28 15:51:11 +02:00
Michael Velbaum
48f6d7abdf refactor(logging): dedupe auth selection debug logs
Extract repeated debug logging for selected auth credentials into a helper so execute, count, and stream paths stay consistent.
2025-12-28 15:42:35 +02:00
Michael Velbaum
79fbcb3ec4 fix(logging): quote OAuth account field
Use strconv.Quote when embedding the OAuth account in debug logs so unexpected characters (e.g. quotes) can't break key=value parsing.
2025-12-28 15:32:54 +02:00
Michael Velbaum
0e4148b229 feat(logging): disambiguate OAuth credential selection in debug logs
When multiple OAuth providers share an account email, the existing "Use OAuth" debug lines are ambiguous and hard to correlate with management usage stats. Include provider, auth file, and auth index in the selection log, and only compute these fields when debug logging is enabled to avoid impacting normal request performance.

Before:
[debug] Use OAuth user@example.com for model gemini-3-flash-preview
[debug] Use OAuth user@example.com (project-1234) for model gemini-3-flash-preview

After:
[debug] Use OAuth provider=antigravity auth_file=antigravity-user_example_com.json auth_index=1a2b3c4d5e6f7788 account="user@example.com" for model gemini-3-flash-preview
[debug] Use OAuth provider=gemini-cli auth_file=gemini-user@example.com-project-1234.json auth_index=99aabbccddeeff00 account="user@example.com (project-1234)" for model gemini-3-flash-preview
2025-12-28 15:22:36 +02:00
hkfires
3ca5fb1046 fix(handlers): match raw error text before JSON body for duplicate detection 2025-12-28 19:35:36 +08:00
hkfires
a091d12f4e fix(logging): improve request/response capture 2025-12-28 19:04:31 +08:00
Luis Pater
457924828a Merge pull request #757 from ben-vargas/fix-thinking-toolchoice-conflict
Fix: disable thinking when tool_choice forces tool use
2025-12-28 14:04:30 +08:00
Ben Vargas
aca2ef6359 Fix: disable thinking when tool_choice forces tool use
Anthropic API does not allow extended thinking when tool_choice is set
to "any" or a specific tool. This was causing 400 errors when using
features like Amp's /handoff command which forces tool_choice.

Added disableThinkingIfToolChoiceForced() that removes thinking config
when incompatible tool_choice is detected, applied to both streaming
and non-streaming paths.

Fixes router-for-me/CLIProxyAPI#630
2025-12-27 16:31:37 -07:00
Luis Pater
ade7194792 feat(management): add generic API call handler to management endpoints 2025-12-28 04:40:32 +08:00
Luis Pater
3a436e116a feat(cliproxy): implement model aliasing and hashing for Codex configurations, enhance request routing logic, and normalize Codex model entries 2025-12-28 03:06:51 +08:00
Luis Pater
336867853b Merge pull request #756 from leaph/check-ai-thinking-settings
feat(iflow): add model-specific thinking configs for GLM-4.7 and Mini…
2025-12-28 02:08:27 +08:00
leaph
6403ff4ec4 feat(iflow): add model-specific thinking configs for GLM-4.7 and MiniMax-M2.1
- GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
- MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
- Added preserveReasoningContentInMessages() to support re-injection of reasoning
  content in assistant message history for multi-turn conversations
- Added ThinkingSupport to MiniMax-M2.1 model definition
2025-12-27 18:39:15 +01:00
Luis Pater
d222469b44 Update issue templates 2025-12-28 01:22:42 +08:00
이대희
31bd90c748 feature: Improves Amp client compatibility
Ensures compatibility with the Amp client by suppressing
"thinking" blocks when "tool_use" blocks are also present in
the response.

The Amp client has issues rendering both types of blocks
simultaneously. This change filters out "thinking" blocks in
such cases, preventing rendering problems.
2025-12-19 08:18:27 +09:00
Muzhen Gaming
0b834fcb54 fix(translator): preserve built-in tools across openai<->responses
- Pass through non-function tool definitions like web_search

- Translate tool_choice for built-in tools and function tools

- Add regression tests for built-in tool passthrough
2025-12-15 21:18:54 +08:00
87 changed files with 6410 additions and 665 deletions

View File

@@ -23,11 +23,14 @@ config.yaml
# Development/editor
bin/*
.claude/*
.vscode/*
.claude/*
.codex/*
.gemini/*
.serena/*
.agent/*
.agents/*
.opencode/*
.bmad/*
_bmad/*
_bmad-output/*

View File

@@ -7,6 +7,13 @@ assignees: ''
---
**Is it a request payload issue?**
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
[ ] No, it's another issue.
**If it's a request payload issue, you MUST know**
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
**Describe the bug**
A clear and concise description of what the bug is.

4
.gitignore vendored
View File

@@ -33,10 +33,14 @@ GEMINI.md
# Tooling metadata
.vscode/*
.codex/*
.claude/*
.gemini/*
.serena/*
.agent/*
.agents/*
.agents/*
.opencode/*
.bmad/*
_bmad/*
_bmad-output/*

View File

@@ -118,9 +118,32 @@ Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings,
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
### [CodMate](https://github.com/loocor/CodMate)
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
## More choices
Those projects are ports of CLIProxyAPI or inspired by it:
### [9Router](https://github.com/decolua/9router)
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

View File

@@ -117,9 +117,32 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
### [CodMate](https://github.com/loocor/CodMate)
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话Claude Code、Codex、Gemini CLI提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。
## 更多选择
以下项目是 CLIProxyAPI 的移植版或受其启发:
### [9Router](https://github.com/decolua/9router)
基于 Next.js 的实现,灵感来自 CLIProxyAPI易于安装使用自研格式转换OpenAI/Claude/Gemini/Ollama、组合系统与自动回退、多账户管理指数退避、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -405,7 +405,7 @@ func main() {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
if err = logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to configure log output: %v", err)
return
}

View File

@@ -35,6 +35,7 @@ auth-dir: "~/.cli-proxy-api"
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
# Enable debug logging
debug: false
@@ -89,6 +90,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
# excluded-models:
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
@@ -104,6 +108,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
# excluded-models:
# - "gpt-5.1" # exclude specific models (exact match)
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
@@ -121,7 +128,7 @@ ws-auth: false
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# excluded-models:
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
@@ -152,9 +159,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names
# - name: "gemini-2.0-flash" # upstream model name
# - name: "gemini-2.5-flash" # upstream model name
# alias: "vertex-flash" # client-visible alias
# - name: "gemini-1.5-pro"
# - name: "gemini-2.5-pro"
# alias: "vertex-pro"
# Amp Integration
@@ -163,6 +170,18 @@ ws-auth: false
# upstream-url: "https://ampcode.com"
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
# upstream-api-key: ""
# # Per-client upstream API key mapping
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
# # Useful when different clients need to use different Amp accounts/quotas.
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
# upstream-api-keys:
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
# api-keys: # Client keys that use this upstream key
# - "your-api-key-1"
# - "your-api-key-2"
# - upstream-api-key: "amp_key_for_team_b"
# api-keys:
# - "your-api-key-3"
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
# restrict-management-to-localhost: false
# # Force model mappings to run before checking local API keys (default: false)
@@ -172,12 +191,43 @@ ws-auth: false
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
# # but you have a similar model available (e.g., Claude Sonnet 4).
# model-mappings:
# - from: "claude-opus-4.5" # Model requested by Amp CLI
# to: "claude-sonnet-4" # Route to this available model instead
# - from: "gpt-5"
# to: "gemini-2.5-pro"
# - from: "claude-3-opus-20240229"
# to: "claude-3-5-sonnet-20241022"
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
# - from: "claude-sonnet-4-5-20250929"
# to: "gemini-claude-sonnet-4-5-thinking"
# - from: "claude-haiku-4-5-20251001"
# to: "gemini-2.5-flash"
# Global OAuth model name mappings (per channel)
# These mappings rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# oauth-model-mappings:
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
# vertex:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-preview"
# alias: "g3p"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
# codex:
# - name: "gpt-5"
# alias: "g5"
# qwen:
# - name: "qwen3-coder-plus"
# alias: "qwen-plus"
# iflow:
# - name: "glm-4.7"
# alias: "glm-god"
# OAuth provider excluded models
# oauth-excluded-models:

View File

@@ -5,9 +5,115 @@
# This script automates the process of building and running the Docker container
# with version information dynamically injected at build time.
# Exit immediately if a command exits with a non-zero status.
# Hidden feature: Preserve usage statistics across rebuilds
# Usage: ./docker-build.sh --with-usage
# First run prompts for management API key, saved to temp/stats/.api_secret
set -euo pipefail
STATS_DIR="temp/stats"
STATS_FILE="${STATS_DIR}/.usage_backup.json"
SECRET_FILE="${STATS_DIR}/.api_secret"
WITH_USAGE=false
get_port() {
if [[ -f "config.yaml" ]]; then
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
else
echo "8317"
fi
}
export_stats_api_secret() {
if [[ -f "${SECRET_FILE}" ]]; then
API_SECRET=$(cat "${SECRET_FILE}")
else
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
echo "First time using --with-usage. Management API key required."
read -r -p "Enter management key: " -s API_SECRET
echo
echo "${API_SECRET}" > "${SECRET_FILE}"
chmod 600 "${SECRET_FILE}"
fi
}
check_container_running() {
local port
port=$(get_port)
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
echo "Please start the container first or use without --with-usage flag."
exit 1
fi
}
export_stats() {
local port
port=$(get_port)
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
check_container_running
echo "Exporting usage statistics..."
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
"http://localhost:${port}/v0/management/usage/export")
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
if [[ "${HTTP_CODE}" != "200" ]]; then
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
exit 1
fi
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
echo "Statistics exported to ${STATS_FILE}"
}
import_stats() {
local port
port=$(get_port)
echo "Importing usage statistics..."
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
-H "X-Management-Key: ${API_SECRET}" \
-H "Content-Type: application/json" \
-d @"${STATS_FILE}" \
"http://localhost:${port}/v0/management/usage/import")
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
if [[ "${IMPORT_CODE}" == "200" ]]; then
echo "Statistics imported successfully"
else
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
fi
rm -f "${STATS_FILE}"
}
wait_for_service() {
local port
port=$(get_port)
echo "Waiting for service to be ready..."
for i in {1..30}; do
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
break
fi
sleep 1
done
sleep 2
}
if [[ "${1:-}" == "--with-usage" ]]; then
WITH_USAGE=true
export_stats_api_secret
fi
# --- Step 1: Choose Environment ---
echo "Please select an option:"
echo "1) Run using Pre-built Image (Recommended)"
@@ -18,7 +124,14 @@ read -r -p "Enter choice [1-2]: " choice
case "$choice" in
1)
echo "--- Running with Pre-built Image ---"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
docker compose up -d --remove-orphans --no-build
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Services are starting from remote image."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -38,16 +151,25 @@ case "$choice" in
# Build and start the services with a local-only image tag
export CLI_PROXY_IMAGE="cli-proxy-api:local"
echo "Building the Docker image..."
docker compose build \
--build-arg VERSION="${VERSION}" \
--build-arg COMMIT="${COMMIT}" \
--build-arg BUILD_DATE="${BUILD_DATE}"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
echo "Starting the services..."
docker compose up -d --remove-orphans --pull never
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Build complete. Services are starting."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -55,4 +177,4 @@ case "$choice" in
echo "Invalid choice. Please enter 1 or 2."
exit 1
;;
esac
esac

View File

@@ -0,0 +1,704 @@
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
const (
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
const (
antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityOAuthClientSecret = "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"
)
var antigravityOAuthTokenURL = "https://oauth2.googleapis.com/token"
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
Method string `json:"method"`
URL string `json:"url"`
Header map[string]string `json:"header"`
Data string `json:"data"`
}
type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
// It is protected by the management middleware.
//
// Endpoint:
//
// POST /v0/management/api-call
//
// Authentication:
//
// Same as other management APIs (requires a management key and remote-management rules).
// You can provide the key via:
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON:
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
// - header (optional): Request headers map.
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
// 1) metadata.access_token
// 2) attributes.api_key
// 3) metadata.token / metadata.id_token / metadata.cookie
// Example: {"Authorization":"Bearer $TOKEN$"}.
// Note: if you need to override the HTTP Host header, set header["Host"].
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
//
// Proxy selection (highest priority first):
// 1. Selected credential proxy_url
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
//
// Example:
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer 831227" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
if method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
return
}
urlStr := strings.TrimSpace(body.URL)
if urlStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
auth := h.authByIndex(authIndex)
reqHeaders := body.Header
if reqHeaders == nil {
reqHeaders = map[string]string{}
}
var hostOverride string
var token string
var tokenResolved bool
var tokenErr error
for key, value := range reqHeaders {
if !strings.Contains(value, "$TOKEN$") {
continue
}
if !tokenResolved {
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
tokenResolved = true
}
if auth != nil && token == "" {
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
return
}
if token == "" {
continue
}
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
}
for key, value := range reqHeaders {
if strings.EqualFold(key, "host") {
hostOverride = strings.TrimSpace(value)
continue
}
req.Header.Set(key, value)
}
if hostOverride != "" {
req.Host = hostOverride
}
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
}
httpClient.Transport = h.apiCallTransport(auth)
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("management APICall request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
c.JSON(http.StatusOK, apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
}
func firstNonEmptyString(values ...*string) string {
for _, v := range values {
if v == nil {
continue
}
if out := strings.TrimSpace(*v); out != "" {
return out
}
}
return ""
}
func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
return v
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return v
}
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
return v
}
}
return ""
}
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "gemini-cli" {
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
if provider == "antigravity" {
token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata, updater := geminiOAuthMetadata(auth)
if len(metadata) == 0 {
return "", fmt.Errorf("gemini oauth metadata missing")
}
base := make(map[string]any)
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
}
var token oauth2.Token
if len(base) > 0 {
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
src := conf.TokenSource(ctxToken, &token)
currentToken, errToken := src.Token()
if errToken != nil {
return "", errToken
}
merged := buildOAuthTokenMap(base, currentToken)
fields := buildOAuthTokenFields(currentToken, merged)
if updater != nil {
updater(fields)
}
return strings.TrimSpace(currentToken.AccessToken), nil
}
func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata := auth.Metadata
if len(metadata) == 0 {
return "", fmt.Errorf("antigravity oauth metadata missing")
}
current := strings.TrimSpace(tokenValueFromMetadata(metadata))
if current != "" && !antigravityTokenNeedsRefresh(metadata) {
return current, nil
}
refreshToken := stringValue(metadata, "refresh_token")
if refreshToken == "" {
return "", fmt.Errorf("antigravity refresh token missing")
}
tokenURL := strings.TrimSpace(antigravityOAuthTokenURL)
if tokenURL == "" {
tokenURL = "https://oauth2.googleapis.com/token"
}
form := url.Values{}
form.Set("client_id", antigravityOAuthClientID)
form.Set("client_secret", antigravityOAuthClientSecret)
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
req, errReq := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if errReq != nil {
return "", errReq
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", errRead
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("antigravity oauth token refresh failed: status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
return "", errUnmarshal
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token")
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
now := time.Now()
auth.Metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if strings.TrimSpace(tokenResp.RefreshToken) != "" {
auth.Metadata["refresh_token"] = strings.TrimSpace(tokenResp.RefreshToken)
}
if tokenResp.ExpiresIn > 0 {
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
auth.Metadata["timestamp"] = now.UnixMilli()
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
}
auth.Metadata["type"] = "antigravity"
if h != nil && h.authManager != nil {
auth.LastRefreshedAt = now
auth.UpdatedAt = now
_, _ = h.authManager.Update(ctx, auth)
}
return strings.TrimSpace(tokenResp.AccessToken), nil
}
func antigravityTokenNeedsRefresh(metadata map[string]any) bool {
// Refresh a bit early to avoid requests racing token expiry.
const skew = 30 * time.Second
if metadata == nil {
return true
}
if expStr, ok := metadata["expired"].(string); ok {
if ts, errParse := time.Parse(time.RFC3339, strings.TrimSpace(expStr)); errParse == nil {
return !ts.After(time.Now().Add(skew))
}
}
expiresIn := int64Value(metadata["expires_in"])
timestampMs := int64Value(metadata["timestamp"])
if expiresIn > 0 && timestampMs > 0 {
exp := time.UnixMilli(timestampMs).Add(time.Duration(expiresIn) * time.Second)
return !exp.After(time.Now().Add(skew))
}
return true
}
func int64Value(raw any) int64 {
switch typed := raw.(type) {
case int:
return int64(typed)
case int32:
return int64(typed)
case int64:
return typed
case uint:
return int64(typed)
case uint32:
return int64(typed)
case uint64:
if typed > uint64(^uint64(0)>>1) {
return 0
}
return int64(typed)
case float32:
return int64(typed)
case float64:
return int64(typed)
case json.Number:
if i, errParse := typed.Int64(); errParse == nil {
return i
}
case string:
if s := strings.TrimSpace(typed); s != "" {
if i, errParse := json.Number(s).Int64(); errParse == nil {
return i
}
}
}
return 0
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
snapshot := shared.MetadataSnapshot()
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
}
return auth.Metadata, func(fields map[string]any) {
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
}
func stringValue(metadata map[string]any, key string) string {
if len(metadata) == 0 || key == "" {
return ""
}
if v, ok := metadata[key].(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if tok == nil {
return merged
}
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
var tokenMap map[string]any
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok != nil && tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok != nil && tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok != nil && tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if tok != nil && !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func tokenValueFromMetadata(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
switch typed := tokenRaw.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
return v
}
case map[string]any:
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
case map[string]string:
if v := strings.TrimSpace(typed["access_token"]); v != "" {
return v
}
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
return v
}
}
}
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
return ""
}
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
authIndex = strings.TrimSpace(authIndex)
if authIndex == "" || h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
return nil
}
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
return transport
}
}
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
return &http.Transport{Proxy: nil}
}
clone := transport.Clone()
clone.Proxy = nil
return clone
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}

View File

@@ -0,0 +1,173 @@
package management
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
_ = ctx
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
h := &Handler{authManager: manager}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
}
updated, ok := manager.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth in manager after update")
}
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
t.Fatalf("expected manager metadata updated, got %q", got)
}
}
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token)
}
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount)
}
}

View File

@@ -427,9 +427,46 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
log.WithError(err).Warnf("failed to stat auth file %s", path)
}
}
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
return entry
}
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
if auth == nil || auth.Metadata == nil {
return nil
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
return nil
}
idTokenRaw, ok := auth.Metadata["id_token"].(string)
if !ok {
return nil
}
idToken := strings.TrimSpace(idTokenRaw)
if idToken == "" {
return nil
}
claims, err := codex.ParseJWTToken(idToken)
if err != nil || claims == nil {
return nil
}
result := gin.H{}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
result["chatgpt_account_id"] = v
}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
result["plan_type"] = v
}
if len(result) == 0 {
return nil
}
return result
}
func authEmail(auth *coreauth.Auth) string {
if auth == nil {
return ""

View File

@@ -202,6 +202,26 @@ func (h *Handler) PutLoggingToFile(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
}
// LogsMaxTotalSizeMB
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
}
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
var body struct {
Value *int `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
value := *body.Value
if value < 0 {
value = 0
}
h.cfg.LogsMaxTotalSizeMB = value
h.persist(c)
}
// Request log
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
func (h *Handler) PutRequestLog(c *gin.Context) {
@@ -232,6 +252,52 @@ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
}
// ForceModelPrefix
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
}
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
}
func normalizeRoutingStrategy(strategy string) (string, bool) {
normalized := strings.ToLower(strings.TrimSpace(strategy))
switch normalized {
case "", "round-robin", "roundrobin", "rr":
return "round-robin", true
case "fill-first", "fillfirst", "ff":
return "fill-first", true
default:
return "", false
}
}
// RoutingStrategy
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
if !ok {
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
return
}
c.JSON(200, gin.H{"strategy": strategy})
}
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
var body struct {
Value *string `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
normalized, ok := normalizeRoutingStrategy(*body.Value)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
return
}
h.cfg.Routing.Strategy = normalized
h.persist(c)
}
// Proxy URL
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
func (h *Handler) PutProxyURL(c *gin.Context) {

View File

@@ -487,6 +487,137 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
c.JSON(400, gin.H{"error": "missing name or index"})
}
// vertex-api-key: []VertexCompatKey
func (h *Handler) GetVertexCompatKeys(c *gin.Context) {
c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey})
}
func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.VertexCompatKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.VertexCompatKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
for i := range arr {
normalizeVertexCompatKey(&arr[i])
}
h.cfg.VertexCompatAPIKey = arr
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
type vertexCompatPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
Models *[]config.VertexCompatModel `json:"models"`
}
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *vertexCompatPatch `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
if match != "" {
for i := range h.cfg.VertexCompatAPIKey {
if h.cfg.VertexCompatAPIKey[i].APIKey == match {
targetIndex = i
break
}
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.VertexCompatAPIKey[targetIndex]
if body.Value.APIKey != nil {
trimmed := strings.TrimSpace(*body.Value.APIKey)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.APIKey = trimmed
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.Models != nil {
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
}
normalizeVertexCompatKey(&entry)
h.cfg.VertexCompatAPIKey[targetIndex] = entry
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
for _, v := range h.cfg.VertexCompatAPIKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.VertexCompatAPIKey = out
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, errScan := fmt.Sscanf(idxStr, "%d", &idx)
if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
// oauth-excluded-models: map[string][]string
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
@@ -572,6 +703,103 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
h.persist(c)
}
// oauth-model-mappings: map[string][]ModelNameMapping
func (h *Handler) GetOAuthModelMappings(c *gin.Context) {
c.JSON(200, gin.H{"oauth-model-mappings": sanitizedOAuthModelMappings(h.cfg.OAuthModelMappings)})
}
func (h *Handler) PutOAuthModelMappings(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var entries map[string][]config.ModelNameMapping
if err = json.Unmarshal(data, &entries); err != nil {
var wrapper struct {
Items map[string][]config.ModelNameMapping `json:"items"`
}
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
entries = wrapper.Items
}
h.cfg.OAuthModelMappings = sanitizedOAuthModelMappings(entries)
h.persist(c)
}
func (h *Handler) PatchOAuthModelMappings(c *gin.Context) {
var body struct {
Provider *string `json:"provider"`
Channel *string `json:"channel"`
Mappings []config.ModelNameMapping `json:"mappings"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
channelRaw := ""
if body.Channel != nil {
channelRaw = *body.Channel
} else if body.Provider != nil {
channelRaw = *body.Provider
}
channel := strings.ToLower(strings.TrimSpace(channelRaw))
if channel == "" {
c.JSON(400, gin.H{"error": "invalid channel"})
return
}
normalizedMap := sanitizedOAuthModelMappings(map[string][]config.ModelNameMapping{channel: body.Mappings})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
if h.cfg.OAuthModelMappings == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelMappings, channel)
if len(h.cfg.OAuthModelMappings) == 0 {
h.cfg.OAuthModelMappings = nil
}
h.persist(c)
return
}
if h.cfg.OAuthModelMappings == nil {
h.cfg.OAuthModelMappings = make(map[string][]config.ModelNameMapping)
}
h.cfg.OAuthModelMappings[channel] = normalized
h.persist(c)
}
func (h *Handler) DeleteOAuthModelMappings(c *gin.Context) {
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
if channel == "" {
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
}
if channel == "" {
c.JSON(400, gin.H{"error": "missing channel"})
return
}
if h.cfg.OAuthModelMappings == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelMappings, channel)
if len(h.cfg.OAuthModelMappings) == 0 {
h.cfg.OAuthModelMappings = nil
}
h.persist(c)
}
// codex-api-key: []CodexKey
func (h *Handler) GetCodexKeys(c *gin.Context) {
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
@@ -597,11 +825,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
filtered := make([]config.CodexKey, 0, len(arr))
for i := range arr {
entry := arr[i]
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
normalizeCodexKey(&entry)
if entry.BaseURL == "" {
continue
}
@@ -613,12 +837,13 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
}
func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Models *[]config.CodexModel `json:"models"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct {
Index *int `json:"index"`
@@ -667,12 +892,16 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Models != nil {
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeCodexKey(&entry)
h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys()
h.persist(c)
@@ -762,6 +991,79 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
entry.Models = normalized
}
func normalizeCodexKey(entry *config.CodexKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.CodexModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" && model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.VertexCompatModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" || model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func sanitizedOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping {
if len(entries) == 0 {
return nil
}
copied := make(map[string][]config.ModelNameMapping, len(entries))
for channel, mappings := range entries {
if len(mappings) == 0 {
continue
}
copied[channel] = append([]config.ModelNameMapping(nil), mappings...)
}
if len(copied) == 0 {
return nil
}
cfg := config.Config{OAuthModelMappings: copied}
cfg.SanitizeOAuthModelMappings()
if len(cfg.OAuthModelMappings) == 0 {
return nil
}
return cfg.OAuthModelMappings
}
// GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil {
@@ -913,3 +1215,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
}
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
return
}
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
}
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
// Normalize entries: trim whitespace, filter empty
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
h.cfg.AmpCode.UpstreamAPIKeys = normalized
h.persist(c)
}
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
// Matching is done by upstream-api-key value.
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
}
for _, newEntry := range body.Value {
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
}
if idx, ok := existing[upstreamKey]; ok {
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
} else {
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
}
}
h.persist(c)
}
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
// If "value" is an empty array, clears all entries.
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Value == nil {
c.JSON(400, gin.H{"error": "missing value"})
return
}
// Empty array means clear all
if len(body.Value) == 0 {
h.cfg.AmpCode.UpstreamAPIKeys = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, key := range body.Value {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
toRemove[trimmed] = true
}
if len(toRemove) == 0 {
c.JSON(400, gin.H{"error": "empty value"})
return
}
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
newEntries = append(newEntries, entry)
}
}
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
h.persist(c)
}
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
if len(entries) == 0 {
return nil
}
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
apiKeys := normalizeAPIKeysList(entry.APIKeys)
out = append(out, config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: apiKeys,
})
}
if len(out) == 0 {
return nil
}
return out
}
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
func normalizeAPIKeysList(keys []string) []string {
if len(keys) == 0 {
return nil
}
out := make([]string, 0, len(keys))
for _, k := range keys {
trimmed := strings.TrimSpace(k)
if trimmed != "" {
out = append(out, trimmed)
}
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -24,8 +24,15 @@ import (
type attemptInfo struct {
count int
blockedUntil time.Time
lastActivity time.Time // track last activity for cleanup
}
// attemptCleanupInterval controls how often stale IP entries are purged
const attemptCleanupInterval = 1 * time.Hour
// attemptMaxIdleTime controls how long an IP can be idle before cleanup
const attemptMaxIdleTime = 2 * time.Hour
// Handler aggregates config reference, persistence path and helpers.
type Handler struct {
cfg *config.Config
@@ -47,7 +54,7 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
envSecret = strings.TrimSpace(envSecret)
return &Handler{
h := &Handler{
cfg: cfg,
configFilePath: configFilePath,
failedAttempts: make(map[string]*attemptInfo),
@@ -57,6 +64,43 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
allowRemoteOverride: envSecret != "",
envSecret: envSecret,
}
h.startAttemptCleanup()
return h
}
// startAttemptCleanup launches a background goroutine that periodically
// removes stale IP entries from failedAttempts to prevent memory leaks.
func (h *Handler) startAttemptCleanup() {
go func() {
ticker := time.NewTicker(attemptCleanupInterval)
defer ticker.Stop()
for range ticker.C {
h.purgeStaleAttempts()
}
}()
}
// purgeStaleAttempts removes IP entries that have been idle beyond attemptMaxIdleTime
// and whose ban (if any) has expired.
func (h *Handler) purgeStaleAttempts() {
now := time.Now()
h.attemptsMu.Lock()
defer h.attemptsMu.Unlock()
for ip, ai := range h.failedAttempts {
// Skip if still banned
if !ai.blockedUntil.IsZero() && now.Before(ai.blockedUntil) {
continue
}
// Remove if idle too long
if now.Sub(ai.lastActivity) > attemptMaxIdleTime {
delete(h.failedAttempts, ip)
}
}
}
// NewHandler creates a new management handler instance.
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
return NewHandler(cfg, "", manager)
}
// SetConfig updates the in-memory config reference when the server hot-reloads.
@@ -144,6 +188,7 @@ func (h *Handler) Middleware() gin.HandlerFunc {
h.failedAttempts[clientIP] = aip
}
aip.count++
aip.lastActivity = time.Now()
if aip.count >= maxFailures {
aip.blockedUntil = time.Now().Add(banDuration)
aip.count = 0

View File

@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
}
}
// Check API key change
// Check API key change (both default and per-client mappings)
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
if apiKeyChanged {
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
if apiKeyChanged || upstreamAPIKeysChanged {
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
if apiKeyChanged {
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
if upstreamAPIKeysChanged {
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
}
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
if m.secretSource == nil {
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
mappedSource := NewMappedSecretSource(defaultSource)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
ms.UpdateMappings(settings.UpstreamAPIKeys)
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
mappedSource := NewMappedSecretSource(ms)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
}
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
@@ -313,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
return oldKey != newKey
}
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.UpstreamAPIKeys) > 0
}
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
return true
}
// Build map for comparison: upstreamKey -> set of clientKeys
type entryInfo struct {
upstreamKey string
clientKeys map[string]struct{}
}
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
for i, entry := range old.UpstreamAPIKeys {
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
for _, k := range entry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
clientKeys[trimmed] = struct{}{}
}
oldEntries[i] = entryInfo{
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
clientKeys: clientKeys,
}
}
for i, newEntry := range new.UpstreamAPIKeys {
if i >= len(oldEntries) {
return true
}
oldE := oldEntries[i]
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
return true
}
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
for _, k := range newEntry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
newKeys[trimmed] = struct{}{}
}
if len(newKeys) != len(oldE.clientKeys) {
return true
}
for k := range newKeys {
if _, ok := oldE.clientKeys[k]; !ok {
return true
}
}
}
return false
}
// GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper

View File

@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
})
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
},
}
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
},
}
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected no change when only whitespace/empty entries differ")
}
}

View File

@@ -15,6 +15,33 @@ import (
log "github.com/sirupsen/logrus"
)
func removeQueryValuesMatching(req *http.Request, key string, match string) {
if req == nil || req.URL == nil || match == "" {
return
}
q := req.URL.Query()
values, ok := q[key]
if !ok || len(values) == 0 {
return
}
kept := make([]string, 0, len(values))
for _, v := range values {
if v == match {
continue
}
kept = append(kept, v)
}
if len(kept) == 0 {
q.Del(key)
} else {
q[key] = kept
}
req.URL.RawQuery = q.Encode()
}
// readCloser wraps a reader and forwards Close to a separate closer.
// Used to restore peeked bytes while preserving upstream body Close behavior.
type readCloser struct {
@@ -45,6 +72,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// We will set our own Authorization using the configured upstream-api-key
req.Header.Del("Authorization")
req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
// Remove query-based credentials if they match the authenticated client API key.
// This prevents leaking client auth material to the Amp upstream while avoiding
// breaking unrelated upstream query parameters.
clientKey := getClientAPIKeyFromContext(req.Context())
removeQueryValuesMatching(req, "key", clientKey)
removeQueryValuesMatching(req, "auth_token", clientKey)
// Preserve correlation headers for debugging
if req.Header.Get("X-Request-ID") == "" {

View File

@@ -3,11 +3,15 @@ package amp
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Helper: compress data with gzip
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
}
}
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))

View File

@@ -69,7 +69,30 @@ func (rw *ResponseRewriter) Flush() {
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount {
var err error
data, err = sjson.SetBytes(data, "content", filtered.Value())
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
}
}
}
}
if rw.originalModel == "" {
return data
}

View File

@@ -1,6 +1,7 @@
package amp
import (
"context"
"errors"
"net"
"net/http"
@@ -16,6 +17,37 @@ import (
log "github.com/sirupsen/logrus"
)
// clientAPIKeyContextKey is the context key used to pass the client API key
// from gin.Context to the request context for SecretSource lookup.
type clientAPIKeyContextKey struct{}
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
// into the request context so that SecretSource can look it up for per-client upstream routing.
func clientAPIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the client API key from gin context (set by AuthMiddleware)
if apiKey, exists := c.Get("apiKey"); exists {
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
// Inject into request context for SecretSource.Get(ctx) to read
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
c.Request = c.Request.WithContext(ctx)
}
}
c.Next()
}
}
// getClientAPIKeyFromContext retrieves the client API key from request context.
// Returns empty string if not present.
func getClientAPIKeyFromContext(ctx context.Context) string {
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
if keyStr, ok := val.(string); ok {
return keyStr
}
}
return ""
}
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
}
// Inject client API key into request context for per-client upstream routing
ampAPI.Use(clientAPIKeyMiddleware())
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
if authWithBypass != nil {
rootMiddleware = append(rootMiddleware, authWithBypass)
}
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
if auth != nil {
ampProviders.Use(auth)
}
// Inject client API key into request context for per-client upstream routing
ampProviders.Use(clientAPIKeyMiddleware())
provider := ampProviders.Group("/:provider")

View File

@@ -9,6 +9,9 @@ import (
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
// SecretSource provides Amp API keys with configurable precedence and caching
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
return s.key, nil
}
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
// When a request context contains a client API key that matches a configured mapping,
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
type MappedSecretSource struct {
defaultSource SecretSource
mu sync.RWMutex
lookup map[string]string // clientKey -> upstreamKey
}
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
return &MappedSecretSource{
defaultSource: defaultSource,
lookup: make(map[string]string),
}
}
// Get retrieves the Amp API key, checking per-client mappings first.
// If the request context contains a client API key that matches a configured mapping,
// returns the corresponding upstream key. Otherwise, falls back to the default source.
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
// Try to get client API key from request context
clientKey := getClientAPIKeyFromContext(ctx)
if clientKey != "" {
s.mu.RLock()
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
s.mu.RUnlock()
return upstreamKey, nil
}
s.mu.RUnlock()
}
// Fall back to default source
return s.defaultSource.Get(ctx)
}
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
// If the same client key appears in multiple entries, logs a warning and uses the first one.
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
newLookup := make(map[string]string)
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
for _, clientKey := range entry.APIKeys {
trimmedKey := strings.TrimSpace(clientKey)
if trimmedKey == "" {
continue
}
if _, exists := newLookup[trimmedKey]; exists {
// Log warning for duplicate client key, first one wins
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
continue
}
newLookup[trimmedKey] = upstreamKey
}
}
s.mu.Lock()
s.lookup = newLookup
s.mu.Unlock()
}
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(key)
}
}
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) InvalidateCache() {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.InvalidateCache()
}
}

View File

@@ -8,6 +8,10 @@ import (
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
)
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
}
}
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1, got %q", got)
}
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
got, err = s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "default" {
t.Fatalf("want default fallback, got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1 (first wins), got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
hook := test.NewLocal(log.StandardLogger())
defer hook.Reset()
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
foundWarning := false
for _, entry := range hook.AllEntries() {
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
foundWarning = true
break
}
}
if !foundWarning {
t.Fatal("expected warning log for duplicate client key, but none was found")
}
}

View File

@@ -33,6 +33,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
@@ -491,6 +492,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
@@ -500,6 +505,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
mgmt.POST("/api-call", s.mgmt.APICall)
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
@@ -549,6 +556,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
@@ -557,6 +568,14 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
@@ -572,11 +591,21 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
mgmt.GET("/oauth-model-mappings", s.mgmt.GetOAuthModelMappings)
mgmt.PUT("/oauth-model-mappings", s.mgmt.PutOAuthModelMappings)
mgmt.PATCH("/oauth-model-mappings", s.mgmt.PatchOAuthModelMappings)
mgmt.DELETE("/oauth-model-mappings", s.mgmt.DeleteOAuthModelMappings)
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
@@ -850,7 +879,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
}
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
if err := logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to reconfigure log output: %v", err)
} else {
if oldCfg == nil {
@@ -960,8 +989,12 @@ func (s *Server) UpdateClients(cfg *config.Config) {
log.Warnf("amp module is nil, skipping config update")
}
// Count client sources from configuration and auth directory
authFiles := util.CountAuthFiles(cfg.AuthDir)
// Count client sources from configuration and auth store.
tokenStore := sdkAuth.GetTokenStore()
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(cfg.AuthDir)
}
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
geminiAPIKeyCount := len(cfg.GeminiKey)
claudeAPIKeyCount := len(cfg.ClaudeKey)
codexAPIKeyCount := len(cfg.CodexKey)
@@ -972,10 +1005,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
openAICompatCount += len(entry.APIKeyEntries)
}
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total,
authFiles,
authEntries,
geminiAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,

View File

@@ -26,11 +26,17 @@ const (
// MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50
// SessionCleanupInterval controls how often stale sessions are purged
SessionCleanupInterval = 10 * time.Minute
)
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
var signatureCache sync.Map
// sessionCleanupOnce ensures the background cleanup goroutine starts only once
var sessionCleanupOnce sync.Once
// sessionCache is the inner map type
type sessionCache struct {
mu sync.RWMutex
@@ -45,6 +51,9 @@ func hashText(text string) string {
// getOrCreateSession gets or creates a session cache
func getOrCreateSession(sessionID string) *sessionCache {
// Start background cleanup on first access
sessionCleanupOnce.Do(startSessionCleanup)
if val, ok := signatureCache.Load(sessionID); ok {
return val.(*sessionCache)
}
@@ -53,6 +62,40 @@ func getOrCreateSession(sessionID string) *sessionCache {
return actual.(*sessionCache)
}
// startSessionCleanup launches a background goroutine that periodically
// removes sessions where all entries have expired.
func startSessionCleanup() {
go func() {
ticker := time.NewTicker(SessionCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredSessions()
}
}()
}
// purgeExpiredSessions removes sessions with no valid (non-expired) entries.
func purgeExpiredSessions() {
now := time.Now()
signatureCache.Range(func(key, value any) bool {
sc := value.(*sessionCache)
sc.mu.Lock()
// Remove expired entries
for k, entry := range sc.entries {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, k)
}
}
isEmpty := len(sc.entries) == 0
sc.mu.Unlock()
// Remove session if empty
if isEmpty {
signatureCache.Delete(key)
}
return true
})
}
// CacheSignature stores a thinking signature for a given session and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(sessionID, text, signature string) {

View File

@@ -91,6 +91,14 @@ type Config struct {
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
// These mappings affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
@@ -137,6 +145,16 @@ type RoutingConfig struct {
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// ModelNameMapping defines a model ID mapping for a specific channel.
// It maps the upstream model name (Name) to the client-visible alias (Alias).
// When Fork is true, the alias is added as an additional model in listings while
// keeping the original model ID available.
type ModelNameMapping struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available.
@@ -163,6 +181,11 @@ type AmpCode struct {
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
@@ -178,6 +201,17 @@ type AmpCode struct {
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
}
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
type AmpUpstreamAPIKeyEntry struct {
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads.
type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload.
@@ -237,6 +271,9 @@ type ClaudeModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m ClaudeModel) GetName() string { return m.Name }
func (m ClaudeModel) GetAlias() string { return m.Alias }
// CodexKey represents the configuration for a Codex API key,
// including the API key itself and an optional base URL for the API endpoint.
type CodexKey struct {
@@ -253,6 +290,9 @@ type CodexKey struct {
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// Models defines upstream model names and aliases for request routing.
Models []CodexModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -260,6 +300,18 @@ type CodexKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// CodexModel describes a mapping between an alias and the actual upstream model name.
type CodexModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m CodexModel) GetName() string { return m.Name }
func (m CodexModel) GetAlias() string { return m.Alias }
// GeminiKey represents the configuration for a Gemini API key,
// including optional overrides for upstream base URL, proxy routing, and headers.
type GeminiKey struct {
@@ -275,6 +327,9 @@ type GeminiKey struct {
// ProxyURL optionally overrides the global proxy for this API key.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// Models defines upstream model names and aliases for request routing.
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -282,6 +337,18 @@ type GeminiKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// GeminiModel describes a mapping between an alias and the actual upstream model name.
type GeminiModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m GeminiModel) GetName() string { return m.Name }
func (m GeminiModel) GetAlias() string { return m.Alias }
// OpenAICompatibility represents the configuration for OpenAI API compatibility
// with external providers, allowing model aliases to be routed through OpenAI API format.
type OpenAICompatibility struct {
@@ -433,6 +500,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Normalize global OAuth model name mappings.
cfg.SanitizeOAuthModelMappings()
if cfg.legacyMigrationPending {
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
if !optional && configFile != "" {
@@ -449,6 +519,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
return &cfg, nil
}
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// and ensures (From, To) pairs are unique within each channel.
func (cfg *Config) SanitizeOAuthModelMappings() {
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
return
}
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
for rawChannel, mappings := range cfg.OAuthModelMappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(mappings) == 0 {
continue
}
seenName := make(map[string]struct{}, len(mappings))
seenAlias := make(map[string]struct{}, len(mappings))
clean := make([]ModelNameMapping, 0, len(mappings))
for _, mapping := range mappings {
name := strings.TrimSpace(mapping.Name)
alias := strings.TrimSpace(mapping.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
nameKey := strings.ToLower(name)
aliasKey := strings.ToLower(alias)
if _, ok := seenName[nameKey]; ok {
continue
}
if _, ok := seenAlias[aliasKey]; ok {
continue
}
seenName[nameKey] = struct{}{}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork})
}
if len(clean) > 0 {
out[channel] = clean
}
}
cfg.OAuthModelMappings = out
}
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
// not actionable, specifically those missing a BaseURL. It trims whitespace before
// evaluation and preserves the relative order of remaining entries.

View File

@@ -0,0 +1,27 @@
package config
import "testing"
func TestSanitizeOAuthModelMappings_PreservesForkFlag(t *testing.T) {
cfg := &Config{
OAuthModelMappings: map[string][]ModelNameMapping{
" CoDeX ": {
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
{Name: "gpt-6", Alias: "g6"},
},
},
}
cfg.SanitizeOAuthModelMappings()
mappings := cfg.OAuthModelMappings["codex"]
if len(mappings) != 2 {
t.Fatalf("expected 2 sanitized mappings, got %d", len(mappings))
}
if mappings[0].Name != "gpt-5" || mappings[0].Alias != "g5" || !mappings[0].Fork {
t.Fatalf("expected first mapping to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", mappings[0].Name, mappings[0].Alias, mappings[0].Fork)
}
if mappings[1].Name != "gpt-6" || mappings[1].Alias != "g6" || mappings[1].Fork {
t.Fatalf("expected second mapping to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", mappings[1].Name, mappings[1].Alias, mappings[1].Fork)
}
}

View File

@@ -42,6 +42,9 @@ type VertexCompatModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m VertexCompatModel) GetName() string { return m.Name }
func (m VertexCompatModel) GetAlias() string { return m.Alias }
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
func (cfg *Config) SanitizeVertexCompatKeys() {
if cfg == nil {

View File

@@ -10,6 +10,7 @@ import (
"sync"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
@@ -83,10 +84,30 @@ func SetupBaseLogger() {
})
}
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
func isDirWritable(dir string) bool {
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
return false
}
testFile := filepath.Join(dir, ".perm_test")
f, err := os.Create(testFile)
if err != nil {
return false
}
defer func() {
_ = f.Close()
_ = os.Remove(testFile)
}()
return true
}
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit.
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
func ConfigureLogOutput(cfg *config.Config) error {
SetupBaseLogger()
writerMu.Lock()
@@ -95,10 +116,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
logDir := "logs"
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
} else if !isDirWritable(logDir) {
logDir = filepath.Join(cfg.AuthDir, "logs")
}
protectedPath := ""
if loggingToFile {
if cfg.LoggingToFile {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("logging: failed to create log directory: %w", err)
}
@@ -122,7 +145,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
log.SetOutput(os.Stdout)
}
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
return nil
}

View File

@@ -24,10 +24,11 @@ import (
)
const (
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
)
// ManagementFileName exposes the control panel asset filename.
@@ -198,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
return
}
localPath := filepath.Join(staticDir, managementAssetName)
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
}
// Rate limiting: check only once every 3 hours
lastUpdateCheckMu.Lock()
now := time.Now()
@@ -210,15 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock()
if err := os.MkdirAll(staticDir, 0o755); err != nil {
log.WithError(err).Warn("failed to prepare static directory for management asset")
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return
}
releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL)
localPath := filepath.Join(staticDir, managementAssetName)
localHash, err := fileSHA256(localPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
@@ -229,6 +239,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to fetch latest management release information")
return
}
@@ -240,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to download management asset")
return
}
@@ -256,6 +280,22 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
}
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
if err != nil {
log.WithError(err).Warn("failed to download fallback management control panel page")
return false
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
}
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
return true
}
func resolveReleaseURL(repo string) string {
repo = strings.TrimSpace(repo)
if repo == "" {

View File

@@ -740,8 +740,8 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {
@@ -773,7 +773,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
return map[string]*AntigravityModelConfig{
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
"gemini-2.5-computer-use-preview-10-2025": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-2.5-computer-use-preview-10-2025"},
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
@@ -781,3 +781,29 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
}
}
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
if modelID == "" {
return nil
}
allModels := [][]*ModelInfo{
GetClaudeModels(),
GetGeminiModels(),
GetGeminiVertexModels(),
GetGeminiCLIModels(),
GetAIStudioModels(),
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
}
for _, models := range allModels {
for _, m := range models {
if m != nil && m.ID == modelID {
return m
}
}
}
return nil
}

View File

@@ -4,6 +4,7 @@
package registry
import (
"context"
"fmt"
"sort"
"strings"
@@ -84,6 +85,13 @@ type ModelRegistration struct {
SuspendedClients map[string]string
}
// ModelRegistryHook provides optional callbacks for external integrations to track model list changes.
// Hook implementations must be non-blocking and resilient; calls are executed asynchronously and panics are recovered.
type ModelRegistryHook interface {
OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo)
OnModelsUnregistered(ctx context.Context, provider, clientID string)
}
// ModelRegistry manages the global registry of available models
type ModelRegistry struct {
// models maps model ID to registration information
@@ -97,6 +105,8 @@ type ModelRegistry struct {
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
// hook is an optional callback sink for model registration changes
hook ModelRegistryHook
}
// Global model registry instance
@@ -117,6 +127,53 @@ func GetGlobalRegistry() *ModelRegistry {
return globalRegistry
}
// SetHook sets an optional hook for observing model registration changes.
func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
if r == nil {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.hook = hook
}
const defaultModelRegistryHookTimeout = 5 * time.Second
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook
if hook == nil {
return
}
modelsCopy := cloneModelInfosUnique(models)
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsRegistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsRegistered(ctx, provider, clientID, modelsCopy)
}()
}
func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
hook := r.hook
if hook == nil {
return
}
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsUnregistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsUnregistered(ctx, provider, clientID)
}()
}
// RegisterClient registers a client and its supported models
// Parameters:
// - clientID: Unique identifier for the client
@@ -177,6 +234,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
} else {
delete(r.clientProviders, clientID)
}
r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator()
return
@@ -310,6 +368,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientProviders, clientID)
}
r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
return
@@ -400,6 +459,25 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
return &copyModel
}
func cloneModelInfosUnique(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
cloned := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
if _, exists := seen[model.ID]; exists {
continue
}
seen[model.ID] = struct{}{}
cloned = append(cloned, cloneModelInfo(model))
}
return cloned
}
// UnregisterClient removes a client and decrements counts for its models
// Parameters:
// - clientID: Unique identifier for the client to remove
@@ -460,6 +538,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
log.Debugf("Unregistered client %s", clientID)
// Separator line after completing client unregistration (after the summary line)
misc.LogCredentialSeparator()
r.triggerModelsUnregistered(provider, clientID)
}
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
@@ -625,6 +704,131 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
return models
}
// GetAvailableModelsByProvider returns models available for the given provider identifier.
// Parameters:
// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity")
//
// Returns:
// - []*ModelInfo: List of available models for the provider
func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return nil
}
r.mutex.RLock()
defer r.mutex.RUnlock()
type providerModel struct {
count int
info *ModelInfo
}
providerModels := make(map[string]*providerModel)
for clientID, clientProvider := range r.clientProviders {
if clientProvider != provider {
continue
}
modelIDs := r.clientModels[clientID]
if len(modelIDs) == 0 {
continue
}
clientInfos := r.clientModelInfos[clientID]
for _, modelID := range modelIDs {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
entry := providerModels[modelID]
if entry == nil {
entry = &providerModel{}
providerModels[modelID] = entry
}
entry.count++
if entry.info == nil {
if clientInfos != nil {
if info := clientInfos[modelID]; info != nil {
entry.info = info
}
}
if entry.info == nil {
if reg, ok := r.models[modelID]; ok && reg != nil && reg.Info != nil {
entry.info = reg.Info
}
}
}
}
}
if len(providerModels) == 0 {
return nil
}
quotaExpiredDuration := 5 * time.Minute
now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels))
for modelID, entry := range providerModels {
if entry == nil || entry.count <= 0 {
continue
}
registration, ok := r.models[modelID]
expiredClients := 0
cooldownSuspended := 0
otherSuspended := 0
if ok && registration != nil {
if registration.QuotaExceededClients != nil {
for clientID, quotaTime := range registration.QuotaExceededClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
expiredClients++
}
}
}
if registration.SuspendedClients != nil {
for clientID, reason := range registration.SuspendedClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if strings.EqualFold(reason, "quota") {
cooldownSuspended++
continue
}
otherSuspended++
}
}
}
availableClients := entry.count
effectiveClients := availableClients - expiredClients - otherSuspended
if effectiveClients < 0 {
effectiveClients = 0
}
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil {
result = append(result, entry.info)
continue
}
if ok && registration != nil && registration.Info != nil {
result = append(result, registration.Info)
}
}
}
return result
}
// GetModelCount returns the number of available clients for a specific model
// Parameters:
// - modelID: The model ID to check

View File

@@ -0,0 +1,204 @@
package registry
import (
"context"
"sync"
"testing"
"time"
)
func newTestModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
}
}
type registeredCall struct {
provider string
clientID string
models []*ModelInfo
}
type unregisteredCall struct {
provider string
clientID string
}
type capturingHook struct {
registeredCh chan registeredCall
unregisteredCh chan unregisteredCall
}
func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models}
}
func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID}
}
func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
inputModels := []*ModelInfo{
{ID: "m1", DisplayName: "Model One"},
{ID: "m2", DisplayName: "Model Two"},
}
r.RegisterClient("client-1", "OpenAI", inputModels)
select {
case call := <-hook.registeredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
if len(call.models) != 2 {
t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2)
}
if call.models[0] == nil || call.models[0].ID != "m1" {
t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1")
}
if call.models[1] == nil || call.models[1].ID != "m2" {
t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
}
func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
r.UnregisterClient("client-1")
select {
case call := <-hook.unregisteredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}
type blockingHook struct {
started chan struct{}
unblock chan struct{}
}
func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
select {
case <-h.started:
default:
close(h.started)
}
<-h.unblock
}
func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {}
func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) {
r := newTestModelRegistry()
hook := &blockingHook{
started: make(chan struct{}),
unblock: make(chan struct{}),
}
r.SetHook(hook)
defer close(hook.unblock)
done := make(chan struct{})
go func() {
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
close(done)
}()
select {
case <-hook.started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for hook to start")
}
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("RegisterClient appears to be blocked by hook")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
}
type panicHook struct {
registeredCalled chan struct{}
unregisteredCalled chan struct{}
}
func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
if h.registeredCalled != nil {
h.registeredCalled <- struct{}{}
}
panic("boom")
}
func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
if h.unregisteredCalled != nil {
h.unregisteredCalled <- struct{}{}
}
panic("boom")
}
func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) {
r := newTestModelRegistry()
hook := &panicHook{
registeredCalled: make(chan struct{}, 1),
unregisteredCalled: make(chan struct{}, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
r.UnregisterClient("client-1")
select {
case <-hook.unregisteredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}

View File

@@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
if err != nil {
return resp, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
@@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
if err != nil {
return nil, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
@@ -321,6 +323,11 @@ type translatedPayload struct {
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, stream)
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
@@ -329,7 +336,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfig(e.cfg, req.Model, payload)
payload = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", payload, originalTranslated)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")

View File

@@ -10,6 +10,7 @@ import (
"crypto/sha256"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
@@ -45,6 +46,7 @@ const (
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
)
var (
@@ -76,7 +78,8 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
// Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if strings.Contains(req.Model, "claude") {
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if isClaude || strings.Contains(req.Model, "gemini-3-pro") {
return e.executeClaudeNonStream(ctx, auth, req, opts)
}
@@ -93,13 +96,18 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -118,6 +126,9 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
@@ -150,7 +161,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err
}
@@ -164,7 +181,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
switch {
case lastStatus != 0:
err = statusErr{code: lastStatus, msg: string(lastBody)}
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
@@ -188,13 +211,18 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, true)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -213,6 +241,9 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
@@ -231,6 +262,14 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return resp, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return resp, err
}
lastStatus = 0
lastBody = nil
lastErr = errRead
@@ -249,7 +288,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err
}
@@ -314,7 +359,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
switch {
case lastStatus != 0:
err = statusErr{code: lastStatus, msg: string(lastBody)}
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
@@ -520,15 +571,22 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -547,6 +605,9 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
@@ -565,6 +626,14 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return nil, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return nil, err
}
lastStatus = 0
lastBody = nil
lastErr = errRead
@@ -583,7 +652,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return nil, err
}
@@ -638,7 +713,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
switch {
case lastStatus != 0:
err = statusErr{code: lastStatus, msg: string(lastBody)}
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
@@ -676,6 +757,8 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -692,9 +775,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
for idx, baseURL := range baseURLs {
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
payload = normalizeAntigravityThinking(req.Model, payload)
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, payload)
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
@@ -739,6 +822,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return cliproxyexecutor.Response{}, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
@@ -773,12 +859,24 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return cliproxyexecutor.Response{}, sErr
}
switch {
case lastStatus != 0:
return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)}
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return cliproxyexecutor.Response{}, sErr
case lastErr != nil:
return cliproxyexecutor.Response{}, lastErr
default:
@@ -815,6 +913,9 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
@@ -894,7 +995,13 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr
if accessToken != "" && expiry.After(time.Now().Add(refreshSkew)) {
return accessToken, nil, nil
}
updated, errRefresh := e.refreshToken(ctx, auth.Clone())
refreshCtx := context.Background()
if ctx != nil {
if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil {
refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt)
}
}
updated, errRefresh := e.refreshToken(refreshCtx, auth.Clone())
if errRefresh != nil {
return "", nil, errRefresh
}
@@ -941,7 +1048,13 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
return auth, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return auth, sErr
}
var tokenResp struct {
@@ -1021,6 +1134,19 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
payload = []byte(strJSON)
}
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-preview") {
systemInstructionPartsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", systemInstruction)
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
for _, partResult := range systemInstructionPartsResult.Array() {
payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(partResult.Raw))
}
}
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
if errReq != nil {
return nil, errReq
@@ -1155,8 +1281,8 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
return []string{base}
}
return []string{
antigravityBaseURLDaily,
antigravitySandboxBaseURLDaily,
antigravityBaseURLDaily,
antigravityBaseURLProd,
}
}
@@ -1184,6 +1310,7 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity")
template, _ = sjson.Set(template, "requestType", "agent")
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
@@ -1308,7 +1435,7 @@ func alias2ModelName(modelName string) string {
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
// For Claude models, it additionally ensures thinking budget < max_tokens.
func normalizeAntigravityThinking(model string, payload []byte) []byte {
func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
payload = util.StripThinkingConfigIfUnsupported(model, payload)
if !util.ModelSupportsThinking(model) {
return payload
@@ -1320,7 +1447,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
raw := int(budget.Int())
normalized := util.NormalizeThinkingBudget(model, raw)
isClaude := strings.Contains(strings.ToLower(model), "claude")
if isClaude {
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
if effectiveMax > 0 && normalized >= effectiveMax {

View File

@@ -1,10 +1,68 @@
package executor
import "time"
import (
"sync"
"time"
)
type codexCache struct {
ID string
Expire time.Time
}
var codexCacheMap = map[string]codexCache{}
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
// Protected by codexCacheMu. Entries expire after 1 hour.
var (
codexCacheMap = make(map[string]codexCache)
codexCacheMu sync.RWMutex
)
// codexCacheCleanupInterval controls how often expired entries are purged.
const codexCacheCleanupInterval = 15 * time.Minute
// codexCacheCleanupOnce ensures the background cleanup goroutine starts only once.
var codexCacheCleanupOnce sync.Once
// startCodexCacheCleanup launches a background goroutine that periodically
// removes expired entries from codexCacheMap to prevent memory leaks.
func startCodexCacheCleanup() {
go func() {
ticker := time.NewTicker(codexCacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredCodexCache()
}
}()
}
// purgeExpiredCodexCache removes entries that have expired.
func purgeExpiredCodexCache() {
now := time.Now()
codexCacheMu.Lock()
defer codexCacheMu.Unlock()
for key, cache := range codexCacheMap {
if cache.Expire.Before(now) {
delete(codexCacheMap, key)
}
}
}
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
func getCodexCache(key string) (codexCache, bool) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.RLock()
cache, ok := codexCacheMap[key]
codexCacheMu.RUnlock()
if !ok || cache.Expire.Before(time.Now()) {
return codexCache{}, false
}
return cache, true
}
// setCodexCache stores a cache entry.
func setCodexCache(key string, cache codexCache) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.Lock()
codexCacheMap[key] = cache
codexCacheMu.Unlock()
}

View File

@@ -35,6 +35,8 @@ type ClaudeExecutor struct {
cfg *config.Config
}
const claudeToolPrefix = "proxy_"
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
func (e *ClaudeExecutor) Identifier() string { return "claude" }
@@ -49,40 +51,46 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, stream)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
body = e.injectThinkingConfig(model, req.Metadata, body)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body)
body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
if err != nil {
return resp, err
}
@@ -97,7 +105,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Body: bodyForUpstream,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
@@ -151,8 +159,20 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.publish(ctx, parseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
out := sdktranslator.TranslateNonStream(
ctx,
to,
from,
req.Model,
bytes.Clone(opts.OriginalRequest),
bodyForTranslation,
data,
&param,
)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
@@ -167,33 +187,39 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
body = e.injectThinkingConfig(model, req.Metadata, body)
body = checkSystemInstructions(body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body)
body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
if err != nil {
return nil, err
}
@@ -208,7 +234,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Body: bodyForUpstream,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
@@ -261,6 +287,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
// Forward the line as-is to preserve SSE format
cloned := make([]byte, len(line)+1)
copy(cloned, line)
@@ -285,7 +314,19 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
if isClaudeOAuthToken(apiKey) {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
chunks := sdktranslator.TranslateStream(
ctx,
to,
from,
req.Model,
bytes.Clone(opts.OriginalRequest),
bodyForTranslation,
bytes.Clone(line),
&param,
)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
@@ -310,27 +351,23 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", model)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
}
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -461,6 +498,19 @@ func (e *ClaudeExecutor) injectThinkingConfig(modelName string, metadata map[str
return util.ApplyClaudeThinkingConfig(body, budget)
}
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
func disableThinkingIfToolChoiceForced(body []byte) []byte {
toolChoiceType := gjson.GetBytes(body, "tool_choice.type").String()
// "auto" is allowed with thinking, but "any" or "tool" (specific tool) are not
if toolChoiceType == "any" || toolChoiceType == "tool" {
// Remove thinking configuration entirely to avoid API error
body, _ = sjson.DeleteBytes(body, "thinking")
}
return body
}
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
// Anthropic API requires this constraint; violating it returns a 400 error.
// This function should be called after all thinking configuration is finalized.
@@ -762,3 +812,107 @@ func checkSystemInstructions(payload []byte) []byte {
}
return payload
}
func isClaudeOAuthToken(apiKey string) bool {
return strings.Contains(apiKey, "sk-ant-oat")
}
func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if prefix == "" {
return body
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
name := tool.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("tools.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
}
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
name := gjson.GetBytes(body, "tool_choice.name").String()
if name != "" && !strings.HasPrefix(name, prefix) {
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
}
}
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
messages.ForEach(func(msgIndex, msg gjson.Result) bool {
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
return true
}
content.ForEach(func(contentIndex, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
}
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
return true
})
}
return body
}
func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
if prefix == "" {
return body
}
content := gjson.GetBytes(body, "content")
if !content.Exists() || !content.IsArray() {
return body
}
content.ForEach(func(index, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
}
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
return true
})
return body
}
func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
if prefix == "" {
return line
}
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
}
contentBlock := gjson.GetBytes(payload, "content_block")
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
return line
}
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
return line
}
trimmed := bytes.TrimSpace(line)
if bytes.HasPrefix(trimmed, []byte("data:")) {
return append([]byte("data: "), updated...)
}
return updated
}

View File

@@ -0,0 +1,51 @@
package executor
import (
"bytes"
"testing"
"github.com/tidwall/gjson"
)
func TestApplyClaudeToolPrefix(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_alpha" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_alpha")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_bravo" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_bravo")
}
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "proxy_charlie" {
t.Fatalf("tool_choice.name = %q, want %q", got, "proxy_charlie")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_delta" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_delta")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.name").String(); got != "alpha" {
t.Fatalf("content.0.name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.name").String(); got != "bravo" {
t.Fatalf("content.1.name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.name").String(); got != "alpha" {
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
}
}

View File

@@ -49,18 +49,26 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return resp, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
@@ -146,20 +154,28 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return nil, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body, _ = sjson.SetBytes(body, "model", model)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -246,20 +262,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
modelForCounting := req.Model
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "stream", false)
enc, err := tokenizerForCodexModel(modelForCounting)
enc, err := tokenizerForCodexModel(model)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
}
@@ -440,14 +457,14 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
var hasKey bool
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) {
var ok bool
if cache, ok = getCodexCache(key); !ok {
cache = codexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
codexCacheMap[key] = cache
setCodexCache(key, cache)
}
}
} else if from == "openai-response" {
@@ -520,3 +537,87 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
}
return
}
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveCodexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -77,14 +77,19 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
action := "generateContent"
if req.Metadata != nil {
@@ -216,14 +221,19 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
projectID := resolveGeminiProjectID(auth)
@@ -318,7 +328,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response, reqBody []byte, attempt string) {
go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
@@ -336,14 +346,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reporter.publish(ctx, detail)
}
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
@@ -365,12 +375,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
@@ -417,15 +427,17 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
var lastStatus int
var lastBody []byte
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
// Gemini CLI endpoint when iterating fallback variants.
for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload)
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
tok, errTok := tokenSource.Token()
if errTok != nil {

View File

@@ -77,19 +77,27 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
// Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent"
if req.Metadata != nil {
@@ -98,7 +106,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
}
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -173,21 +181,29 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -287,19 +303,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, bearer := geminiCreds(auth)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens")
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
requestBody := bytes.NewReader(translatedReq)
@@ -398,6 +420,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
return base
}
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveGeminiConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
var attrs map[string]string
if auth != nil {

View File

@@ -120,10 +120,13 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
@@ -136,8 +139,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", req.Model)
action := "generateContent"
if req.Metadata != nil {
@@ -146,7 +149,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
}
}
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -220,24 +223,32 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent"
if req.Metadata != nil {
@@ -250,7 +261,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action)
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -321,10 +332,13 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
@@ -337,11 +351,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", req.Model)
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -438,30 +452,38 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -552,8 +574,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
// countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
@@ -566,14 +586,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
@@ -641,21 +661,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
// countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -665,7 +688,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
@@ -808,3 +831,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
}
return tok.AccessToken, nil
}
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
// It matches the requested model alias against configured models and returns the actual upstream name.
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveVertexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -56,18 +56,21 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return resp, errValidate
}
body = applyIFlowThinkingConfig(body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = preserveReasoningContentInMessages(body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -147,24 +150,27 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return nil, errValidate
}
body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
toolsResult := gjson.GetBytes(body, "tools")
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body)
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -445,20 +451,85 @@ func ensureToolsArray(body []byte) []byte {
return updated
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
// preserveReasoningContentInMessages checks if reasoning_content from assistant messages
// is preserved in conversation history for iFlow models that support thinking.
// This is helpful for multi-turn conversations where the model may benefit from seeing
// its previous reasoning to maintain coherent thought chains.
//
// For GLM-4.6/4.7 and MiniMax M2/M2.1, it is recommended to include the full assistant
// response (including reasoning_content) in message history for better context continuity.
func preserveReasoningContentInMessages(body []byte) []byte {
model := strings.ToLower(gjson.GetBytes(body, "model").String())
// Only apply to models that support thinking with history preservation
needsPreservation := strings.HasPrefix(model, "glm-4") || strings.HasPrefix(model, "minimax-m2")
if !needsPreservation {
return body
}
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
// Check if any assistant message already has reasoning_content preserved
hasReasoningContent := false
messages.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
if role == "assistant" {
rc := msg.Get("reasoning_content")
if rc.Exists() && rc.String() != "" {
hasReasoningContent = true
return false // stop iteration
}
}
return true
})
// If reasoning content is already present, the messages are properly formatted
// No need to modify - the client has correctly preserved reasoning in history
if hasReasoningContent {
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
}
return body
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
// This should be called after NormalizeThinkingConfig has processed the payload.
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
//
// Model-specific handling:
// - GLM-4.6/4.7: Uses chat_template_kwargs.enable_thinking (boolean) and chat_template_kwargs.clear_thinking=false
// - MiniMax M2/M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
func applyIFlowThinkingConfig(body []byte) []byte {
effort := gjson.GetBytes(body, "reasoning_effort")
if !effort.Exists() {
return body
}
model := strings.ToLower(gjson.GetBytes(body, "model").String())
val := strings.ToLower(strings.TrimSpace(effort.String()))
enableThinking := val != "none" && val != ""
// Remove reasoning_effort as we'll convert to model-specific format
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
body, _ = sjson.DeleteBytes(body, "thinking")
// GLM-4.6/4.7: Use chat_template_kwargs
if strings.HasPrefix(model, "glm-4") {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
if enableThinking {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.clear_thinking", false)
}
return body
}
// MiniMax M2/M2.1: Use reasoning_split
if strings.HasPrefix(model, "minimax-m2") {
body, _ = sjson.SetBytes(body, "reasoning_split", enableThinking)
return body
}
return body
}

View File

@@ -304,11 +304,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
parts = append(parts, "type=api_key")
}
case "oauth":
if authValue != "" {
parts = append(parts, fmt.Sprintf("type=oauth account=%s", authValue))
} else {
parts = append(parts, "type=oauth")
}
parts = append(parts, "type=oauth")
default:
if authType != "" {
if authValue != "" {

View File

@@ -53,20 +53,21 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate inbound request to OpenAI format
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
modelOverride := e.resolveUpstreamModel(req.Model, auth)
if modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" && modelOverride == "" {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
return resp, errValidate
}
@@ -149,20 +150,21 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
modelOverride := e.resolveUpstreamModel(req.Model, auth)
if modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" && modelOverride == "" {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
return nil, errValidate
}
@@ -239,6 +241,11 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
// Pass through translator; it yields one or more chunks for the target schema.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), &param)

View File

@@ -14,32 +14,54 @@ import (
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
// Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload
}
if !util.ModelSupportsThinking(model) {
if !util.ModelSupportsThinking(thinkingModel) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
}
// applyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
// ApplyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
func ApplyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
// Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload
}
if !util.ModelSupportsThinking(model) {
if !util.ModelSupportsThinking(thinkingModel) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
@@ -82,17 +104,11 @@ func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model
return payload
}
// applyPayloadConfig applies payload default and override rules from configuration
// to the given JSON payload for the specified model.
// Defaults only fill missing fields, while overrides always overwrite existing values.
func applyPayloadConfig(cfg *config.Config, model string, payload []byte) []byte {
return applyPayloadConfigWithRoot(cfg, model, "", "", payload)
}
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload []byte) []byte {
// and restricts matches to the given protocol when supplied. Defaults are checked
// against the original payload when provided.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte) []byte {
if cfg == nil || len(payload) == 0 {
return payload
}
@@ -105,6 +121,11 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
return payload
}
out := payload
source := original
if len(source) == 0 {
source = payload
}
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
@@ -116,7 +137,10 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
if fullPath == "" {
continue
}
if gjson.GetBytes(out, fullPath).Exists() {
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
@@ -124,6 +148,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.

View File

@@ -12,7 +12,6 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -50,17 +49,19 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return resp, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -129,15 +130,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return nil, errValidate
}
toolsResult := gjson.GetBytes(body, "tools")
@@ -147,7 +150,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
}
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))

View File

@@ -14,7 +14,6 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -136,14 +135,14 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if sessionID != "" && thinkingText != "" {
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
signature = cachedSig
log.Debugf("Using cached signature for thinking block")
// log.Debugf("Using cached signature for thinking block")
}
}
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" && cache.HasValidSignature(clientSignature) {
signature = clientSignature
log.Debugf("Using client-provided signature for thinking block")
// log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
@@ -158,8 +157,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
// Converting to text would break this requirement
if isUnsigned {
// TypeScript plugin approach: drop unsigned thinking blocks entirely
log.Debugf("Dropping unsigned thinking block (no valid signature)")
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
continue
}
@@ -183,7 +181,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
// NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected.
// The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies.
functionName := contentResult.Get("name").String()
argsResult := contentResult.Get("input")

View File

@@ -136,11 +136,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
log.Debug("Branch: signature_delta")
// log.Debug("Branch: signature_delta")
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset()
}

View File

@@ -184,7 +184,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -223,6 +223,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
@@ -266,6 +267,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}

View File

@@ -40,6 +40,16 @@ type claudeToResponsesState struct {
var dataTag = []byte("data:")
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
return originalRequestRawJSON
}
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
return requestRawJSON
}
return nil
}
func emitEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
}
@@ -279,8 +289,9 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String())
}
@@ -549,8 +560,9 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
out, _ = sjson.Set(out, "created_at", createdAt)
// Inject request echo fields as top-level (similar to streaming variant)
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
out, _ = sjson.Set(out, "instructions", v.String())
}

View File

@@ -20,6 +20,12 @@ var (
dataTag = []byte("data:")
)
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
}
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates Codex API responses
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
@@ -38,8 +44,10 @@ var (
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
hasToolCall := false
*param = &hasToolCall
*param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false,
BlockIndex: 0,
}
}
// log.Debugf("rawJSON: %s", string(rawJSON))
@@ -62,46 +70,49 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.content_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.output_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.content_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.completed" {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
p := (*param).(*bool)
if *p {
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
if p {
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
} else {
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
@@ -118,10 +129,9 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
p := true
*param = &p
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
{
// Restore original tool name if shortened
@@ -137,7 +147,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output += fmt.Sprintf("data: %s\n\n", template)
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
@@ -147,14 +157,15 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
}
} else if typeStr == "response.function_call_arguments.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n"

View File

@@ -275,7 +275,15 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
arr := tools.Array()
for i := 0; i < len(arr); i++ {
t := arr[i]
if t.Get("type").String() == "function" {
toolType := t.Get("type").String()
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
// Only "function" needs structural conversion because Chat Completions nests details under "function".
if toolType != "" && toolType != "function" && t.IsObject() {
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
continue
}
if toolType == "function" {
item := `{}`
item, _ = sjson.Set(item, "type", "function")
fn := t.Get("function")
@@ -304,6 +312,37 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
}
}
// Map tool_choice when present.
// Chat Completions: "tool_choice" can be a string ("auto"/"none") or an object (e.g. {"type":"function","function":{"name":"..."}}).
// Responses API: keep built-in tool choices as-is; flatten function choice to {"type":"function","name":"..."}.
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
switch {
case tc.Type == gjson.String:
out, _ = sjson.Set(out, "tool_choice", tc.String())
case tc.IsObject():
tcType := tc.Get("type").String()
if tcType == "function" {
name := tc.Get("function.name").String()
if name != "" {
if short, ok := originalToolNameMap[name]; ok {
name = short
} else {
name = shortenNameIfNeeded(name)
}
}
choice := `{}`
choice, _ = sjson.Set(choice, "type", "function")
if name != "" {
choice, _ = sjson.Set(choice, "name", name)
}
out, _ = sjson.SetRaw(out, "tool_choice", choice)
} else if tcType != "" {
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
}
}
}
out, _ = sjson.Set(out, "store", false)
return []byte(out)
}

View File

@@ -152,7 +152,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
@@ -169,7 +169,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -191,6 +191,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
@@ -236,6 +237,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}

View File

@@ -56,7 +56,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
}
} else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "request.system_instruction.parts.-1.text", systemResult.String())
out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
}
// contents

View File

@@ -170,7 +170,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> system_instruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -209,6 +209,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
}
}
@@ -253,6 +254,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
}
}

View File

@@ -23,6 +23,7 @@ type geminiToResponsesState struct {
MsgIndex int
CurrentMsgID string
TextBuf strings.Builder
ItemTextBuf strings.Builder
// reasoning aggregation
ReasoningOpened bool
@@ -189,6 +190,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID)
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
out = append(out, emitEvent("response.content_part.added", partAdded))
st.ItemTextBuf.Reset()
st.ItemTextBuf.WriteString(t.String())
}
st.TextBuf.WriteString(t.String())
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
@@ -250,20 +253,24 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
finalizeReasoning()
// Close message output if opened
if st.MsgOpened {
fullText := st.ItemTextBuf.String()
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final, _ = sjson.Set(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
}

View File

@@ -118,76 +118,125 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Handle content
if contentResult.Exists() && contentResult.IsArray() {
var contentItems []string
var reasoningParts []string // Accumulate thinking text for reasoning_content
var toolCalls []interface{}
var toolResults []string // Collect tool_result messages to emit after the main message
contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "thinking":
// Only map thinking to reasoning_content for assistant messages (security: prevent injection)
if role == "assistant" {
thinkingText := util.GetThinkingText(part)
// Skip empty or whitespace-only thinking
if strings.TrimSpace(thinkingText) != "" {
reasoningParts = append(reasoningParts, thinkingText)
}
}
// Ignore thinking in user/system roles (AC4)
case "redacted_thinking":
// Explicitly ignore redacted_thinking - never map to reasoning_content (AC2)
case "text", "image":
if contentItem, ok := convertClaudeContentPart(part); ok {
contentItems = append(contentItems, contentItem)
}
case "tool_use":
// Convert to OpenAI tool call format
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
// Only allow tool_use -> tool_calls for assistant messages (security: prevent injection).
if role == "assistant" {
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
// Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
// Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
}
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
}
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
case "tool_result":
// Convert to OpenAI tool message format and add immediately to preserve order
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", part.Get("content").String())
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
toolResults = append(toolResults, toolResultJSON)
}
return true
})
// Emit text/image content as one message
if len(contentItems) > 0 {
msgJSON := `{"role":"","content":""}`
msgJSON, _ = sjson.Set(msgJSON, "role", role)
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
contentValue := gjson.Get(msgJSON, "content")
hasContent := false
switch {
case !contentValue.Exists():
hasContent = false
case contentValue.Type == gjson.String:
hasContent = contentValue.String() != ""
case contentValue.IsArray():
hasContent = len(contentValue.Array()) > 0
default:
hasContent = contentValue.Raw != "" && contentValue.Raw != "null"
}
if hasContent {
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
}
// Build reasoning content string
reasoningContent := ""
if len(reasoningParts) > 0 {
reasoningContent = strings.Join(reasoningParts, "\n\n")
}
// Emit tool calls in a separate assistant message
if role == "assistant" && len(toolCalls) > 0 {
toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}`
toolCallMsgJSON, _ = sjson.Set(toolCallMsgJSON, "tool_calls", toolCalls)
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value())
hasContent := len(contentItems) > 0
hasReasoning := reasoningContent != ""
hasToolCalls := len(toolCalls) > 0
hasToolResults := len(toolResults) > 0
// OpenAI requires: tool messages MUST immediately follow the assistant message with tool_calls.
// Therefore, we emit tool_result messages FIRST (they respond to the previous assistant's tool_calls),
// then emit the current message's content.
for _, toolResultJSON := range toolResults {
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
}
// For assistant messages: emit a single unified message with content, tool_calls, and reasoning_content
// This avoids splitting into multiple assistant messages which breaks OpenAI tool-call adjacency
if role == "assistant" {
if hasContent || hasReasoning || hasToolCalls {
msgJSON := `{"role":"assistant"}`
// Add content (as array if we have items, empty string if reasoning-only)
if hasContent {
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
} else {
// Ensure content field exists for OpenAI compatibility
msgJSON, _ = sjson.Set(msgJSON, "content", "")
}
// Add reasoning_content if present
if hasReasoning {
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
}
// Add tool_calls if present (in same message as content)
if hasToolCalls {
msgJSON, _ = sjson.Set(msgJSON, "tool_calls", toolCalls)
}
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
}
} else {
// For non-assistant roles: emit content message if we have content
// If the message only contains tool_results (no text/image), we still processed them above
if hasContent {
msgJSON := `{"role":""}`
msgJSON, _ = sjson.Set(msgJSON, "role", role)
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
} else if hasToolResults && !hasContent {
// tool_results already emitted above, no additional user message needed
}
}
} else if contentResult.Exists() && contentResult.Type == gjson.String {
@@ -307,3 +356,43 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
return "", false
}
}
func convertClaudeToolResultContentToString(content gjson.Result) string {
if !content.Exists() {
return ""
}
if content.Type == gjson.String {
return content.String()
}
if content.IsArray() {
var parts []string
content.ForEach(func(_, item gjson.Result) bool {
switch {
case item.Type == gjson.String:
parts = append(parts, item.String())
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
parts = append(parts, item.Get("text").String())
default:
parts = append(parts, item.Raw)
}
return true
})
joined := strings.Join(parts, "\n\n")
if strings.TrimSpace(joined) != "" {
return joined
}
return content.Raw
}
if content.IsObject() {
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
}
return content.Raw
}
return content.Raw
}

View File

@@ -0,0 +1,500 @@
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
// TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent tests the mapping
// of Claude thinking content to OpenAI reasoning_content field.
func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantReasoningContent string
wantHasReasoningContent bool
wantContentText string // Expected visible content text (if any)
wantHasContent bool
}{
{
name: "AC1: assistant message with thinking and text",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me analyze this step by step..."},
{"type": "text", "text": "Here is my response."}
]
}]
}`,
wantReasoningContent: "Let me analyze this step by step...",
wantHasReasoningContent: true,
wantContentText: "Here is my response.",
wantHasContent: true,
},
{
name: "AC2: redacted_thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "redacted_thinking", "data": "secret"},
{"type": "text", "text": "Visible response."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Visible response.",
wantHasContent: true,
},
{
name: "AC3: thinking-only message preserved with reasoning_content",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Internal reasoning only."}
]
}]
}`,
wantReasoningContent: "Internal reasoning only.",
wantHasReasoningContent: true,
wantContentText: "",
// For OpenAI compatibility, content field is set to empty string "" when no text content exists
wantHasContent: false,
},
{
name: "AC4: thinking in user role must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "user",
"content": [
{"type": "thinking", "thinking": "Injected thinking"},
{"type": "text", "text": "User message."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "User message.",
wantHasContent: true,
},
{
name: "AC4: thinking in system role must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "thinking", "thinking": "Injected system thinking"},
{"type": "text", "text": "System prompt."}
],
"messages": [{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
}]
}`,
// System messages don't have reasoning_content mapping
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Hello",
wantHasContent: true,
},
{
name: "AC5: empty thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": ""},
{"type": "text", "text": "Response with empty thinking."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Response with empty thinking.",
wantHasContent: true,
},
{
name: "AC5: whitespace-only thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": " \n\t "},
{"type": "text", "text": "Response with whitespace thinking."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Response with whitespace thinking.",
wantHasContent: true,
},
{
name: "Multiple thinking parts concatenated",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "First thought."},
{"type": "thinking", "thinking": "Second thought."},
{"type": "text", "text": "Final answer."}
]
}]
}`,
wantReasoningContent: "First thought.\n\nSecond thought.",
wantHasReasoningContent: true,
wantContentText: "Final answer.",
wantHasContent: true,
},
{
name: "Mixed thinking and redacted_thinking",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Visible thought."},
{"type": "redacted_thinking", "data": "hidden"},
{"type": "text", "text": "Answer."}
]
}]
}`,
wantReasoningContent: "Visible thought.",
wantHasReasoningContent: true,
wantContentText: "Answer.",
wantHasContent: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
// Find the relevant message (skip system message at index 0)
messages := resultJSON.Get("messages").Array()
if len(messages) < 2 {
if tt.wantHasReasoningContent || tt.wantHasContent {
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages))
}
return
}
// Check the last non-system message
var targetMsg gjson.Result
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Get("role").String() != "system" {
targetMsg = messages[i]
break
}
}
// Check reasoning_content
gotReasoningContent := targetMsg.Get("reasoning_content").String()
gotHasReasoningContent := targetMsg.Get("reasoning_content").Exists()
if gotHasReasoningContent != tt.wantHasReasoningContent {
t.Errorf("reasoning_content existence = %v, want %v", gotHasReasoningContent, tt.wantHasReasoningContent)
}
if gotReasoningContent != tt.wantReasoningContent {
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
}
// Check content
content := targetMsg.Get("content")
// content has meaningful content if it's a non-empty array, or a non-empty string
var gotHasContent bool
switch {
case content.IsArray():
gotHasContent = len(content.Array()) > 0
case content.Type == gjson.String:
gotHasContent = content.String() != ""
default:
gotHasContent = false
}
if gotHasContent != tt.wantHasContent {
t.Errorf("content existence = %v, want %v", gotHasContent, tt.wantHasContent)
}
if tt.wantHasContent && tt.wantContentText != "" {
// Find text content
var foundText string
content.ForEach(func(_, v gjson.Result) bool {
if v.Get("type").String() == "text" {
foundText = v.Get("text").String()
return false
}
return true
})
if foundText != tt.wantContentText {
t.Errorf("content text = %q, want %q", foundText, tt.wantContentText)
}
}
})
}
}
// TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved tests AC3:
// that a message with only thinking content is preserved (not dropped).
func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}]
},
{
"role": "assistant",
"content": [{"type": "thinking", "thinking": "Let me calculate: 2+2=4"}]
},
{
"role": "user",
"content": [{"type": "text", "text": "Thanks"}]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages
if len(messages) != 4 {
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
}
// Check the assistant message (index 2) has reasoning_content
assistantMsg := messages[2]
if assistantMsg.Get("role").String() != "assistant" {
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String())
}
if !assistantMsg.Get("reasoning_content").Exists() {
t.Error("Expected assistant message to have reasoning_content")
}
if assistantMsg.Get("reasoning_content").String() != "Let me calculate: 2+2=4" {
t.Errorf("Unexpected reasoning_content: %s", assistantMsg.Get("reasoning_content").String())
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "before"},
{"type": "tool_result", "tool_use_id": "call_1", "content": [{"type":"text","text":"tool ok"}]},
{"type": "text", "text": "after"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after)
if len(messages) != 4 {
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[0].Get("role").String() != "system" {
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
}
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
}
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
if messages[2].Get("role").String() != "tool" {
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String())
}
if got := messages[2].Get("tool_call_id").String(); got != "call_1" {
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
}
if got := messages[2].Get("content").String(); got != "tool ok" {
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
}
// User message comes after tool message
if messages[3].Get("role").String() != "user" {
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String())
}
// User message should contain both "before" and "after" text
if got := messages[3].Get("content.0.text").String(); got != "before" {
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
}
if got := messages[3].Get("content.1.text").String(); got != "after" {
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "call_1", "content": {"foo": "bar"}}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// system + assistant(tool_calls) + tool(result)
if len(messages) != 3 {
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[2].Get("role").String() != "tool" {
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String())
}
toolContent := messages[2].Get("content").String()
parsed := gjson.Parse(toolContent)
if parsed.Get("foo").String() != "bar" {
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "pre"},
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
{"type": "text", "text": "post"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// New behavior: content + tool_calls unified in single assistant message
// Expect: system + assistant(content[pre,post] + tool_calls)
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[0].Get("role").String() != "system" {
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
}
assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
}
// Should have both content and tool_calls in same message
if !assistantMsg.Get("tool_calls").Exists() {
t.Fatalf("Expected assistant message to have tool_calls")
}
if got := assistantMsg.Get("tool_calls.0.id").String(); got != "call_1" {
t.Fatalf("Expected tool_call id %q, got %q", "call_1", got)
}
if got := assistantMsg.Get("tool_calls.0.function.name").String(); got != "do_work" {
t.Fatalf("Expected tool_call name %q, got %q", "do_work", got)
}
// Content should have both pre and post text
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
}
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "t1"},
{"type": "text", "text": "pre"},
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
{"type": "thinking", "thinking": "t2"},
{"type": "text", "text": "post"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// New behavior: all content, thinking, and tool_calls unified in single assistant message
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
}
// Should have content with both pre and post
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
}
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
}
// Should have tool_calls
if !assistantMsg.Get("tool_calls").Exists() {
t.Fatalf("Expected assistant message to have tool_calls")
}
// Should have combined reasoning_content from both thinking blocks
if got := assistantMsg.Get("reasoning_content").String(); got != "t1\n\nt2" {
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
}
}

View File

@@ -299,17 +299,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
inputTokens = promptTokens.Int()
outputTokens = completionTokens.Int()
}
// Send message_delta with usage
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results)
}
// Send message_delta with usage
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results)
}
return results
@@ -480,15 +479,15 @@ func collectOpenAIReasoningTexts(node gjson.Result) []string {
switch node.Type {
case gjson.String:
if text := strings.TrimSpace(node.String()); text != "" {
if text := node.String(); text != "" {
texts = append(texts, text)
}
case gjson.JSON:
if text := node.Get("text"); text.Exists() {
if trimmed := strings.TrimSpace(text.String()); trimmed != "" {
texts = append(texts, trimmed)
if textStr := text.String(); textStr != "" {
texts = append(texts, textStr)
}
} else if raw := strings.TrimSpace(node.Raw); raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
} else if raw := node.Raw; raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
texts = append(texts, raw)
}
}

View File

@@ -163,6 +163,14 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
var chatCompletionsTools []interface{}
tools.ForEach(func(_, tool gjson.Result) bool {
// Built-in tools (e.g. {"type":"web_search"}) are already compatible with the Chat Completions schema.
// Only function tools need structural conversion because Chat Completions nests details under "function".
toolType := tool.Get("type").String()
if toolType != "" && toolType != "function" && tool.IsObject() {
chatCompletionsTools = append(chatCompletionsTools, tool.Value())
return true
}
chatTool := `{"type":"function","function":{}}`
// Convert tool structure from responses format to chat completions format

View File

@@ -344,7 +344,7 @@ func cleanupRequiredFields(jsonStr string) string {
}
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
// Claude VALIDATED mode requires at least one property in tool schemas.
// Claude VALIDATED mode requires at least one required property in tool schemas.
func addEmptySchemaPlaceholder(jsonStr string) string {
// Find all "type" fields
paths := findPaths(jsonStr, "type")
@@ -364,6 +364,9 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
// Check if properties exists and is empty or missing
propsPath := joinPath(parentPath, "properties")
propsVal := gjson.Get(jsonStr, propsPath)
reqPath := joinPath(parentPath, "required")
reqVal := gjson.Get(jsonStr, reqPath)
hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
needsPlaceholder := false
if !propsVal.Exists() {
@@ -381,8 +384,22 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
// Add to required array
reqPath := joinPath(parentPath, "required")
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
continue
}
// If schema has properties but none are required, add a minimal placeholder.
if propsVal.IsObject() && !hasRequiredProperties {
// DO NOT add placeholder if it's a top-level schema (parentPath is empty)
// or if we've already added a placeholder reason above.
if parentPath == "" {
continue
}
placeholderPath := joinPath(propsPath, "_")
if !gjson.Get(jsonStr, placeholderPath).Exists() {
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
}
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
}
}

View File

@@ -127,8 +127,10 @@ func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing
"type": "object",
"description": "Accepts: null | object",
"properties": {
"_": { "type": "boolean" },
"kind": { "type": "string" }
}
},
"required": ["_"]
}
}
}`
@@ -614,71 +616,6 @@ func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
}
}
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
// propertyNames is used to validate object property names (e.g., must match a pattern)
// Gemini doesn't support this keyword and will reject requests containing it
input := `{
"type": "object",
"properties": {
"metadata": {
"type": "object",
"propertyNames": {
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
},
"additionalProperties": {
"type": "string"
}
}
}
}`
expected := `{
"type": "object",
"properties": {
"metadata": {
"type": "object"
}
}
}`
result := CleanJSONSchemaForGemini(input)
compareJSON(t, expected, result)
// Verify propertyNames is completely removed
if strings.Contains(result, "propertyNames") {
t.Errorf("propertyNames keyword should be removed, got: %s", result)
}
}
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
input := `{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"config": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
}
}
}
}`
result := CleanJSONSchemaForGemini(input)
if strings.Contains(result, "propertyNames") {
t.Errorf("Nested propertyNames should be removed, got: %s", result)
}
}
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
var expMap, actMap map[string]interface{}
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)

View File

@@ -71,10 +71,13 @@ func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool)
incl = &defaultInclude
}
if incl != nil {
valuePath := "generationConfig.thinkingConfig.include_thoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
if !gjson.GetBytes(updated, "generationConfig.thinkingConfig.includeThoughts").Exists() &&
!gjson.GetBytes(updated, "generationConfig.thinkingConfig.include_thoughts").Exists() {
valuePath := "generationConfig.thinkingConfig.include_thoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
}
}
}
return updated
@@ -99,10 +102,13 @@ func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *boo
incl = &defaultInclude
}
if incl != nil {
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
if !gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.includeThoughts").Exists() &&
!gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts").Exists() {
valuePath := "request.generationConfig.thinkingConfig.include_thoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
}
}
}
return updated
@@ -130,15 +136,15 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool)
incl = &defaultInclude
}
if incl != nil {
valuePath := "generationConfig.thinkingConfig.includeThoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
if !gjson.GetBytes(updated, "generationConfig.thinkingConfig.includeThoughts").Exists() &&
!gjson.GetBytes(updated, "generationConfig.thinkingConfig.include_thoughts").Exists() {
valuePath := "generationConfig.thinkingConfig.includeThoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
}
}
}
if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts")
}
if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget")
}
@@ -167,15 +173,15 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo
incl = &defaultInclude
}
if incl != nil {
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
if !gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.includeThoughts").Exists() &&
!gjson.GetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts").Exists() {
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
if err == nil {
updated = rewritten
}
}
}
if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts")
}
if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget")
}
@@ -251,9 +257,14 @@ func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) {
// modelsWithDefaultThinking lists models that should have thinking enabled by default
// when no explicit thinkingConfig is provided.
// Note: Gemini 3 models are NOT included here because per Google's official documentation:
// - thinkingLevel defaults to "high" (dynamic thinking)
// - includeThoughts defaults to false
//
// We should not override these API defaults; let users explicitly configure if needed.
var modelsWithDefaultThinking = map[string]bool{
"gemini-3-pro-preview": true,
"gemini-3-pro-image-preview": true,
// "gemini-3-pro-preview": true,
// "gemini-3-pro-image-preview": true,
// "gemini-3-flash-preview": true,
}
@@ -288,37 +299,73 @@ func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
// ApplyGemini3ThinkingLevelFromMetadata applies thinkingLevel from metadata for Gemini 3 models.
// For standard Gemini API format (generationConfig.thinkingConfig path).
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal))
// or numeric budget suffix (e.g., model(1000)) which gets converted to a thinkingLevel.
func ApplyGemini3ThinkingLevelFromMetadata(model string, metadata map[string]any, body []byte) []byte {
if !IsGemini3Model(model) {
// Use the alias from metadata if available for model type detection
lookupModel := ResolveOriginalModel(model, metadata)
if !IsGemini3Model(lookupModel) && !IsGemini3Model(model) {
return body
}
// Determine which model to use for validation
checkModel := model
if IsGemini3Model(lookupModel) {
checkModel = lookupModel
}
// First try to get effort string from metadata
effort, ok := ReasoningEffortFromMetadata(metadata)
if !ok || effort == "" {
return body
if ok && effort != "" {
if level, valid := ValidateGemini3ThinkingLevel(checkModel, effort); valid {
return ApplyGeminiThinkingLevel(body, level, nil)
}
}
// Validate and apply the thinkingLevel
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
return ApplyGeminiThinkingLevel(body, level, nil)
// Fallback: check for numeric budget and convert to thinkingLevel
budget, _, _, matched := ThinkingFromMetadata(metadata)
if matched && budget != nil {
if level, valid := ThinkingBudgetToGemini3Level(checkModel, *budget); valid {
return ApplyGeminiThinkingLevel(body, level, nil)
}
}
return body
}
// ApplyGemini3ThinkingLevelFromMetadataCLI applies thinkingLevel from metadata for Gemini 3 models.
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal))
// or numeric budget suffix (e.g., model(1000)) which gets converted to a thinkingLevel.
func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]any, body []byte) []byte {
if !IsGemini3Model(model) {
// Use the alias from metadata if available for model type detection
lookupModel := ResolveOriginalModel(model, metadata)
if !IsGemini3Model(lookupModel) && !IsGemini3Model(model) {
return body
}
// Determine which model to use for validation
checkModel := model
if IsGemini3Model(lookupModel) {
checkModel = lookupModel
}
// First try to get effort string from metadata
effort, ok := ReasoningEffortFromMetadata(metadata)
if !ok || effort == "" {
return body
if ok && effort != "" {
if level, valid := ValidateGemini3ThinkingLevel(checkModel, effort); valid {
return ApplyGeminiCLIThinkingLevel(body, level, nil)
}
}
// Validate and apply the thinkingLevel
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
return ApplyGeminiCLIThinkingLevel(body, level, nil)
// Fallback: check for numeric budget and convert to thinkingLevel
budget, _, _, matched := ThinkingFromMetadata(metadata)
if matched && budget != nil {
if level, valid := ThinkingBudgetToGemini3Level(checkModel, *budget); valid {
return ApplyGeminiCLIThinkingLevel(body, level, nil)
}
}
return body
}
@@ -326,15 +373,17 @@ func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
if !ModelHasDefaultThinking(model) {
func ApplyDefaultThinkingIfNeededCLI(model string, metadata map[string]any, body []byte) []byte {
// Use the alias from metadata if available for model property lookup
lookupModel := ResolveOriginalModel(model, metadata)
if !ModelHasDefaultThinking(lookupModel) && !ModelHasDefaultThinking(model) {
return body
}
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
return body
}
// Gemini 3 models use thinkingLevel instead of thinkingBudget
if IsGemini3Model(model) {
if IsGemini3Model(lookupModel) || IsGemini3Model(model) {
// Don't set a default - let the API use its dynamic default ("high")
// Only set includeThoughts
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts", true)

View File

@@ -0,0 +1,56 @@
package util
import (
"testing"
)
func TestSanitizeFunctionName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"Normal", "valid_name", "valid_name"},
{"With Dots", "name.with.dots", "name.with.dots"},
{"With Colons", "name:with:colons", "name:with:colons"},
{"With Dashes", "name-with-dashes", "name-with-dashes"},
{"Mixed Allowed", "name.with_dots:colons-dashes", "name.with_dots:colons-dashes"},
{"Invalid Characters", "name!with@invalid#chars", "name_with_invalid_chars"},
{"Spaces", "name with spaces", "name_with_spaces"},
{"Non-ASCII", "name_with_你好_chars", "name_with____chars"},
{"Starts with digit", "123name", "_123name"},
{"Starts with dot", ".name", "_.name"},
{"Starts with colon", ":name", "_:name"},
{"Starts with dash", "-name", "_-name"},
{"Starts with invalid char", "!name", "_name"},
{"Exactly 64 chars", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
{"Too long (65 chars)", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charactX", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
{"Very long", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_limit_for_function_names", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_l"},
{"Starts with digit (64 chars total)", "1234567890123456789012345678901234567890123456789012345678901234", "_123456789012345678901234567890123456789012345678901234567890123"},
{"Starts with invalid char (64 chars total)", "!234567890123456789012345678901234567890123456789012345678901234", "_234567890123456789012345678901234567890123456789012345678901234"},
{"Empty", "", ""},
{"Single character invalid", "@", "_"},
{"Single character valid", "a", "a"},
{"Single character digit", "1", "_1"},
{"Single character underscore", "_", "_"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeFunctionName(tt.input)
if got != tt.expected {
t.Errorf("SanitizeFunctionName(%q) = %v, want %v", tt.input, got, tt.expected)
}
// Verify Gemini compliance
if len(got) > 64 {
t.Errorf("SanitizeFunctionName(%q) result too long: %d", tt.input, len(got))
}
if len(got) > 0 {
first := got[0]
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
t.Errorf("SanitizeFunctionName(%q) result starts with invalid char: %c", tt.input, first)
}
}
})
}
}

View File

@@ -12,9 +12,18 @@ func ModelSupportsThinking(model string) bool {
if model == "" {
return false
}
// First check the global dynamic registry
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil {
return info.Thinking != nil
}
// Fallback: check static model definitions
if info := registry.LookupStaticModelInfo(model); info != nil {
return info.Thinking != nil
}
// Fallback: check Antigravity static config
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil {
return cfg.Thinking != nil
}
return false
}
@@ -63,11 +72,19 @@ func thinkingRangeFromRegistry(model string) (found bool, min int, max int, zero
if model == "" {
return false, 0, 0, false, false
}
info := registry.GetGlobalRegistry().GetModelInfo(model)
if info == nil || info.Thinking == nil {
return false, 0, 0, false, false
// First check global dynamic registry
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil && info.Thinking != nil {
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
}
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
// Fallback: check static model definitions
if info := registry.LookupStaticModelInfo(model); info != nil && info.Thinking != nil {
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
}
// Fallback: check Antigravity static config
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil && cfg.Thinking != nil {
return true, cfg.Thinking.Min, cfg.Thinking.Max, cfg.Thinking.ZeroAllowed, cfg.Thinking.DynamicAllowed
}
return false, 0, 0, false, false
}
// GetModelThinkingLevels returns the discrete reasoning effort levels for the model.

View File

@@ -7,10 +7,11 @@ import (
)
const (
ThinkingBudgetMetadataKey = "thinking_budget"
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
ReasoningEffortMetadataKey = "reasoning_effort"
ThinkingOriginalModelMetadataKey = "thinking_original_model"
ThinkingBudgetMetadataKey = "thinking_budget"
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
ReasoningEffortMetadataKey = "reasoning_effort"
ThinkingOriginalModelMetadataKey = "thinking_original_model"
ModelMappingOriginalModelMetadataKey = "model_mapping_original_model"
)
// NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns
@@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string {
}
if metadata != nil {
if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
if base := normalize(s); base != "" {
return base
}
}
}
if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
if base := normalize(s); base != "" {

View File

@@ -4,16 +4,56 @@
package util
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
var functionNameSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
// SanitizeFunctionName ensures a function name matches the requirements for Gemini/Vertex AI.
// It replaces invalid characters with underscores, ensures it starts with a letter or underscore,
// and truncates it to 64 characters if necessary.
// Regex Rule: [^a-zA-Z0-9_.:-] replaced with _.
func SanitizeFunctionName(name string) string {
if name == "" {
return ""
}
// Replace invalid characters with underscore
sanitized := functionNameSanitizer.ReplaceAllString(name, "_")
// Ensure it starts with a letter or underscore
// Re-reading requirements: Must start with a letter or an underscore.
if len(sanitized) > 0 {
first := sanitized[0]
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
// If it starts with an allowed character but not allowed at the beginning (digit, dot, colon, dash),
// we must prepend an underscore.
// To stay within the 64-character limit while prepending, we must truncate first.
if len(sanitized) >= 64 {
sanitized = sanitized[:63]
}
sanitized = "_" + sanitized
}
} else {
sanitized = "_"
}
// Truncate to 64 characters
if len(sanitized) > 64 {
sanitized = sanitized[:64]
}
return sanitized
}
// SetLogLevel configures the logrus log level based on the configuration.
// It sets the log level to DebugLevel if debug mode is enabled, otherwise to InfoLevel.
func SetLogLevel(cfg *config.Config) {
@@ -53,36 +93,23 @@ func ResolveAuthDir(authDir string) (string, error) {
return filepath.Clean(authDir), nil
}
// CountAuthFiles returns the number of JSON auth files located under the provided directory.
// The function resolves leading tildes to the user's home directory and performs a case-insensitive
// match on the ".json" suffix so that files saved with uppercase extensions are also counted.
func CountAuthFiles(authDir string) int {
dir, err := ResolveAuthDir(authDir)
// CountAuthFiles returns the number of auth records available through the provided Store.
// For filesystem-backed stores, this reflects the number of JSON auth files under the configured directory.
func CountAuthFiles[T any](ctx context.Context, store interface {
List(context.Context) ([]T, error)
}) int {
if store == nil {
return 0
}
if ctx == nil {
ctx = context.Background()
}
entries, err := store.List(ctx)
if err != nil {
log.Debugf("countAuthFiles: failed to resolve auth directory: %v", err)
log.Debugf("countAuthFiles: failed to list auth records: %v", err)
return 0
}
if dir == "" {
return 0
}
count := 0
walkErr := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
log.Debugf("countAuthFiles: error accessing %s: %v", path, err)
return nil
}
if d.IsDir() {
return nil
}
if strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
count++
}
return nil
})
if walkErr != nil {
log.Debugf("countAuthFiles: walk error: %v", walkErr)
}
return count
return len(entries)
}
// WritablePath returns the cleaned WRITABLE_PATH environment variable when it is set.

View File

@@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/hex"
"os"
"reflect"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool {
}
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelMappings, newConfig.OAuthModelMappings))
log.Infof("config successfully reloaded, triggering client reload")
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)

View File

@@ -90,6 +90,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
}
oldModels := SummarizeGeminiModels(o.Models)
newModels := SummarizeGeminiModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("gemini[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
@@ -120,6 +125,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
}
oldModels := SummarizeClaudeModels(o.Models)
newModels := SummarizeClaudeModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("claude[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
@@ -150,6 +160,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
}
oldModels := SummarizeCodexModels(o.Models)
newModels := SummarizeCodexModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("codex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
@@ -185,10 +200,18 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
}
oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys)
newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys)
if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) {
changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount))
}
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
changes = append(changes, entries...)
}
if entries, _ := DiffOAuthModelMappingChanges(oldCfg.OAuthModelMappings, newCfg.OAuthModelMappings); len(entries) > 0 {
changes = append(changes, entries...)
}
// Remote management (never print the key)
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
@@ -301,3 +324,43 @@ func formatProxyURL(raw string) string {
}
return scheme + "://" + host
}
func equalStringSet(a, b []string) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
aSet := make(map[string]struct{}, len(a))
for _, k := range a {
aSet[strings.TrimSpace(k)] = struct{}{}
}
bSet := make(map[string]struct{}, len(b))
for _, k := range b {
bSet[strings.TrimSpace(k)] = struct{}{}
}
if len(aSet) != len(bSet) {
return false
}
for k := range aSet {
if _, ok := bSet[k]; !ok {
return false
}
}
return true
}
// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality.
// Comparison is done by count and content (upstream key and client keys).
func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) {
return false
}
if !equalStringSet(a[i].APIKeys, b[i].APIKeys) {
return false
}
}
return true
}

View File

@@ -56,6 +56,36 @@ func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
return hashJoined(keys)
}
// ComputeCodexModelsHash returns a stable hash for Codex model aliases.
func ComputeCodexModelsHash(models []config.CodexModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeGeminiModelsHash returns a stable hash for Gemini model aliases.
func ComputeGeminiModelsHash(models []config.GeminiModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
func ComputeExcludedModelsHash(excluded []string) string {
if len(excluded) == 0 {

View File

@@ -81,6 +81,15 @@ func TestComputeClaudeModelsHash_Empty(t *testing.T) {
}
}
func TestComputeCodexModelsHash_Empty(t *testing.T) {
if got := ComputeCodexModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil models, got %q", got)
}
if got := ComputeCodexModelsHash([]config.CodexModel{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
}
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.ClaudeModel{
{Name: "m1", Alias: "a1"},
@@ -95,6 +104,20 @@ func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
}
}
func TestComputeCodexModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.CodexModel{
{Name: "m1", Alias: "a1"},
{Name: " "},
{Name: "M1", Alias: "A1"},
}
b := []config.CodexModel{
{Name: "m1", Alias: "a1"},
}
if h1, h2 := ComputeCodexModelsHash(a), ComputeCodexModelsHash(b); h1 == "" || h1 != h2 {
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
}
}
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
@@ -157,3 +180,15 @@ func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
t.Fatalf("expected different hash when models change, got %s", h3)
}
}
func TestComputeCodexModelsHash_Deterministic(t *testing.T) {
models := []config.CodexModel{{Name: "a", Alias: "A"}, {Name: "b"}}
h1 := ComputeCodexModelsHash(models)
h2 := ComputeCodexModelsHash(models)
if h1 == "" || h1 != h2 {
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
}
if h3 := ComputeCodexModelsHash([]config.CodexModel{{Name: "a"}}); h3 == h1 {
t.Fatalf("expected different hash when models change, got %s", h3)
}
}

View File

@@ -0,0 +1,121 @@
package diff
import (
"crypto/sha256"
"encoding/hex"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type GeminiModelsSummary struct {
hash string
count int
}
type ClaudeModelsSummary struct {
hash string
count int
}
type CodexModelsSummary struct {
hash string
count int
}
type VertexModelsSummary struct {
hash string
count int
}
// SummarizeGeminiModels hashes Gemini model aliases for change detection.
func SummarizeGeminiModels(models []config.GeminiModel) GeminiModelsSummary {
if len(models) == 0 {
return GeminiModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return GeminiModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeClaudeModels hashes Claude model aliases for change detection.
func SummarizeClaudeModels(models []config.ClaudeModel) ClaudeModelsSummary {
if len(models) == 0 {
return ClaudeModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return ClaudeModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeCodexModels hashes Codex model aliases for change detection.
func SummarizeCodexModels(models []config.CodexModel) CodexModelsSummary {
if len(models) == 0 {
return CodexModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return CodexModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeVertexModels hashes Vertex-compatible model aliases for change detection.
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
if len(models) == 0 {
return VertexModelsSummary{}
}
names := make([]string, 0, len(models))
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
if alias != "" {
name = alias
}
names = append(names, name)
}
if len(names) == 0 {
return VertexModelsSummary{}
}
sort.Strings(names)
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
return VertexModelsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(names),
}
}

View File

@@ -116,36 +116,3 @@ func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappin
count: len(entries),
}
}
type VertexModelsSummary struct {
hash string
count int
}
// SummarizeVertexModels hashes vertex-compatible models for change detection.
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
if len(models) == 0 {
return VertexModelsSummary{}
}
names := make([]string, 0, len(models))
for _, m := range models {
name := strings.TrimSpace(m.Name)
alias := strings.TrimSpace(m.Alias)
if name == "" && alias == "" {
continue
}
if alias != "" {
name = alias
}
names = append(names, name)
}
if len(names) == 0 {
return VertexModelsSummary{}
}
sort.Strings(names)
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
return VertexModelsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(names),
}
}

View File

@@ -0,0 +1,101 @@
package diff
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type OAuthModelMappingsSummary struct {
hash string
count int
}
// SummarizeOAuthModelMappings summarizes OAuth model mappings per channel.
func SummarizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string]OAuthModelMappingsSummary {
if len(entries) == 0 {
return nil
}
out := make(map[string]OAuthModelMappingsSummary, len(entries))
for k, v := range entries {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
out[key] = summarizeOAuthModelMappingList(v)
}
if len(out) == 0 {
return nil
}
return out
}
// DiffOAuthModelMappingChanges compares OAuth model mappings maps.
func DiffOAuthModelMappingChanges(oldMap, newMap map[string][]config.ModelNameMapping) ([]string, []string) {
oldSummary := SummarizeOAuthModelMappings(oldMap)
newSummary := SummarizeOAuthModelMappings(newMap)
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
for k := range oldSummary {
keys[k] = struct{}{}
}
for k := range newSummary {
keys[k] = struct{}{}
}
changes := make([]string, 0, len(keys))
affected := make([]string, 0, len(keys))
for key := range keys {
oldInfo, okOld := oldSummary[key]
newInfo, okNew := newSummary[key]
switch {
case okOld && !okNew:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: removed", key))
affected = append(affected, key)
case !okOld && okNew:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: added (%d entries)", key, newInfo.count))
affected = append(affected, key)
case okOld && okNew && oldInfo.hash != newInfo.hash:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
affected = append(affected, key)
}
}
sort.Strings(changes)
sort.Strings(affected)
return changes, affected
}
func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMappingsSummary {
if len(list) == 0 {
return OAuthModelMappingsSummary{}
}
seen := make(map[string]struct{}, len(list))
normalized := make([]string, 0, len(list))
for _, mapping := range list {
name := strings.ToLower(strings.TrimSpace(mapping.Name))
alias := strings.ToLower(strings.TrimSpace(mapping.Alias))
if name == "" || alias == "" {
continue
}
key := name + "->" + alias
if mapping.Fork {
key += "|fork"
}
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
normalized = append(normalized, key)
}
if len(normalized) == 0 {
return OAuthModelMappingsSummary{}
}
sort.Strings(normalized)
sum := sha256.Sum256([]byte(strings.Join(normalized, "|")))
return OAuthModelMappingsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(normalized),
}
}

View File

@@ -62,6 +62,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
if base != "" {
attrs["base_url"] = base
}
if hash := diff.ComputeGeminiModelsHash(entry.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(entry.Headers, attrs)
a := &coreauth.Auth{
ID: id,
@@ -147,6 +150,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL
}
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(ck.Headers, attrs)
proxyURL := strings.TrimSpace(ck.ProxyURL)
a := &coreauth.Auth{

View File

@@ -618,7 +618,22 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
}
body := BuildErrorResponseBody(status, errText)
c.Set("API_RESPONSE", bytes.Clone(body))
// Append first to preserve upstream response logs, then drop duplicate payloads if already recorded.
var previous []byte
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
previous = bytes.Clone(existingBytes)
}
}
appendAPIResponse(c, body)
trimmedErrText := strings.TrimSpace(errText)
trimmedBody := bytes.TrimSpace(body)
if len(previous) > 0 {
if (trimmedErrText != "" && bytes.Contains(previous, []byte(trimmedErrText))) ||
(len(trimmedBody) > 0 && bytes.Contains(previous, trimmedBody)) {
c.Set("API_RESPONSE", previous)
}
}
if !c.Writer.Written() {
c.Writer.Header().Set("Content-Type", "application/json")

72
sdk/api/management.go Normal file
View File

@@ -0,0 +1,72 @@
// Package api exposes helpers for embedding CLIProxyAPI.
//
// It wraps internal management handler types so external projects can integrate
// management endpoints without importing internal packages.
package api
import (
"github.com/gin-gonic/gin"
internalmanagement "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// ManagementTokenRequester exposes a limited subset of management endpoints for requesting tokens.
type ManagementTokenRequester interface {
RequestAnthropicToken(*gin.Context)
RequestGeminiCLIToken(*gin.Context)
RequestCodexToken(*gin.Context)
RequestAntigravityToken(*gin.Context)
RequestQwenToken(*gin.Context)
RequestIFlowToken(*gin.Context)
RequestIFlowCookieToken(*gin.Context)
GetAuthStatus(c *gin.Context)
PostOAuthCallback(c *gin.Context)
}
type managementTokenRequester struct {
handler *internalmanagement.Handler
}
// NewManagementTokenRequester creates a limited management handler exposing only token request endpoints.
func NewManagementTokenRequester(cfg *config.Config, manager *coreauth.Manager) ManagementTokenRequester {
return &managementTokenRequester{
handler: internalmanagement.NewHandlerWithoutConfigFilePath(cfg, manager),
}
}
func (m *managementTokenRequester) RequestAnthropicToken(c *gin.Context) {
m.handler.RequestAnthropicToken(c)
}
func (m *managementTokenRequester) RequestGeminiCLIToken(c *gin.Context) {
m.handler.RequestGeminiCLIToken(c)
}
func (m *managementTokenRequester) RequestCodexToken(c *gin.Context) {
m.handler.RequestCodexToken(c)
}
func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
m.handler.RequestAntigravityToken(c)
}
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
m.handler.RequestQwenToken(c)
}
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
m.handler.RequestIFlowToken(c)
}
func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
m.handler.RequestIFlowCookieToken(c)
}
func (m *managementTokenRequester) GetAuthStatus(c *gin.Context) {
m.handler.GetAuthStatus(c)
}
func (m *managementTokenRequester) PostOAuthCallback(c *gin.Context) {
m.handler.PostOAuthCallback(c)
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"net/http"
"path/filepath"
"strconv"
"strings"
"sync"
@@ -111,6 +112,9 @@ type Manager struct {
requestRetry atomic.Int32
maxRetryInterval atomic.Int64
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
modelNameMappings atomic.Value
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider
@@ -385,22 +389,8 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
return cliproxyexecutor.Response{}, errPick
}
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
entry := logEntryWithRequestID(ctx)
if accountType == "api_key" {
if proxyInfo != "" {
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
} else {
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
}
} else if accountType == "oauth" {
if proxyInfo != "" {
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
} else {
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
}
}
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -410,6 +400,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -446,22 +437,8 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
return cliproxyexecutor.Response{}, errPick
}
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
entry := logEntryWithRequestID(ctx)
if accountType == "api_key" {
if proxyInfo != "" {
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
} else {
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
}
} else if accountType == "oauth" {
if proxyInfo != "" {
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
} else {
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
}
}
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -471,6 +448,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -507,22 +485,8 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
return nil, errPick
}
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
entry := logEntryWithRequestID(ctx)
if accountType == "api_key" {
if proxyInfo != "" {
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
} else {
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
}
} else if accountType == "oauth" {
if proxyInfo != "" {
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
} else {
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
}
}
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -532,6 +496,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
@@ -592,6 +557,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
keys := []string{
util.ThinkingOriginalModelMetadataKey,
util.GeminiOriginalModelMetadataKey,
util.ModelMappingOriginalModelMetadataKey,
}
var out map[string]any
for _, key := range keys {
@@ -1529,6 +1495,9 @@ func (m *Manager) markRefreshPending(id string, now time.Time) bool {
}
func (m *Manager) refreshAuth(ctx context.Context, id string) {
if ctx == nil {
ctx = context.Background()
}
m.mu.RLock()
auth := m.auths[id]
var exec ProviderExecutor
@@ -1541,6 +1510,10 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
}
cloned := auth.Clone()
updated, err := exec.Refresh(ctx, cloned)
if err != nil && errors.Is(err, context.Canceled) {
log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID)
return
}
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
now := time.Now()
if err != nil {
@@ -1610,6 +1583,59 @@ func logEntryWithRequestID(ctx context.Context) *log.Entry {
return log.NewEntry(log.StandardLogger())
}
func debugLogAuthSelection(entry *log.Entry, auth *Auth, provider string, model string) {
if !log.IsLevelEnabled(log.DebugLevel) {
return
}
if entry == nil || auth == nil {
return
}
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
suffix := ""
if proxyInfo != "" {
suffix = " " + proxyInfo
}
switch accountType {
case "api_key":
entry.Debugf("Use API key %s for model %s%s", util.HideAPIKey(accountInfo), model, suffix)
case "oauth":
ident := formatOauthIdentity(auth, provider, accountInfo)
entry.Debugf("Use OAuth %s for model %s%s", ident, model, suffix)
}
}
func formatOauthIdentity(auth *Auth, provider string, accountInfo string) string {
if auth == nil {
return ""
}
// Prefer the auth's provider when available.
providerName := strings.TrimSpace(auth.Provider)
if providerName == "" {
providerName = strings.TrimSpace(provider)
}
// Only log the basename to avoid leaking host paths.
// FileName may be unset for some auth backends; fall back to ID.
authFile := strings.TrimSpace(auth.FileName)
if authFile == "" {
authFile = strings.TrimSpace(auth.ID)
}
if authFile != "" {
authFile = filepath.Base(authFile)
}
parts := make([]string, 0, 3)
if providerName != "" {
parts = append(parts, "provider="+providerName)
}
if authFile != "" {
parts = append(parts, "auth_file="+authFile)
}
if len(parts) == 0 {
return accountInfo
}
return strings.Join(parts, " ")
}
// InjectCredentials delegates per-provider HTTP request preparation when supported.
// If the registered executor for the auth provider implements RequestPreparer,
// it will be invoked to modify the request (e.g., add headers).

View File

@@ -0,0 +1,171 @@
package auth
import (
"strings"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
type modelNameMappingTable struct {
// reverse maps channel -> alias (lower) -> original upstream model name.
reverse map[string]map[string]string
}
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
if len(mappings) == 0 {
return &modelNameMappingTable{}
}
out := &modelNameMappingTable{
reverse: make(map[string]map[string]string, len(mappings)),
}
for rawChannel, entries := range mappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
rev := make(map[string]string, len(entries))
for _, entry := range entries {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
aliasKey := strings.ToLower(alias)
if _, exists := rev[aliasKey]; exists {
continue
}
rev[aliasKey] = name
}
if len(rev) > 0 {
out.reverse[channel] = rev
}
}
if len(out.reverse) == 0 {
out.reverse = nil
}
return out
}
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
// client-visible model name unchanged for translation/response formatting.
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
if m == nil {
return
}
table := compileModelNameMappingTable(mappings)
// atomic.Value requires non-nil store values.
if table == nil {
table = &modelNameMappingTable{}
}
m.modelNameMappings.Store(table)
}
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings
// and returns the resolved model along with updated metadata. If a mapping exists,
// the returned model is the upstream model and metadata contains the original
// requested model for response translation.
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
if upstreamModel == "" {
return requestedModel, metadata
}
out := make(map[string]any, 1)
if len(metadata) > 0 {
out = make(map[string]any, len(metadata)+1)
for k, v := range metadata {
out[k] = v
}
}
// Store the requested alias (e.g., "gp") so downstream can use it to look up
// model metadata from the global registry where it was registered under this alias.
out[util.ModelMappingOriginalModelMetadataKey] = requestedModel
return upstreamModel, out
}
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
if m == nil || auth == nil {
return ""
}
channel := modelMappingChannel(auth)
if channel == "" {
return ""
}
key := strings.ToLower(strings.TrimSpace(requestedModel))
if key == "" {
return ""
}
raw := m.modelNameMappings.Load()
table, _ := raw.(*modelNameMappingTable)
if table == nil || table.reverse == nil {
return ""
}
rev := table.reverse[channel]
if rev == nil {
return ""
}
original := strings.TrimSpace(rev[key])
if original == "" || strings.EqualFold(original, requestedModel) {
return ""
}
return original
}
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
// It determines the provider and auth kind from the Auth's attributes and delegates
// to OAuthModelMappingChannel for the actual channel resolution.
func modelMappingChannel(auth *Auth) string {
if auth == nil {
return ""
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
authKind := ""
if auth.Attributes != nil {
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
}
if authKind == "" {
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
return OAuthModelMappingChannel(provider, authKind)
}
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
// OAuth model mappings (e.g., API key authentication).
//
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
func OAuthModelMappingChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
switch provider {
case "gemini":
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
return ""
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
return provider
default:
return ""
}
}

View File

@@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) {
}
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings)
service := &Service{
cfg: b.cfg,

View File

@@ -5,6 +5,9 @@ import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
// ModelInfo re-exports the registry model info structure.
type ModelInfo = registry.ModelInfo
// ModelRegistryHook re-exports the registry hook interface for external integrations.
type ModelRegistryHook = registry.ModelRegistryHook
// ModelRegistry describes registry operations consumed by external callers.
type ModelRegistry interface {
RegisterClient(clientID, clientProvider string, models []*ModelInfo)
@@ -13,9 +16,15 @@ type ModelRegistry interface {
ClearModelQuotaExceeded(clientID, modelID string)
ClientSupportsModel(clientID, modelID string) bool
GetAvailableModels(handlerType string) []map[string]any
GetAvailableModelsByProvider(provider string) []*ModelInfo
}
// GlobalModelRegistry returns the shared registry instance.
func GlobalModelRegistry() ModelRegistry {
return registry.GetGlobalRegistry()
}
// SetGlobalModelRegistryHook registers an optional hook on the shared global registry instance.
func SetGlobalModelRegistryHook(hook ModelRegistryHook) {
registry.GetGlobalRegistry().SetHook(hook)
}

View File

@@ -552,6 +552,9 @@ func (s *Service) Run(ctx context.Context) error {
s.cfgMu.Lock()
s.cfg = newCfg
s.cfgMu.Unlock()
if s.coreManager != nil {
s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings)
}
s.rebindExecutors()
}
@@ -677,6 +680,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
return
}
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
if authKind == "" {
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
GlobalModelRegistry().UnregisterClient(a.ID)
@@ -702,6 +710,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "gemini":
models = registry.GetGeminiModels()
if entry := s.resolveConfigGeminiKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildGeminiConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
@@ -741,6 +752,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "codex":
models = registry.GetOpenAIModels()
if entry := s.resolveConfigCodexKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildCodexConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
@@ -833,6 +847,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
}
}
models = applyOAuthModelMappings(s.cfg, provider, authKind, models)
if len(models) > 0 {
key := provider
if key == "" {
@@ -1104,17 +1119,22 @@ func matchWildcard(pattern, value string) bool {
return true
}
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 {
type modelEntry interface {
GetName() string
GetAlias() string
}
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
if len(models) == 0 {
return nil
}
now := time.Now().Unix()
out := make([]*ModelInfo, 0, len(entry.Models))
seen := make(map[string]struct{}, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
out := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for i := range models {
model := models[i]
name := strings.TrimSpace(model.GetName())
alias := strings.TrimSpace(model.GetAlias())
if alias == "" {
alias = name
}
@@ -1130,52 +1150,176 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
if display == "" {
display = alias
}
out = append(out, &ModelInfo{
info := &ModelInfo{
ID: alias,
Object: "model",
Created: now,
OwnedBy: "vertex",
Type: "vertex",
OwnedBy: ownedBy,
Type: modelType,
DisplayName: display,
})
}
if name != "" {
if upstream := registry.LookupStaticModelInfo(name); upstream != nil && upstream.Thinking != nil {
info.Thinking = upstream.Thinking
}
}
out = append(out, info)
}
return out
}
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 {
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
if entry == nil {
return nil
}
now := time.Now().Unix()
out := make([]*ModelInfo, 0, len(entry.Models))
seen := make(map[string]struct{}, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if alias == "" {
alias = name
}
if alias == "" {
return buildConfigModels(entry.Models, "google", "vertex")
}
func buildGeminiConfigModels(entry *config.GeminiKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "google", "gemini")
}
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "anthropic", "claude")
}
func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "openai", "openai")
}
func rewriteModelInfoName(name, oldID, newID string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return name
}
oldID = strings.TrimSpace(oldID)
newID = strings.TrimSpace(newID)
if oldID == "" || newID == "" {
return name
}
if strings.EqualFold(oldID, newID) {
return name
}
if strings.HasSuffix(trimmed, "/"+oldID) {
prefix := strings.TrimSuffix(trimmed, oldID)
return prefix + newID
}
if trimmed == "models/"+oldID {
return "models/" + newID
}
return name
}
func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
if cfg == nil || len(models) == 0 {
return models
}
channel := coreauth.OAuthModelMappingChannel(provider, authKind)
if channel == "" || len(cfg.OAuthModelMappings) == 0 {
return models
}
mappings := cfg.OAuthModelMappings[channel]
if len(mappings) == 0 {
return models
}
type mappingEntry struct {
alias string
fork bool
}
forward := make(map[string]mappingEntry, len(mappings))
for i := range mappings {
name := strings.TrimSpace(mappings[i].Name)
alias := strings.TrimSpace(mappings[i].Alias)
if name == "" || alias == "" {
continue
}
key := strings.ToLower(alias)
if _, exists := seen[key]; exists {
if strings.EqualFold(name, alias) {
continue
}
seen[key] = struct{}{}
display := name
if display == "" {
display = alias
key := strings.ToLower(name)
if _, exists := forward[key]; exists {
continue
}
out = append(out, &ModelInfo{
ID: alias,
Object: "model",
Created: now,
OwnedBy: "claude",
Type: "claude",
DisplayName: display,
})
forward[key] = mappingEntry{alias: alias, fork: mappings[i].Fork}
}
if len(forward) == 0 {
return models
}
out := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
id := strings.TrimSpace(model.ID)
if id == "" {
continue
}
key := strings.ToLower(id)
entry, ok := forward[key]
if !ok {
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
out = append(out, model)
continue
}
mappedID := strings.TrimSpace(entry.alias)
if mappedID == "" {
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
out = append(out, model)
continue
}
if entry.fork {
if _, exists := seen[key]; !exists {
seen[key] = struct{}{}
out = append(out, model)
}
aliasKey := strings.ToLower(mappedID)
if _, exists := seen[aliasKey]; exists {
continue
}
seen[aliasKey] = struct{}{}
clone := *model
clone.ID = mappedID
if clone.Name != "" {
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
}
out = append(out, &clone)
continue
}
uniqueKey := strings.ToLower(mappedID)
if _, exists := seen[uniqueKey]; exists {
continue
}
seen[uniqueKey] = struct{}{}
if mappedID == id {
out = append(out, model)
continue
}
clone := *model
clone.ID = mappedID
if clone.Name != "" {
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
}
out = append(out, &clone)
}
return out
}

View File

@@ -0,0 +1,58 @@
package cliproxy
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestApplyOAuthModelMappings_Rename(t *testing.T) {
cfg := &config.Config{
OAuthModelMappings: map[string][]config.ModelNameMapping{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
models := []*ModelInfo{
{ID: "gpt-5", Name: "models/gpt-5"},
}
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
if len(out) != 1 {
t.Fatalf("expected 1 model, got %d", len(out))
}
if out[0].ID != "g5" {
t.Fatalf("expected model id %q, got %q", "g5", out[0].ID)
}
if out[0].Name != "models/g5" {
t.Fatalf("expected model name %q, got %q", "models/g5", out[0].Name)
}
}
func TestApplyOAuthModelMappings_ForkAddsAlias(t *testing.T) {
cfg := &config.Config{
OAuthModelMappings: map[string][]config.ModelNameMapping{
"codex": {
{Name: "gpt-5", Alias: "g5", Fork: true},
},
},
}
models := []*ModelInfo{
{ID: "gpt-5", Name: "models/gpt-5"},
}
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
if len(out) != 2 {
t.Fatalf("expected 2 models, got %d", len(out))
}
if out[0].ID != "gpt-5" {
t.Fatalf("expected first model id %q, got %q", "gpt-5", out[0].ID)
}
if out[1].ID != "g5" {
t.Fatalf("expected second model id %q, got %q", "g5", out[1].ID)
}
if out[1].Name != "models/g5" {
t.Fatalf("expected forked model name %q, got %q", "models/g5", out[1].Name)
}
}

View File

@@ -16,6 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig
type TLSConfig = internalconfig.TLSConfig
type RemoteManagement = internalconfig.RemoteManagement
type AmpCode = internalconfig.AmpCode
type ModelNameMapping = internalconfig.ModelNameMapping
type PayloadConfig = internalconfig.PayloadConfig
type PayloadRule = internalconfig.PayloadRule
type PayloadModelRule = internalconfig.PayloadModelRule

View File

@@ -56,6 +56,10 @@ func setupAmpRouter(h *management.Handler) *gin.Engine {
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
@@ -188,6 +192,90 @@ func TestPutAmpUpstreamAPIKey(t *testing.T) {
}
}
func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) {
h, configPath := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
// Verify it was persisted to disk
loaded, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("failed to load config from disk: %v", err)
}
if len(loaded.AmpCode.UpstreamAPIKeys) != 1 {
t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys))
}
entry := loaded.AmpCode.UpstreamAPIKeys[0]
if entry.UpstreamAPIKey != "u1" {
t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey)
}
if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" {
t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys)
}
// Verify it is returned by GET /ampcode
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]config.AmpCode
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" {
t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got)
}
}
func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
// Seed with one entry
putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
deleteBody := `{"value":[]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpUpstreamAPIKeyEntry
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 {
t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"])
}
}
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)

View File

@@ -0,0 +1,54 @@
package test
import (
"testing"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestOpenAIToCodex_PreservesBuiltinTools(t *testing.T) {
in := []byte(`{
"model":"gpt-5",
"messages":[{"role":"user","content":"hi"}],
"tools":[{"type":"web_search","search_context_size":"high"}],
"tool_choice":{"type":"web_search"}
}`)
out := sdktranslator.TranslateRequest(sdktranslator.FormatOpenAI, sdktranslator.FormatCodex, "gpt-5", in, false)
if got := gjson.GetBytes(out, "tools.#").Int(); got != 1 {
t.Fatalf("expected 1 tool, got %d: %s", got, string(out))
}
if got := gjson.GetBytes(out, "tools.0.type").String(); got != "web_search" {
t.Fatalf("expected tools[0].type=web_search, got %q: %s", got, string(out))
}
if got := gjson.GetBytes(out, "tools.0.search_context_size").String(); got != "high" {
t.Fatalf("expected tools[0].search_context_size=high, got %q: %s", got, string(out))
}
if got := gjson.GetBytes(out, "tool_choice.type").String(); got != "web_search" {
t.Fatalf("expected tool_choice.type=web_search, got %q: %s", got, string(out))
}
}
func TestOpenAIResponsesToOpenAI_PreservesBuiltinTools(t *testing.T) {
in := []byte(`{
"model":"gpt-5",
"input":[{"role":"user","content":[{"type":"input_text","text":"hi"}]}],
"tools":[{"type":"web_search","search_context_size":"low"}]
}`)
out := sdktranslator.TranslateRequest(sdktranslator.FormatOpenAIResponse, sdktranslator.FormatOpenAI, "gpt-5", in, false)
if got := gjson.GetBytes(out, "tools.#").Int(); got != 1 {
t.Fatalf("expected 1 tool, got %d: %s", got, string(out))
}
if got := gjson.GetBytes(out, "tools.0.type").String(); got != "web_search" {
t.Fatalf("expected tools[0].type=web_search, got %q: %s", got, string(out))
}
if got := gjson.GetBytes(out, "tools.0.search_context_size").String(); got != "low" {
t.Fatalf("expected tools[0].search_context_size=low, got %q: %s", got, string(out))
}
}

View File

@@ -0,0 +1,211 @@
package test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
)
// TestModelAliasThinkingSuffix tests the 32 test cases defined in docs/thinking_suffix_test_cases.md
// These tests verify the thinking suffix parsing and application logic across different providers.
func TestModelAliasThinkingSuffix(t *testing.T) {
tests := []struct {
id int
name string
provider string
requestModel string
suffixType string
expectedField string // "thinkingBudget", "thinkingLevel", "budget_tokens", "reasoning_effort", "enable_thinking"
expectedValue any
upstreamModel string // The upstream model after alias resolution
isAlias bool
}{
// === 1. Antigravity Provider ===
// 1.1 Budget-only models (Gemini 2.5)
{1, "antigravity_original_numeric", "antigravity", "gemini-2.5-computer-use-preview-10-2025(1000)", "numeric", "thinkingBudget", 1000, "gemini-2.5-computer-use-preview-10-2025", false},
{2, "antigravity_alias_numeric", "antigravity", "gp(1000)", "numeric", "thinkingBudget", 1000, "gemini-2.5-computer-use-preview-10-2025", true},
// 1.2 Budget+Levels models (Gemini 3)
{3, "antigravity_original_numeric_to_level", "antigravity", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{4, "antigravity_original_level", "antigravity", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{5, "antigravity_alias_numeric_to_level", "antigravity", "gf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
{6, "antigravity_alias_level", "antigravity", "gf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
// === 2. Gemini CLI Provider ===
// 2.1 Budget-only models
{7, "gemini_cli_original_numeric", "gemini-cli", "gemini-2.5-pro(8192)", "numeric", "thinkingBudget", 8192, "gemini-2.5-pro", false},
{8, "gemini_cli_alias_numeric", "gemini-cli", "g25p(8192)", "numeric", "thinkingBudget", 8192, "gemini-2.5-pro", true},
// 2.2 Budget+Levels models
{9, "gemini_cli_original_numeric_to_level", "gemini-cli", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{10, "gemini_cli_original_level", "gemini-cli", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{11, "gemini_cli_alias_numeric_to_level", "gemini-cli", "gf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
{12, "gemini_cli_alias_level", "gemini-cli", "gf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
// === 3. Vertex Provider ===
// 3.1 Budget-only models
{13, "vertex_original_numeric", "vertex", "gemini-2.5-pro(16384)", "numeric", "thinkingBudget", 16384, "gemini-2.5-pro", false},
{14, "vertex_alias_numeric", "vertex", "vg25p(16384)", "numeric", "thinkingBudget", 16384, "gemini-2.5-pro", true},
// 3.2 Budget+Levels models
{15, "vertex_original_numeric_to_level", "vertex", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{16, "vertex_original_level", "vertex", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{17, "vertex_alias_numeric_to_level", "vertex", "vgf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
{18, "vertex_alias_level", "vertex", "vgf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
// === 4. AI Studio Provider ===
// 4.1 Budget-only models
{19, "aistudio_original_numeric", "aistudio", "gemini-2.5-pro(12000)", "numeric", "thinkingBudget", 12000, "gemini-2.5-pro", false},
{20, "aistudio_alias_numeric", "aistudio", "ag25p(12000)", "numeric", "thinkingBudget", 12000, "gemini-2.5-pro", true},
// 4.2 Budget+Levels models
{21, "aistudio_original_numeric_to_level", "aistudio", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{22, "aistudio_original_level", "aistudio", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{23, "aistudio_alias_numeric_to_level", "aistudio", "agf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
{24, "aistudio_alias_level", "aistudio", "agf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
// === 5. Claude Provider ===
{25, "claude_original_numeric", "claude", "claude-sonnet-4-5-20250929(16384)", "numeric", "budget_tokens", 16384, "claude-sonnet-4-5-20250929", false},
{26, "claude_alias_numeric", "claude", "cs45(16384)", "numeric", "budget_tokens", 16384, "claude-sonnet-4-5-20250929", true},
// === 6. Codex Provider ===
{27, "codex_original_level", "codex", "gpt-5(high)", "level", "reasoning_effort", "high", "gpt-5", false},
{28, "codex_alias_level", "codex", "g5(high)", "level", "reasoning_effort", "high", "gpt-5", true},
// === 7. Qwen Provider ===
{29, "qwen_original_level", "qwen", "qwen3-coder-plus(high)", "level", "enable_thinking", true, "qwen3-coder-plus", false},
{30, "qwen_alias_level", "qwen", "qcp(high)", "level", "enable_thinking", true, "qwen3-coder-plus", true},
// === 8. iFlow Provider ===
{31, "iflow_original_level", "iflow", "glm-4.7(high)", "level", "reasoning_effort", "high", "glm-4.7", false},
{32, "iflow_alias_level", "iflow", "glm(high)", "level", "reasoning_effort", "high", "glm-4.7", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Step 1: Parse model suffix (simulates SDK layer normalization)
// For "gp(1000)" -> requestedModel="gp", metadata={thinking_budget: 1000}
requestedModel, metadata := util.NormalizeThinkingModel(tt.requestModel)
// Verify suffix was parsed
if metadata == nil && (tt.suffixType == "numeric" || tt.suffixType == "level") {
t.Errorf("Case #%d: NormalizeThinkingModel(%q) metadata is nil", tt.id, tt.requestModel)
return
}
// Step 2: Simulate OAuth model mapping
// Real flow: applyOAuthModelMapping stores requestedModel (the alias) in metadata
if tt.isAlias {
if metadata == nil {
metadata = make(map[string]any)
}
metadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
}
// Step 3: Verify metadata extraction
switch tt.suffixType {
case "numeric":
budget, _, _, matched := util.ThinkingFromMetadata(metadata)
if !matched {
t.Errorf("Case #%d: ThinkingFromMetadata did not match", tt.id)
return
}
if budget == nil {
t.Errorf("Case #%d: expected budget in metadata", tt.id)
return
}
// For thinkingBudget/budget_tokens, verify the parsed budget value
if tt.expectedField == "thinkingBudget" || tt.expectedField == "budget_tokens" {
expectedBudget := tt.expectedValue.(int)
if *budget != expectedBudget {
t.Errorf("Case #%d: budget = %d, want %d", tt.id, *budget, expectedBudget)
}
}
// For thinkingLevel (Gemini 3), verify conversion from budget to level
if tt.expectedField == "thinkingLevel" {
level, ok := util.ThinkingBudgetToGemini3Level(tt.upstreamModel, *budget)
if !ok {
t.Errorf("Case #%d: ThinkingBudgetToGemini3Level failed", tt.id)
return
}
expectedLevel := tt.expectedValue.(string)
if level != expectedLevel {
t.Errorf("Case #%d: converted level = %q, want %q", tt.id, level, expectedLevel)
}
}
case "level":
_, _, effort, matched := util.ThinkingFromMetadata(metadata)
if !matched {
t.Errorf("Case #%d: ThinkingFromMetadata did not match", tt.id)
return
}
if effort == nil {
t.Errorf("Case #%d: expected effort in metadata", tt.id)
return
}
if tt.expectedField == "thinkingLevel" || tt.expectedField == "reasoning_effort" {
expectedEffort := tt.expectedValue.(string)
if *effort != expectedEffort {
t.Errorf("Case #%d: effort = %q, want %q", tt.id, *effort, expectedEffort)
}
}
}
// Step 4: Test Gemini-specific thinkingLevel conversion for Gemini 3 models
if tt.expectedField == "thinkingLevel" && util.IsGemini3Model(tt.upstreamModel) {
body := []byte(`{"request":{"contents":[]}}`)
// Build metadata simulating real OAuth flow:
// - requestedModel (alias like "gf") is stored in model_mapping_original_model
// - upstreamModel is passed as the model parameter
testMetadata := make(map[string]any)
if tt.isAlias {
// Real flow: applyOAuthModelMapping stores requestedModel (the alias)
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
}
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
for k, v := range metadata {
testMetadata[k] = v
}
result := util.ApplyGemini3ThinkingLevelFromMetadataCLI(tt.upstreamModel, testMetadata, body)
levelVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel")
expectedLevel := tt.expectedValue.(string)
if !levelVal.Exists() {
t.Errorf("Case #%d: expected thinkingLevel in result", tt.id)
} else if levelVal.String() != expectedLevel {
t.Errorf("Case #%d: thinkingLevel = %q, want %q", tt.id, levelVal.String(), expectedLevel)
}
}
// Step 5: Test Gemini 2.5 thinkingBudget application using real ApplyThinkingMetadataCLI flow
if tt.expectedField == "thinkingBudget" && util.IsGemini25Model(tt.upstreamModel) {
body := []byte(`{"request":{"contents":[]}}`)
// Build metadata simulating real OAuth flow:
// - requestedModel (alias like "gp") is stored in model_mapping_original_model
// - upstreamModel is passed as the model parameter
testMetadata := make(map[string]any)
if tt.isAlias {
// Real flow: applyOAuthModelMapping stores requestedModel (the alias)
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
}
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
for k, v := range metadata {
testMetadata[k] = v
}
// Use the exported ApplyThinkingMetadataCLI which includes the fallback logic
result := executor.ApplyThinkingMetadataCLI(body, testMetadata, tt.upstreamModel)
budgetVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget")
expectedBudget := tt.expectedValue.(int)
if !budgetVal.Exists() {
t.Errorf("Case #%d: expected thinkingBudget in result", tt.id)
} else if int(budgetVal.Int()) != expectedBudget {
t.Errorf("Case #%d: thinkingBudget = %d, want %d", tt.id, int(budgetVal.Int()), expectedBudget)
}
}
})
}
}