Compare commits

...

226 Commits
v6.6.40 ... amp

Author SHA1 Message Date
hkfires
5f7067ef0d test 2026-01-12 19:53:26 +08:00
hkfires
21ac161b21 fix(test): implement missing HttpRequest method in stream bootstrap mock 2026-01-12 16:33:43 +08:00
Luis Pater
94e979865e Fixed: #897
refactor(executor): remove `prompt_cache_retention` from request payloads
2026-01-12 10:46:47 +08:00
Luis Pater
6c324f2c8b Fixed: #936
feat(cliproxy): support multiple aliases for OAuth model mappings

- Updated mapping logic to allow multiple aliases per upstream model name.
- Adjusted `SanitizeOAuthModelMappings` to ensure aliases remain unique within channels.
- Added test cases to validate multi-alias scenarios.
- Updated example config to clarify multi-alias support.
2026-01-12 10:40:34 +08:00
Luis Pater
543dfd67e0 refactor(cache): remove max entries logic and extend signature TTL to 3 hours 2026-01-12 00:20:44 +08:00
Luis Pater
28bd1323a2 Merge pull request #971 from router-for-me/codex
feat(codex): add OpenCode instructions based on user agent
2026-01-11 16:01:13 +08:00
hkfires
220ca45f74 fix(codex): only override instructions when upstream provides them 2026-01-11 15:52:21 +08:00
hkfires
70a82d80ac fix(codex): only override instructions in responses for OpenCode UA 2026-01-11 15:19:37 +08:00
hkfires
ac626111ac feat(codex): add OpenCode instructions based on user agent 2026-01-11 13:36:35 +08:00
Luis Pater
8cfe26f10c Merge branch 'sdk' into dev 2026-01-10 16:26:23 +08:00
Luis Pater
80db2dc254 Merge pull request #955 from router-for-me/api
feat(codex): add subscription date fields to ID token claims
2026-01-10 16:26:07 +08:00
Luis Pater
e8e3bc8616 feat(executor): add HttpRequest support across executors for better http request handling 2026-01-10 16:25:25 +08:00
Luis Pater
bc3195c8d8 refactor(logger): remove unnecessary request details limit logic 2026-01-10 14:46:59 +08:00
hkfires
6494330c6b feat(codex): add subscription date fields to ID token claims 2026-01-10 11:15:20 +08:00
Luis Pater
4d7f389b69 Fixed: #941
fix(translator): ensure fallback to valid originalRequestRawJSON in response handling
2026-01-10 01:01:09 +08:00
Luis Pater
95f87d5669 Merge pull request #947 from pykancha/fix-memory-leak
Resolve memory leaks causing OOM in k8s deployment
2026-01-10 00:40:47 +08:00
Luis Pater
c83365a349 Merge pull request #938 from router-for-me/log
refactor(logging): clean up oauth logs and debugs
2026-01-10 00:02:45 +08:00
Luis Pater
6b3604cf2b Merge pull request #943 from ben-vargas/fix-tool-mappings
Fix Claude OAuth tool name mapping (proxy_)
2026-01-09 23:52:29 +08:00
Luis Pater
af6bdca14f Fixed: #942
fix(executor): ignore non-SSE lines in OpenAI-compatible streams
2026-01-09 23:41:50 +08:00
hemanta212
1c773c428f fix: Remove investigation artifacts 2026-01-09 17:47:59 +05:45
Ben Vargas
e785bfcd12 Use unprefixed Claude request for translation
Keep the upstream payload prefixed for OAuth while passing the unprefixed request body into response translators. This avoids proxy_ leaking into OpenAI Responses echoed tool metadata while preserving the Claude OAuth workaround.
2026-01-09 00:54:35 -07:00
hemanta212
47dacce6ea fix(server): resolve memory leaks causing OOM in k8s deployment
- usage/logger_plugin: cap modelStats.Details at 1000 entries per model
- cache/signature_cache: add background cleanup for expired sessions (10 min)
- management/handler: add background cleanup for stale IP rate-limit entries (1 hr)
- executor/cache_helpers: add mutex protection and TTL cleanup for codexCacheMap (15 min)
- executor/codex_executor: use thread-safe cache accessors

Add reproduction tests demonstrating leak behavior before/after fixes.

Amp-Thread-ID: https://ampcode.com/threads/T-019ba0fc-1d7b-7338-8e1d-ca0520412777
Co-authored-by: Amp <amp@ampcode.com>
2026-01-09 13:33:46 +05:45
Ben Vargas
dcac3407ab Fix Claude OAuth tool name mapping
Prefix tool names with proxy_ for Claude OAuth requests and strip the prefix from streaming and non-streaming responses to restore client-facing names.

Updates the Claude executor to:
- add prefixing for tools, tool_choice, and tool_use messages when using OAuth tokens
- strip the prefix from tool_use events in SSE and non-streaming payloads
- add focused unit tests for prefix/strip helpers
2026-01-09 00:10:38 -07:00
hkfires
7004295e1d build(docker): move stats export execution after image build 2026-01-09 11:24:00 +08:00
hkfires
ee62ef4745 refactor(logging): clean up oauth logs and debugs 2026-01-09 11:20:55 +08:00
Luis Pater
ef6bafbf7e fix(executor): handle context cancellation and deadline errors explicitly 2026-01-09 10:48:29 +08:00
Luis Pater
ed28b71e87 refactor(amp): remove duplicate comments in response rewriter 2026-01-09 08:21:13 +08:00
Luis Pater
d47b7dc79a refactor(response): enhance parameter handling for Codex to Claude conversion 2026-01-09 05:20:19 +08:00
Luis Pater
49b9709ce5 Merge pull request #787 from sususu98/fix/antigravity-429-retry-delay-parsing
fix(antigravity): parse retry-after delay from 429 response body
2026-01-09 04:45:25 +08:00
Luis Pater
a2eba2cdf5 Merge pull request #763 from mvelbaum/feature/improve-oauth-use-logging
feat(logging): disambiguate OAuth credential selection in debug logs
2026-01-09 04:43:21 +08:00
Luis Pater
3d01b3cfe8 Merge pull request #553 from XInTheDark/fix/builtin-tools-web-search
fix(translator): preserve built-in tools (web_search) to Responses API
2026-01-09 04:40:13 +08:00
Luis Pater
af2efa6f7e Merge pull request #605 from soilSpoon/feature/amp-compat
feature: Improves Amp client compatibility
2026-01-09 04:28:17 +08:00
Luis Pater
d73b61d367 Merge pull request #901 from uzhao/vscode-plugin
Vscode plugin
2026-01-08 22:22:27 +08:00
Luis Pater
59a448b645 feat(executor): centralize systemInstruction handling for Claude and Gemini-3-Pro models 2026-01-08 21:05:33 +08:00
Chén Mù
4adb9eed77 Merge pull request #921 from router-for-me/atgy
fix(executor): update gemini model identifier to gemini-3-pro-preview
2026-01-08 19:20:32 +08:00
hkfires
b6a0f7a07f fix(executor): update gemini model identifier to gemini-3-pro-preview
Update the model name check in `buildRequest` to target "gemini-3-pro-preview" instead of "gemini-3-pro" when applying specific system instruction handling.
2026-01-08 19:14:52 +08:00
Luis Pater
1b2f907671 feat(executor): update system instruction handling for Claude and Gemini-3-Pro models 2026-01-08 12:42:26 +08:00
Luis Pater
bda04eed8a feat(executor): add model-specific support for "gemini-3-pro" in execution and payload handling 2026-01-08 12:27:03 +08:00
Luis Pater
67985d8226 feat(executor): enhance Antigravity payload with user role and dynamic system instructions 2026-01-08 10:55:25 +08:00
Jianyang Zhao
cbcb061812 Update README_CN.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-07 20:07:01 -05:00
Jianyang Zhao
9fc2e1b3c8 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-07 20:06:55 -05:00
Jianyang Zhao
3b484aea9e Add Claude Proxy VSCode to README_CN.md
Added information about Claude Proxy VSCode extension.
2026-01-07 20:03:07 -05:00
Jianyang Zhao
963a0950fa Add Claude Proxy VSCode extension to README
Added Claude Proxy VSCode extension to the README.
2026-01-07 20:02:50 -05:00
Luis Pater
f4ba1ab910 fix(executor): remove unused tokenRefreshTimeout constant and pass zero timeout to HTTP client 2026-01-07 18:16:49 +08:00
Luis Pater
2662f91082 feat(management): add PostOAuthCallback handler to token requester interface 2026-01-07 10:47:32 +08:00
Luis Pater
c1db2c7d7c Merge pull request #888 from router-for-me/api-call-TOKEN-fix
fix(management): refresh antigravity token for api-call $TOKEN$
2026-01-07 01:19:24 +08:00
LTbinglingfeng
5e5d8142f9 fix(auth): error when antigravity refresh token missing during refresh 2026-01-07 01:09:50 +08:00
LTbinglingfeng
b01619b441 fix(management): refresh antigravity token for api-call $TOKEN$ 2026-01-07 00:14:02 +08:00
Luis Pater
f861bd6a94 docs: add 9Router to community projects in README 2026-01-06 23:15:28 +08:00
Luis Pater
6dbfdd140d Merge pull request #871 from decolua/patch-1
Update README.md
2026-01-06 22:58:53 +08:00
decolua
386ccffed4 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 20:54:33 +07:00
decolua
ffddd1c90a Update README.md 2026-01-05 20:29:26 +07:00
Luis Pater
8f8dfd081b Merge pull request #850 from can1357/main
feat(translator): add developer role support for Gemini translators
2026-01-05 11:27:24 +08:00
Luis Pater
9f1b445c7c docs: add ProxyPilot to community projects in Chinese README 2026-01-05 11:23:48 +08:00
Luis Pater
ae933dfe14 Merge pull request #858 from Finesssee/add-proxypilot
docs: add ProxyPilot to community projects
2026-01-05 11:20:52 +08:00
Luis Pater
e124db723b Merge pull request #862 from router-for-me/gemini
fix(gemini): abort default injection on existing thinking keys
2026-01-05 10:41:07 +08:00
hkfires
05444cf32d fix(gemini): abort default injection on existing thinking keys 2026-01-05 10:24:30 +08:00
Luis Pater
8edbda57cf feat(translator): add thoughtSignature to node parts for Gemini and Antigravity requests
Enhanced node structure by including `thoughtSignature` for inline data parts in Gemini OpenAI, Gemini CLI, and Antigravity request handlers to improve traceability of thought processes.
2026-01-05 09:25:17 +08:00
Finessse
821249a5ed docs: add ProxyPilot to community projects 2026-01-04 18:19:41 +07:00
Luis Pater
ee33863b47 Merge pull request #857 from router-for-me/management-update
Management update
2026-01-04 18:07:13 +08:00
Supra4E8C
cd22c849e2 feat(management): 更新OAuth模型映射的清理逻辑以增强数据安全性 2026-01-04 17:57:34 +08:00
Supra4E8C
f0e73efda2 feat(management): add vertex api key and oauth model mappings endpoints 2026-01-04 17:32:00 +08:00
Supra4E8C
3156109c71 feat(management): 支持管理接口调整日志大小/强制前缀/路由策略 2026-01-04 12:21:49 +08:00
can1357
6762e081f3 feat(translator): add developer role support for Gemini translators
Treat OpenAI's "developer" role the same as "system" role in request
translation for gemini, gemini-cli, and antigravity backends.
2026-01-03 21:01:01 +01:00
Luis Pater
7815ee338d fix(translator): adjust message_delta emission boundary in Claude-to-OpenAI conversion
Fixed incorrect boundary logic for `message_delta` emission, ensuring proper handling of usage updates and `emitMessageStopIfNeeded` within the response loop.
2026-01-04 01:36:51 +08:00
Luis Pater
44b6c872e2 feat(config): add support for Fork in OAuth model mappings with alias handling
Implemented `Fork` flag in `ModelNameMapping` to allow aliases as additional models while preserving the original model ID. Updated the `applyOAuthModelMappings` logic, added tests for `Fork` behavior, and updated documentation and examples accordingly.
2026-01-04 01:18:29 +08:00
Luis Pater
7a77b23f2d feat(executor): add token refresh timeout and improve context handling during refresh
Introduced `tokenRefreshTimeout` constant for token refresh operations and enhanced context propagation for `refreshToken` by embedding roundtrip information if available. Adjusted `refreshAuth` to ensure default context initialization and handle cancellation errors appropriately.
2026-01-04 00:26:08 +08:00
Luis Pater
672e8549c0 docs: reorganize README to adjust CodMate placement
Moved CodMate entry under ProxyPal in both English and Chinese README files for consistency in structure and better readability.
2026-01-03 21:31:53 +08:00
Luis Pater
66f5269a23 Merge pull request #837 from loocor/main
docs: add CodMate to community projects
2026-01-03 21:30:15 +08:00
Luis Pater
ebec293497 feat(api): integrate TokenStore for improved auth entry management
Replaced file-based auth entry counting with `TokenStore`-backed implementation, enhancing flexibility and context-aware token management. Updated related logic to reflect this change.
2026-01-03 04:53:47 +08:00
Luis Pater
e02ceecd35 feat(registry): introduce ModelRegistryHook for monitoring model registrations and unregistrations
Added support for external hooks to observe model registry events using the `ModelRegistryHook` interface. Implemented thread-safe, non-blocking execution of hooks with panic recovery. Comprehensive tests added to verify hook behavior during registration, unregistration, blocking, and panic scenarios.
2026-01-02 23:18:40 +08:00
Luis Pater
c8b33a8cc3 Merge pull request #824 from router-for-me/script
feat(script): add usage statistics preservation across container rebuilds
2026-01-02 20:42:25 +08:00
Loocor
dca8d5ded8 Add CodMate app information to README
Added CodMate section to README with app details.
2026-01-02 17:15:38 +08:00
Loocor
2a7fd1e897 Add CodMate description to README_CN.md
添加 CodMate 应用的描述,提供 CLI AI 会话管理功能。
2026-01-02 17:15:09 +08:00
Luis Pater
b9d1e70ac2 Merge pull request #830 from router-for-me/gemini
fix(util): disable default thinking for gemini-3 series
2026-01-02 10:59:24 +08:00
hkfires
fdf5720217 fix(gemini): remove default thinking for gemini 3 models 2026-01-02 10:55:59 +08:00
hkfires
f40bd0cd51 feat(script): add usage statistics preservation across container rebuilds 2026-01-02 10:01:20 +08:00
hkfires
e33676bb87 fix(util): disable default thinking for gemini-3 series 2026-01-02 09:43:40 +08:00
Luis Pater
2a663d5cba feat(executor): enhance payload translation with original request context
Refactored `applyPayloadConfig` to `applyPayloadConfigWithRoot`, adding support for default rule validation against the original payload when available. Updated all executors to use `applyPayloadConfigWithRoot` and incorporate an optional original request payload for translations.
2026-01-02 00:03:26 +08:00
Luis Pater
750b930679 Merge pull request #823 from router-for-me/translator
feat(translator): enhance Claude-to-OpenAI conversion with thinking block and tool result handling
2026-01-01 20:16:10 +08:00
hkfires
3902fd7501 fix(iflow): remove thinking field from request body in thinking config handler 2026-01-01 19:40:28 +08:00
hkfires
4fc3d5e935 refactor(iflow): simplify thinking config handling for GLM and MiniMax models 2026-01-01 19:31:08 +08:00
hkfires
2d2f4572a7 fix(translator): remove unnecessary whitespace trimming in reasoning text collection 2026-01-01 12:39:09 +08:00
hkfires
8f4c46f38d fix(translator): emit tool_result messages before user content in Claude-to-OpenAI conversion 2026-01-01 11:11:43 +08:00
hkfires
b6ba51bc2a feat(translator): add thinking block and tool result handling for Claude-to-OpenAI conversion 2026-01-01 09:41:25 +08:00
Luis Pater
6a66d32d37 Merge pull request #803 from HsnSaboor/fix-invalid-function-names-sanitization-v2
feat(translator): resolve invalid function name errors by sanitizing Claude tool names
2026-01-01 01:15:50 +08:00
Luis Pater
8d15723195 feat(registry): add GetAvailableModelsByProvider method for retrieving models by provider 2025-12-31 23:37:46 +08:00
Chén Mù
736e0aae86 Merge pull request #814 from router-for-me/aistudio
Fix model alias thinking suffix
2025-12-31 03:08:05 -08:00
hkfires
8bf3305b2b fix(thinking): fallback to upstream model for thinking support when alias not in registry 2025-12-31 18:07:13 +08:00
hkfires
d00e3ea973 feat(thinking): add numeric budget to thinkingLevel conversion fallback 2025-12-31 17:14:47 +08:00
hkfires
89db4e9481 fix(thinking): use model alias for thinking config resolution in mapped models 2025-12-31 17:09:22 +08:00
hkfires
e332419081 feat(registry): add thinking support for gemini-2.5-computer-use-preview model 2025-12-31 17:09:22 +08:00
Luis Pater
e998b1229a feat(updater): add fallback URL and logic for missing management asset 2025-12-31 11:51:20 +08:00
Luis Pater
bbed134bd1 feat(api): add GetAuthStatus method to ManagementTokenRequester interface 2025-12-31 09:40:48 +08:00
Saboor Hassan
47b9503112 chore: revert changes to internal/translator to comply with path guard
This commit reverts all modifications within internal/translator. A separate issue
will be created for the maintenance team to integrate SanitizeFunctionName into
the translators.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 02:19:26 +05:00
Saboor Hassan
3b9253c2be fix(translator): resolve invalid function name errors by sanitizing Claude tool names
This commit centralizes tool name sanitization in SanitizeFunctionName,
applying character compliance, starting character rules, and length limits.
It also fixes a regression in gemini_schema tests and preserves MCP-specific
shortening logic while ensuring compliance.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 02:14:46 +05:00
Saboor Hassan
d241359153 fix(translator): address PR feedback for tool name sanitization
- Pre-compile sanitization regex for better performance.
- Optimize SanitizeFunctionName for conciseness and correctness.
- Handle 64-char edge cases by truncating before prepending underscore.
- Fix bug in Antigravity translator (incorrect join index).
- Refactor Gemini translators to avoid redundant sanitization calls.
- Add comprehensive unit tests including 64-char edge cases.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 01:54:41 +05:00
Saboor Hassan
f4d4249ba5 feat(translator): sanitize tool/function names for upstream provider compatibility
Implemented SanitizeFunctionName utility to ensure Claude tool names meet
Gemini/Upstream strict naming conventions (alphanumeric, starts with letter/underscore, max 64 chars).
Applied sanitization to tool definitions and usage in all relevant translators.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 01:41:07 +05:00
Chén Mù
cb56cb250e Merge pull request #800 from router-for-me/modelmappings
feat(watcher): add model mappings change detection
2025-12-30 06:50:42 -08:00
hkfires
e0381a6ae0 refactor(watcher): extract model summary functions to dedicated file 2025-12-30 22:39:12 +08:00
hkfires
2c01b2ef64 feat(watcher): add Gemini models and OAuth model mappings change detection 2025-12-30 22:39:12 +08:00
Chén Mù
e947266743 Merge pull request #795 from router-for-me/modelmappings
refactor(executor): resolve upstream model at conductor level before execution
2025-12-30 05:31:19 -08:00
Luis Pater
c6b0e85b54 Fixed: #790
fix(gemini): include full text in response output events
2025-12-30 20:44:13 +08:00
hkfires
26efbed05c refactor(executor): remove redundant upstream model parameter from translateRequest 2025-12-30 20:20:42 +08:00
hkfires
96340bf136 refactor(executor): resolve upstream model at conductor level before execution 2025-12-30 19:31:54 +08:00
hkfires
b055e00c1a fix(executor): use upstream model for thinking config and payload translation 2025-12-30 17:49:44 +08:00
sususu
414db44c00 fix(antigravity): parse retry-after delay from 429 response body
When receiving HTTP 429 (Too Many Requests) responses, parse the retry
delay from the response body using parseRetryDelay and populate the
statusErr.retryAfter field. This allows upstream callers to respect
the server's requested retry timing.

Applied to all error paths in Execute, executeClaudeNonStream,
ExecuteStream, CountTokens, and refreshToken functions.
2025-12-30 16:07:32 +08:00
Chén Mù
857c880f99 Merge pull request #785 from router-for-me/gemini
feat(gemini): add per-key model alias support for Gemini provider
2025-12-29 23:32:40 -08:00
hkfires
ce7474d953 feat(cliproxy): propagate thinking support metadata to aliased models 2025-12-30 15:16:54 +08:00
hkfires
70fdd70b84 refactor(cliproxy): extract generic buildConfigModels function for model info generation 2025-12-30 13:35:22 +08:00
hkfires
08ab6a7d77 feat(gemini): add per-key model alias support for Gemini provider 2025-12-30 13:27:57 +08:00
Luis Pater
9fa2a7e9df Merge pull request #782 from router-for-me/modelmappings
refactor(config): rename model-name-mappings to oauth-model-mappings
2025-12-30 11:40:12 +08:00
hkfires
d443c86620 refactor(config): rename model mapping fields from from/to to name/alias 2025-12-30 11:07:59 +08:00
hkfires
7be3f1c36c refactor(config): rename model-name-mappings to oauth-model-mappings 2025-12-30 11:07:58 +08:00
Luis Pater
f6ab6d97b9 fix(logging): add isDirWritable utility to enhance log dir validation in ConfigureLogOutput 2025-12-30 10:48:25 +08:00
Luis Pater
bc866bac49 fix(logging): refactor ConfigureLogOutput to accept config object and adjust log directory handling 2025-12-30 10:28:25 +08:00
Luis Pater
50e6d845f4 feat(cliproxy): introduce global model name mappings for improved aliasing and routing 2025-12-30 08:13:06 +08:00
Luis Pater
a8cb01819d Merge pull request #772 from soffchen/main
fix: Implement fallback log directory for file logging on read-only system
2025-12-30 02:24:49 +08:00
Luis Pater
530273906b Merge pull request #776 from router-for-me/fix-ag-claude
fix(antigravity): inject required placeholder when properties exist w…
2025-12-30 00:37:01 +08:00
Supra4E8C
06ddf575d9 fix(antigravity): inject required placeholder when properties exist without required 2025-12-29 23:55:59 +08:00
hkfires
3099114cbb refactor(api): simplify codex id token claims extraction 2025-12-29 19:48:02 +08:00
Soff
44b63f0767 fix: Return an error if the user home directory cannot be determined for the fallback log path. 2025-12-29 18:46:15 +08:00
Soff Chen
6705d20194 fix: Implement fallback log directory for file logging on read-only systems. 2025-12-29 18:35:48 +08:00
Chén Mù
a38a9c0b0f Merge pull request #770 from router-for-me/api
feat(api): add id token claims extraction for codex auth entries
2025-12-29 00:44:41 -08:00
hkfires
8286caa366 feat(api): add id token claims extraction for codex auth entries 2025-12-29 16:34:16 +08:00
Chén Mù
bd1ec8424d Merge pull request #767 from router-for-me/amp
feat(amp): add per-client upstream API key mapping support
2025-12-28 22:10:11 -08:00
hkfires
225e2c6797 feat(amp): add per-client upstream API key mapping support 2025-12-29 12:26:25 +08:00
Luis Pater
d8fc485513 fix(translators): correct key path for system_instruction.parts in Claude request logic 2025-12-29 11:54:26 +08:00
hkfires
f137eb0ac4 chore: add codex, agents, and opencode dirs to ignore files 2025-12-29 08:42:29 +08:00
Chén Mù
f39a460487 Merge pull request #761 from router-for-me/log
fix(logging): improve request/response capture
2025-12-28 16:13:10 -08:00
Luis Pater
ee171bc563 feat(api): add ManagementTokenRequester interface for management token request endpoints 2025-12-29 02:42:29 +08:00
hkfires
a95428f204 fix(handlers): preserve upstream response logs before duplicate detection 2025-12-28 22:35:36 +08:00
Michael Velbaum
cb3bdffb43 refactor(logging): streamline auth selection debug messages
Reduce duplicate Debugf calls by appending proxy info via an optional suffix and keep the debug-level guard inside the helper.
2025-12-28 16:10:11 +02:00
Michael Velbaum
48f19aab51 refactor(logging): pass request entry into auth selection log
Avoid re-creating the request-scoped log entry in the helper and use a switch for account type dispatch.
2025-12-28 15:51:11 +02:00
Michael Velbaum
48f6d7abdf refactor(logging): dedupe auth selection debug logs
Extract repeated debug logging for selected auth credentials into a helper so execute, count, and stream paths stay consistent.
2025-12-28 15:42:35 +02:00
Michael Velbaum
79fbcb3ec4 fix(logging): quote OAuth account field
Use strconv.Quote when embedding the OAuth account in debug logs so unexpected characters (e.g. quotes) can't break key=value parsing.
2025-12-28 15:32:54 +02:00
Michael Velbaum
0e4148b229 feat(logging): disambiguate OAuth credential selection in debug logs
When multiple OAuth providers share an account email, the existing "Use OAuth" debug lines are ambiguous and hard to correlate with management usage stats. Include provider, auth file, and auth index in the selection log, and only compute these fields when debug logging is enabled to avoid impacting normal request performance.

Before:
[debug] Use OAuth user@example.com for model gemini-3-flash-preview
[debug] Use OAuth user@example.com (project-1234) for model gemini-3-flash-preview

After:
[debug] Use OAuth provider=antigravity auth_file=antigravity-user_example_com.json auth_index=1a2b3c4d5e6f7788 account="user@example.com" for model gemini-3-flash-preview
[debug] Use OAuth provider=gemini-cli auth_file=gemini-user@example.com-project-1234.json auth_index=99aabbccddeeff00 account="user@example.com (project-1234)" for model gemini-3-flash-preview
2025-12-28 15:22:36 +02:00
hkfires
3ca5fb1046 fix(handlers): match raw error text before JSON body for duplicate detection 2025-12-28 19:35:36 +08:00
hkfires
a091d12f4e fix(logging): improve request/response capture 2025-12-28 19:04:31 +08:00
Luis Pater
457924828a Merge pull request #757 from ben-vargas/fix-thinking-toolchoice-conflict
Fix: disable thinking when tool_choice forces tool use
2025-12-28 14:04:30 +08:00
Ben Vargas
aca2ef6359 Fix: disable thinking when tool_choice forces tool use
Anthropic API does not allow extended thinking when tool_choice is set
to "any" or a specific tool. This was causing 400 errors when using
features like Amp's /handoff command which forces tool_choice.

Added disableThinkingIfToolChoiceForced() that removes thinking config
when incompatible tool_choice is detected, applied to both streaming
and non-streaming paths.

Fixes router-for-me/CLIProxyAPI#630
2025-12-27 16:31:37 -07:00
Luis Pater
ade7194792 feat(management): add generic API call handler to management endpoints 2025-12-28 04:40:32 +08:00
Luis Pater
3a436e116a feat(cliproxy): implement model aliasing and hashing for Codex configurations, enhance request routing logic, and normalize Codex model entries 2025-12-28 03:06:51 +08:00
Luis Pater
336867853b Merge pull request #756 from leaph/check-ai-thinking-settings
feat(iflow): add model-specific thinking configs for GLM-4.7 and Mini…
2025-12-28 02:08:27 +08:00
leaph
6403ff4ec4 feat(iflow): add model-specific thinking configs for GLM-4.7 and MiniMax-M2.1
- GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
- MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
- Added preserveReasoningContentInMessages() to support re-injection of reasoning
  content in assistant message history for multi-turn conversations
- Added ThinkingSupport to MiniMax-M2.1 model definition
2025-12-27 18:39:15 +01:00
Luis Pater
d222469b44 Update issue templates 2025-12-28 01:22:42 +08:00
Luis Pater
7646a2b877 Fixed: #749
fix(translators): ensure `gjson.String` content is non-empty before setting `parts` in OpenAI request logic
2025-12-28 00:54:26 +08:00
Luis Pater
62090f2568 Merge pull request #750 from router-for-me/config
fix(config): preserve original config structure and avoid default value pollution
2025-12-27 22:10:01 +08:00
Luis Pater
c281f4cbaf Fixed: #747
fix(translators): rename and integrate `usageMetadata` as `cpaUsageMetadata` in Claude processing logic
2025-12-27 22:02:11 +08:00
hkfires
09455f9e85 fix(config): make streaming keepalive and retries ints 2025-12-27 20:56:47 +08:00
hkfires
c8e72ba0dc fix(config): smart merge writes non-default new keys only 2025-12-27 20:28:54 +08:00
hkfires
375ef252ab docs(config): clarify merge mapping behavior 2025-12-27 19:30:21 +08:00
hkfires
ee552f8720 chore(config): update ignore patterns 2025-12-27 19:13:14 +08:00
hkfires
2e88c4858e fix(config): avoid adding new keys when merging 2025-12-27 19:00:47 +08:00
Luis Pater
3f50da85c1 Merge pull request #745 from router-for-me/auth
fix(auth): make provider rotation atomic
2025-12-27 13:01:22 +08:00
hkfires
8be06255f7 fix(auth): make provider rotation atomic 2025-12-27 12:56:48 +08:00
Luis Pater
72274099aa Fixed: #738
fix(translators): refine prompt token calculation by incorporating cached tokens in Claude response handling
2025-12-27 03:56:11 +08:00
Luis Pater
dcae098e23 Fixed: #736
fix(translators): handle gjson string types in Claude request processing to ensure consistent content parsing
2025-12-27 01:25:47 +08:00
Luis Pater
2eb05ec640 Merge pull request #727 from nguyenphutrong/main
docs: add Quotio to community projects
2025-12-26 11:53:09 +08:00
Luis Pater
3ce0d76aa4 feat(usage): add import/export functionality for usage statistics and enhance deduplication logic 2025-12-26 11:49:51 +08:00
Trong Nguyen
a00b79d9be docs(readme): add Quotio to community projects section 2025-12-26 10:46:05 +07:00
Luis Pater
33e53a2a56 fix(translators): ensure correct handling and output of multimodal assistant content across request handlers 2025-12-26 05:08:04 +08:00
Luis Pater
cd5b80785f Merge pull request #722 from hungthai1401/bugfix/remove-extra-args
Fixed incorrect function signature call to `NewBaseAPIHandlers`
2025-12-26 02:56:56 +08:00
Thai Nguyen Hung
54f71aa273 fix(test): remove extra argument from ExecuteStreamWithAuthManager call 2025-12-25 21:55:35 +07:00
Luis Pater
3f949b7f84 Merge pull request #704 from tinyc0der/add-index
fix(openai): add index field to image response for LiteLLM compatibility
2025-12-25 21:35:12 +08:00
Luis Pater
443c4538bb feat(config): add commercial-mode to optimize HTTP middleware for lower memory usage 2025-12-25 21:05:01 +08:00
TinyCoder
a7fc2ee4cf refactor(image): avoid using json.Marshal 2025-12-25 14:21:01 +07:00
Luis Pater
8e749ac22d docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:59:48 +08:00
Luis Pater
69e09d9bc7 docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:46:27 +08:00
Luis Pater
06ad527e8c Fixed: #696
fix(translators): adjust prompt token calculation by subtracting cached tokens across Gemini, OpenAI, and Claude handlers
2025-12-24 23:29:18 +08:00
Luis Pater
b7409dd2de Merge pull request #706 from router-for-me/log
Log
2025-12-24 22:24:39 +08:00
hkfires
5ba325a8fc refactor(logging): standardize request id formatting and layout 2025-12-24 22:03:07 +08:00
Luis Pater
d502840f91 Merge pull request #695 from NguyenSiTrung/main
feat: add cached token parsing for Gemini , Antigravity API responses
2025-12-24 21:58:55 +08:00
hkfires
99238a4b59 fix(logging): normalize warning level to warn 2025-12-24 21:11:37 +08:00
hkfires
6d43a2ff9a refactor(logging): inline request id in log output 2025-12-24 21:07:18 +08:00
Luis Pater
3faa1ca9af Merge pull request #700 from router-for-me/log
refactor(sdk/auth): rename manager.go to conductor.go
2025-12-24 19:36:24 +08:00
Luis Pater
9d975e0375 feat(models): add support for GLM-4.7 and MiniMax-M2.1 2025-12-24 19:30:57 +08:00
hkfires
2a6d8b78d4 feat(api): add endpoint to retrieve request logs by ID 2025-12-24 19:24:51 +08:00
TinyCoder
671558a822 fix(openai): add index field to image response for LiteLLM compatibility
LiteLLM's Pydantic model requires an index field in each image object.
Without it, responses fail validation with "images.0.index Field required".
2025-12-24 17:43:31 +07:00
hkfires
26fbb77901 refactor(sdk/auth): rename manager.go to conductor.go 2025-12-24 15:21:03 +08:00
NguyenSiTrung
a277302262 Merge remote-tracking branch 'upstream/main' 2025-12-24 10:54:09 +07:00
NguyenSiTrung
969c1a5b72 refactor: extract parseGeminiFamilyUsageDetail helper to reduce duplication 2025-12-24 10:22:31 +07:00
NguyenSiTrung
872339bceb feat: add cached token parsing for Gemini API responses 2025-12-24 10:20:11 +07:00
Luis Pater
5dc0dbc7aa Merge pull request #697 from Cubence-com/main
docs(readme): add Cubence sponsor and fix PackyCode link
2025-12-24 11:19:32 +08:00
Luis Pater
2b7ba54a2f Merge pull request #688 from router-for-me/feature/request-id-tracking
feat(logging): implement request ID tracking and propagation
2025-12-24 10:54:13 +08:00
hkfires
007c3304f2 feat(logging): scope request ID tracking to AI API endpoints 2025-12-24 09:17:09 +08:00
hkfires
e76ba0ede9 feat(logging): implement request ID tracking and propagation 2025-12-24 08:32:17 +08:00
Luis Pater
c06ac07e23 Merge pull request #686 from ajkdrag/main
feat: regex support for model-mappings
2025-12-24 04:37:44 +08:00
Luis Pater
66769ec657 fix(translators): update role from tool to user in Gemini and Gemini-CLI requests 2025-12-24 04:24:07 +08:00
Luis Pater
f413feec61 refactor(handlers): streamline error and data channel handling in streaming logic
Improved consistency across OpenAI, Claude, and Gemini handlers by replacing initial `select` statement with a `for` loop for better readability and error-handling robustness.
2025-12-24 04:07:24 +08:00
Luis Pater
2e538e3486 Merge pull request #661 from jroth1111/fix/streaming-bootstrap-forwarder
fix: improve streaming bootstrap and forwarding
2025-12-24 03:51:40 +08:00
Luis Pater
9617a7b0d6 Merge pull request #621 from dacsang97/fix/antigravity-prompt-caching
Antigravity Prompt Caching Fix
2025-12-24 03:50:25 +08:00
Luis Pater
7569320770 Merge branch 'dev' into fix/antigravity-prompt-caching 2025-12-24 03:49:46 +08:00
Fetters
8d25cf0d75 fix(readme): update PackyCode sponsorship link and remove redundant tbody 2025-12-23 23:44:40 +08:00
Fetters
64e85e7019 docs(readme): add Cubence sponsor 2025-12-23 23:30:57 +08:00
Luis Pater
6d1e20e940 fix(claude_executor): update header logic for API key handling
Refined header assignment to use `x-api-key` for Anthropic API requests, ensuring correct authorization behavior based on request attributes and URL validation.
2025-12-23 22:30:25 +08:00
altamash
0c0aae1eac Robust change detection: replaced string concat with struct-based compare in hasModelMappingsChanged; removed boolTo01.
•  Performance: pre-allocate map and regex slice capacities in UpdateMappings.
   •  Verified with amp module tests (all passing)
2025-12-23 18:52:28 +05:30
altamash
5dcf7cb846 feat: regex support for model-mappings 2025-12-23 18:41:58 +05:30
Luis Pater
e52b542e22 Merge pull request #684 from packyme/main
docs(readme): add PackyCode sponsor
2025-12-23 17:19:25 +08:00
SmallL-U
8f6abb8a86 fix(readme): correct closing tbody tag 2025-12-23 17:17:57 +08:00
SmallL-U
ed8eaae964 docs(readme): add PackyCode sponsor 2025-12-23 17:11:34 +08:00
Luis Pater
4e572ec8b9 fix(translators): handle string system instructions in Claude translators
Updated Antigravity, Gemini, and Gemini-CLI translators to process `systemResult` of type `string` for system instructions. Ensures properly formatted JSON with dynamic content assignment.
2025-12-23 08:44:36 +08:00
Luis Pater
24bc9cba67 Fixed: #639
fix(antigravity): validate function arguments before serialization

Ensure `function.arguments` is a valid JSON before setting raw bytes, fallback to setting as parameterized content if invalid.
2025-12-23 03:49:45 +08:00
Luis Pater
1084b53fba Fixed: #655
refactor(antigravity): clean up tool key filtering and improve signature caching logic
2025-12-23 03:16:51 +08:00
Luis Pater
83b90e106f refactor(antigravity): add sandbox URL constant and update base URLs routine 2025-12-23 02:47:56 +08:00
Luis Pater
5106caf641 Fixed: #654
feat: handle array input for system instructions in translators

Enhanced Gemini, Gemini-CLI, and Antigravity translators to process array content for system instructions. Adds support for assigning roles and handling multiple content parts dynamically.
2025-12-23 02:24:26 +08:00
Luis Pater
b84ccc6e7a feat: add unit tests for routing strategies and implement dynamic selector updates
Added comprehensive tests for `FillFirstSelector` and `RoundRobinSelector` to ensure proper behavior, including deterministic, cyclical, and concurrent scenarios. Introduced dynamic routing strategy updates in `service.go`, normalizing strategies and seamlessly switching between `fill-first` and `round-robin`. Updated `Manager` to support selector changes via the new `SetSelector` method.
2025-12-22 22:52:23 +08:00
Luis Pater
e19ddb53e7 Merge pull request #663 from jroth1111/feat/fill-first-selector
feat: add fill-first routing strategy
2025-12-22 22:26:32 +08:00
gwizz
5bf89dd757 fix: keep streaming defaults legacy-safe 2025-12-23 00:53:18 +11:00
gwizz
2a0100b2d6 docs: add routing strategy example 2025-12-23 00:39:18 +11:00
gwizz
4442574e53 fix: stop streaming loop on context cancel 2025-12-23 00:37:55 +11:00
gwizz
c020fa60d0 fix: keep round-robin as default routing 2025-12-22 23:39:41 +11:00
gwizz
b078be4613 feat: add fill-first routing strategy 2025-12-22 23:38:10 +11:00
gwizz
71a6dffbb6 fix: improve streaming bootstrap and forwarding 2025-12-22 23:34:23 +11:00
Luis Pater
27b43ed63f Merge pull request #658 from moxi000/fix-responses-convert
Fix responses-format handling for chat completions(Support Cursor)
2025-12-22 16:47:46 +08:00
moxi
f6a3a1d0ba Remove compat test under translator per review 2025-12-22 16:44:50 +08:00
moxi
830fd8eac2 Fix responses-format handling for chat completions 2025-12-22 13:54:02 +08:00
Luis Pater
a86d501dc2 refactor: replace json.Marshal and json.Unmarshal with sjson and gjson
Optimized the handling of JSON serialization and deserialization by replacing redundant `json.Marshal` and `json.Unmarshal` calls with `sjson` and `gjson`. Introduced a `marshalJSONValue` utility for compact JSON encoding, improving performance and code simplicity. Removed unused `encoding/json` imports.
2025-12-22 11:44:06 +08:00
Evan Nguyen
24e8e20b59 Merge branch 'main' into fix/antigravity-prompt-caching 2025-12-21 19:43:24 +07:00
Evan Nguyen
a87f09bad2 feat(antigravity): add session ID generation and mutex for random source 2025-12-21 17:50:41 +07:00
evann
bc6c4cdbfc feat(antigravity): add logging for cached token setting errors in responses 2025-12-19 16:49:50 +07:00
evann
404546ce93 refactor(antigravity): regarding production endpoint caching 2025-12-19 16:36:54 +07:00
evann
6dd1cf1dd6 Merge branch 'main' into fix/antigravity-prompt-caching 2025-12-19 16:34:28 +07:00
evann
9058d406a3 feat(antigravity): enhance prompt caching support and update agent version 2025-12-19 16:33:41 +07:00
이대희
31bd90c748 feature: Improves Amp client compatibility
Ensures compatibility with the Amp client by suppressing
"thinking" blocks when "tool_use" blocks are also present in
the response.

The Amp client has issues rendering both types of blocks
simultaneously. This change filters out "thinking" blocks in
such cases, preventing rendering problems.
2025-12-19 08:18:27 +09:00
Muzhen Gaming
0b834fcb54 fix(translator): preserve built-in tools across openai<->responses
- Pass through non-function tool definitions like web_search

- Translate tool_choice for built-in tools and function tools

- Add regression tests for built-in tool passthrough
2025-12-15 21:18:54 +08:00
137 changed files with 10524 additions and 2871 deletions

View File

@@ -13,8 +13,6 @@ Dockerfile
docs/*
README.md
README_CN.md
MANAGEMENT_API.md
MANAGEMENT_API_CN.md
LICENSE
# Runtime data folders (should be mounted as volumes)
@@ -25,10 +23,14 @@ config.yaml
# Development/editor
bin/*
.claude/*
.vscode/*
.claude/*
.codex/*
.gemini/*
.serena/*
.agent/*
.agents/*
.opencode/*
.bmad/*
_bmad/*
_bmad-output/*

View File

@@ -7,6 +7,13 @@ assignees: ''
---
**Is it a request payload issue?**
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
[ ] No, it's another issue.
**If it's a request payload issue, you MUST know**
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
**Describe the bug**
A clear and concise description of what the bug is.

11
.gitignore vendored
View File

@@ -11,11 +11,15 @@ bin/*
logs/*
conv/*
temp/*
refs/*
# Storage backends
pgstore/*
gitstore/*
objectstore/*
# Static assets
static/*
refs/*
# Authentication data
auths/*
@@ -29,12 +33,17 @@ GEMINI.md
# Tooling metadata
.vscode/*
.codex/*
.claude/*
.gemini/*
.serena/*
.agent/*
.agents/*
.agents/*
.opencode/*
.bmad/*
_bmad/*
_bmad-output/*
# macOS
.DS_Store

View File

@@ -10,14 +10,29 @@ So you can use local or multi-account CLI access with OpenAI(include Responses)/
## Sponsor
[![z.ai](https://assets.router-for.me/english.png)](https://z.ai/subscribe?ic=8JVLJQFSKB)
[![z.ai](https://assets.router-for.me/english-4.7.png)](https://z.ai/subscribe?ic=8JVLJQFSKB)
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.7 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
Get 10% OFF GLM CODING PLANhttps://z.ai/subscribe?ic=8JVLJQFSKB
---
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=cliproxyapi">this link</a> and enter the "cliproxyapi" promo code during recharge to get 10% off.</td>
</tr>
<tr>
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
<td>Thanks to Cubence for sponsoring this project! Cubence is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. Cubence provides special discounts for our software users: register using <a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">this link</a> and enter the "CLIPROXYAPI" promo code during recharge to get 10% off.</td>
</tr>
</tbody>
</table>
## Overview
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
@@ -99,9 +114,36 @@ CLI wrapper for instant switching between multiple Claude accounts and alternati
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
### [Quotio](https://github.com/nguyenphutrong/quotio)
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
### [CodMate](https://github.com/loocor/CodMate)
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
## More choices
Those projects are ports of CLIProxyAPI or inspired by it:
### [9Router](https://github.com/decolua/9router)
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

View File

@@ -10,14 +10,30 @@
## 赞助商
[![bigmodel.cn](https://assets.router-for.me/chinese.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
[![bigmodel.cn](https://assets.router-for.me/chinese-4.7.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。
GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7,为开发者提供顶尖的编码体验。
智谱AI为本软件提供了特别优惠使用以下链接购买可以享受九折优惠https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
---
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>感谢 PackyCode 对本项目的赞助PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.packyapi.com/register?aff=cliproxyapi">此链接</a>注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。</td>
</tr>
<tr>
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
<td>感谢 Cubence 对本项目的赞助Cubence 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。Cubence 为本软件用户提供了特别优惠:使用<a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">此链接</a>注册,并在充值时输入 "CLIPROXYAPI" 优惠码即可享受九折优惠。</td>
</tr>
</tbody>
</table>
## 功能特性
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
@@ -97,9 +113,36 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
基于 macOS 平台的原生 CLIProxyAPI GUI配置供应商、模型映射以及OAuth端点无需 API 密钥。
### [Quotio](https://github.com/nguyenphutrong/quotio)
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
### [CodMate](https://github.com/loocor/CodMate)
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话Claude Code、Codex、Gemini CLI提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。
## 更多选择
以下项目是 CLIProxyAPI 的移植版或受其启发:
### [9Router](https://github.com/decolua/9router)
基于 Next.js 的实现,灵感来自 CLIProxyAPI易于安装使用自研格式转换OpenAI/Claude/Gemini/Ollama、组合系统与自动回退、多账户管理指数退避、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

BIN
assets/cubence.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
assets/packycode.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

View File

@@ -405,7 +405,7 @@ func main() {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
if err = logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to configure log output: %v", err)
return
}

View File

@@ -35,10 +35,14 @@ auth-dir: "~/.cli-proxy-api"
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
# Enable debug logging
debug: false
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
# When true, write application logs to rotating files instead of stdout
logging-to-file: false
@@ -66,9 +70,18 @@ quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
# streaming:
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
# Gemini API keys
# gemini-api-key:
# - api-key: "AIzaSy...01"
@@ -77,6 +90,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
# excluded-models:
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
@@ -92,6 +108,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
# excluded-models:
# - "gpt-5.1" # exclude specific models (exact match)
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
@@ -109,7 +128,7 @@ ws-auth: false
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# excluded-models:
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
@@ -140,9 +159,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names
# - name: "gemini-2.0-flash" # upstream model name
# - name: "gemini-2.5-flash" # upstream model name
# alias: "vertex-flash" # client-visible alias
# - name: "gemini-1.5-pro"
# - name: "gemini-2.5-pro"
# alias: "vertex-pro"
# Amp Integration
@@ -151,6 +170,18 @@ ws-auth: false
# upstream-url: "https://ampcode.com"
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
# upstream-api-key: ""
# # Per-client upstream API key mapping
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
# # Useful when different clients need to use different Amp accounts/quotas.
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
# upstream-api-keys:
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
# api-keys: # Client keys that use this upstream key
# - "your-api-key-1"
# - "your-api-key-2"
# - upstream-api-key: "amp_key_for_team_b"
# api-keys:
# - "your-api-key-3"
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
# restrict-management-to-localhost: false
# # Force model mappings to run before checking local API keys (default: false)
@@ -160,12 +191,44 @@ ws-auth: false
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
# # but you have a similar model available (e.g., Claude Sonnet 4).
# model-mappings:
# - from: "claude-opus-4.5" # Model requested by Amp CLI
# to: "claude-sonnet-4" # Route to this available model instead
# - from: "gpt-5"
# to: "gemini-2.5-pro"
# - from: "claude-3-opus-20240229"
# to: "claude-3-5-sonnet-20241022"
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
# - from: "claude-sonnet-4-5-20250929"
# to: "gemini-claude-sonnet-4-5-thinking"
# - from: "claude-haiku-4-5-20251001"
# to: "gemini-2.5-flash"
# Global OAuth model name mappings (per channel)
# These mappings rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# You can repeat the same name with different aliases to expose multiple client model names.
# oauth-model-mappings:
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
# vertex:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-preview"
# alias: "g3p"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
# codex:
# - name: "gpt-5"
# alias: "g5"
# qwen:
# - name: "qwen3-coder-plus"
# alias: "qwen-plus"
# iflow:
# - name: "glm-4.7"
# alias: "glm-god"
# OAuth provider excluded models
# oauth-excluded-models:

View File

@@ -5,9 +5,115 @@
# This script automates the process of building and running the Docker container
# with version information dynamically injected at build time.
# Exit immediately if a command exits with a non-zero status.
# Hidden feature: Preserve usage statistics across rebuilds
# Usage: ./docker-build.sh --with-usage
# First run prompts for management API key, saved to temp/stats/.api_secret
set -euo pipefail
STATS_DIR="temp/stats"
STATS_FILE="${STATS_DIR}/.usage_backup.json"
SECRET_FILE="${STATS_DIR}/.api_secret"
WITH_USAGE=false
get_port() {
if [[ -f "config.yaml" ]]; then
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
else
echo "8317"
fi
}
export_stats_api_secret() {
if [[ -f "${SECRET_FILE}" ]]; then
API_SECRET=$(cat "${SECRET_FILE}")
else
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
echo "First time using --with-usage. Management API key required."
read -r -p "Enter management key: " -s API_SECRET
echo
echo "${API_SECRET}" > "${SECRET_FILE}"
chmod 600 "${SECRET_FILE}"
fi
}
check_container_running() {
local port
port=$(get_port)
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
echo "Please start the container first or use without --with-usage flag."
exit 1
fi
}
export_stats() {
local port
port=$(get_port)
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
check_container_running
echo "Exporting usage statistics..."
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
"http://localhost:${port}/v0/management/usage/export")
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
if [[ "${HTTP_CODE}" != "200" ]]; then
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
exit 1
fi
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
echo "Statistics exported to ${STATS_FILE}"
}
import_stats() {
local port
port=$(get_port)
echo "Importing usage statistics..."
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
-H "X-Management-Key: ${API_SECRET}" \
-H "Content-Type: application/json" \
-d @"${STATS_FILE}" \
"http://localhost:${port}/v0/management/usage/import")
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
if [[ "${IMPORT_CODE}" == "200" ]]; then
echo "Statistics imported successfully"
else
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
fi
rm -f "${STATS_FILE}"
}
wait_for_service() {
local port
port=$(get_port)
echo "Waiting for service to be ready..."
for i in {1..30}; do
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
break
fi
sleep 1
done
sleep 2
}
if [[ "${1:-}" == "--with-usage" ]]; then
WITH_USAGE=true
export_stats_api_secret
fi
# --- Step 1: Choose Environment ---
echo "Please select an option:"
echo "1) Run using Pre-built Image (Recommended)"
@@ -18,7 +124,14 @@ read -r -p "Enter choice [1-2]: " choice
case "$choice" in
1)
echo "--- Running with Pre-built Image ---"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
docker compose up -d --remove-orphans --no-build
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Services are starting from remote image."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -38,16 +151,25 @@ case "$choice" in
# Build and start the services with a local-only image tag
export CLI_PROXY_IMAGE="cli-proxy-api:local"
echo "Building the Docker image..."
docker compose build \
--build-arg VERSION="${VERSION}" \
--build-arg COMMIT="${COMMIT}" \
--build-arg BUILD_DATE="${BUILD_DATE}"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
echo "Starting the services..."
docker compose up -d --remove-orphans --pull never
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Build complete. Services are starting."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -55,4 +177,4 @@ case "$choice" in
echo "Invalid choice. Please enter 1 or 2."
exit 1
;;
esac
esac

View File

@@ -14,6 +14,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
@@ -122,7 +123,9 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
httpReq.Header.Set("Content-Type", "application/json")
// Inject credentials via PrepareRequest hook.
_ = (MyExecutor{}).PrepareRequest(httpReq, a)
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
return clipexec.Response{}, errPrep
}
resp, errDo := client.Do(httpReq)
if errDo != nil {
@@ -130,13 +133,28 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
// Best-effort close; log if needed in real projects.
fmt.Fprintf(os.Stderr, "close response body error: %v\n", errClose)
}
}()
body, _ := io.ReadAll(resp.Body)
return clipexec.Response{Payload: body}, nil
}
func (MyExecutor) HttpRequest(ctx context.Context, a *coreauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("myprov executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
return nil, errPrep
}
client := buildHTTPClient(a)
return client.Do(httpReq)
}
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("count tokens not implemented")
}
@@ -199,8 +217,8 @@ func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
panic(err)
if errRun := svc.Run(ctx); errRun != nil && !errors.Is(errRun, context.Canceled) {
panic(errRun)
}
_ = os.Stderr // keep os import used (demo only)
_ = time.Second

View File

@@ -0,0 +1,140 @@
// Package main demonstrates how to use coreauth.Manager.HttpRequest/NewHttpRequest
// to execute arbitrary HTTP requests with provider credentials injected.
//
// This example registers a minimal custom executor that injects an Authorization
// header from auth.Attributes["api_key"], then performs two requests against
// httpbin.org to show the injected headers.
package main
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
const providerKey = "echo"
// EchoExecutor is a minimal provider implementation for demonstration purposes.
type EchoExecutor struct{}
func (EchoExecutor) Identifier() string { return providerKey }
func (EchoExecutor) PrepareRequest(req *http.Request, auth *coreauth.Auth) error {
if req == nil || auth == nil {
return nil
}
if auth.Attributes != nil {
if apiKey := strings.TrimSpace(auth.Attributes["api_key"]); apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
}
return nil
}
func (EchoExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("echo executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrep := (EchoExecutor{}).PrepareRequest(httpReq, auth); errPrep != nil {
return nil, errPrep
}
return http.DefaultClient.Do(httpReq)
}
func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
}
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
return nil, errors.New("echo executor: ExecuteStream not implemented")
}
func (EchoExecutor) Refresh(context.Context, *coreauth.Auth) (*coreauth.Auth, error) {
return nil, errors.New("echo executor: Refresh not implemented")
}
func (EchoExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("echo executor: CountTokens not implemented")
}
func main() {
log.SetLevel(log.InfoLevel)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
core := coreauth.NewManager(nil, nil, nil)
core.RegisterExecutor(EchoExecutor{})
auth := &coreauth.Auth{
ID: "demo-echo",
Provider: providerKey,
Attributes: map[string]string{
"api_key": "demo-api-key",
},
}
// Example 1: Build a prepared request and execute it using your own http.Client.
reqPrepared, errReqPrepared := core.NewHttpRequest(
ctx,
auth,
http.MethodGet,
"https://httpbin.org/anything",
nil,
http.Header{"X-Example": []string{"prepared"}},
)
if errReqPrepared != nil {
panic(errReqPrepared)
}
respPrepared, errDoPrepared := http.DefaultClient.Do(reqPrepared)
if errDoPrepared != nil {
panic(errDoPrepared)
}
defer func() {
if errClose := respPrepared.Body.Close(); errClose != nil {
log.Errorf("close response body error: %v", errClose)
}
}()
bodyPrepared, errReadPrepared := io.ReadAll(respPrepared.Body)
if errReadPrepared != nil {
panic(errReadPrepared)
}
fmt.Printf("Prepared request status: %d\n%s\n\n", respPrepared.StatusCode, bodyPrepared)
// Example 2: Execute a raw request via core.HttpRequest (auto inject + do).
rawBody := []byte(`{"hello":"world"}`)
rawReq, errRawReq := http.NewRequestWithContext(ctx, http.MethodPost, "https://httpbin.org/anything", bytes.NewReader(rawBody))
if errRawReq != nil {
panic(errRawReq)
}
rawReq.Header.Set("Content-Type", "application/json")
rawReq.Header.Set("X-Example", "executed")
respExec, errDoExec := core.HttpRequest(ctx, auth, rawReq)
if errDoExec != nil {
panic(errDoExec)
}
defer func() {
if errClose := respExec.Body.Close(); errClose != nil {
log.Errorf("close response body error: %v", errClose)
}
}()
bodyExec, errReadExec := io.ReadAll(respExec.Body)
if errReadExec != nil {
panic(errReadExec)
}
fmt.Printf("Manager HttpRequest status: %d\n%s\n", respExec.StatusCode, bodyExec)
}

View File

@@ -0,0 +1,704 @@
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
const (
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
const (
antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityOAuthClientSecret = "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"
)
var antigravityOAuthTokenURL = "https://oauth2.googleapis.com/token"
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
Method string `json:"method"`
URL string `json:"url"`
Header map[string]string `json:"header"`
Data string `json:"data"`
}
type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
// It is protected by the management middleware.
//
// Endpoint:
//
// POST /v0/management/api-call
//
// Authentication:
//
// Same as other management APIs (requires a management key and remote-management rules).
// You can provide the key via:
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON:
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
// - header (optional): Request headers map.
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
// 1) metadata.access_token
// 2) attributes.api_key
// 3) metadata.token / metadata.id_token / metadata.cookie
// Example: {"Authorization":"Bearer $TOKEN$"}.
// Note: if you need to override the HTTP Host header, set header["Host"].
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
//
// Proxy selection (highest priority first):
// 1. Selected credential proxy_url
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
//
// Example:
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer 831227" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
if method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
return
}
urlStr := strings.TrimSpace(body.URL)
if urlStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
auth := h.authByIndex(authIndex)
reqHeaders := body.Header
if reqHeaders == nil {
reqHeaders = map[string]string{}
}
var hostOverride string
var token string
var tokenResolved bool
var tokenErr error
for key, value := range reqHeaders {
if !strings.Contains(value, "$TOKEN$") {
continue
}
if !tokenResolved {
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
tokenResolved = true
}
if auth != nil && token == "" {
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
return
}
if token == "" {
continue
}
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
}
for key, value := range reqHeaders {
if strings.EqualFold(key, "host") {
hostOverride = strings.TrimSpace(value)
continue
}
req.Header.Set(key, value)
}
if hostOverride != "" {
req.Host = hostOverride
}
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
}
httpClient.Transport = h.apiCallTransport(auth)
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("management APICall request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
c.JSON(http.StatusOK, apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
}
func firstNonEmptyString(values ...*string) string {
for _, v := range values {
if v == nil {
continue
}
if out := strings.TrimSpace(*v); out != "" {
return out
}
}
return ""
}
func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
return v
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return v
}
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
return v
}
}
return ""
}
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "gemini-cli" {
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
if provider == "antigravity" {
token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata, updater := geminiOAuthMetadata(auth)
if len(metadata) == 0 {
return "", fmt.Errorf("gemini oauth metadata missing")
}
base := make(map[string]any)
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
}
var token oauth2.Token
if len(base) > 0 {
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
src := conf.TokenSource(ctxToken, &token)
currentToken, errToken := src.Token()
if errToken != nil {
return "", errToken
}
merged := buildOAuthTokenMap(base, currentToken)
fields := buildOAuthTokenFields(currentToken, merged)
if updater != nil {
updater(fields)
}
return strings.TrimSpace(currentToken.AccessToken), nil
}
func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata := auth.Metadata
if len(metadata) == 0 {
return "", fmt.Errorf("antigravity oauth metadata missing")
}
current := strings.TrimSpace(tokenValueFromMetadata(metadata))
if current != "" && !antigravityTokenNeedsRefresh(metadata) {
return current, nil
}
refreshToken := stringValue(metadata, "refresh_token")
if refreshToken == "" {
return "", fmt.Errorf("antigravity refresh token missing")
}
tokenURL := strings.TrimSpace(antigravityOAuthTokenURL)
if tokenURL == "" {
tokenURL = "https://oauth2.googleapis.com/token"
}
form := url.Values{}
form.Set("client_id", antigravityOAuthClientID)
form.Set("client_secret", antigravityOAuthClientSecret)
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
req, errReq := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if errReq != nil {
return "", errReq
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", errRead
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("antigravity oauth token refresh failed: status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
return "", errUnmarshal
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token")
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
now := time.Now()
auth.Metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if strings.TrimSpace(tokenResp.RefreshToken) != "" {
auth.Metadata["refresh_token"] = strings.TrimSpace(tokenResp.RefreshToken)
}
if tokenResp.ExpiresIn > 0 {
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
auth.Metadata["timestamp"] = now.UnixMilli()
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
}
auth.Metadata["type"] = "antigravity"
if h != nil && h.authManager != nil {
auth.LastRefreshedAt = now
auth.UpdatedAt = now
_, _ = h.authManager.Update(ctx, auth)
}
return strings.TrimSpace(tokenResp.AccessToken), nil
}
func antigravityTokenNeedsRefresh(metadata map[string]any) bool {
// Refresh a bit early to avoid requests racing token expiry.
const skew = 30 * time.Second
if metadata == nil {
return true
}
if expStr, ok := metadata["expired"].(string); ok {
if ts, errParse := time.Parse(time.RFC3339, strings.TrimSpace(expStr)); errParse == nil {
return !ts.After(time.Now().Add(skew))
}
}
expiresIn := int64Value(metadata["expires_in"])
timestampMs := int64Value(metadata["timestamp"])
if expiresIn > 0 && timestampMs > 0 {
exp := time.UnixMilli(timestampMs).Add(time.Duration(expiresIn) * time.Second)
return !exp.After(time.Now().Add(skew))
}
return true
}
func int64Value(raw any) int64 {
switch typed := raw.(type) {
case int:
return int64(typed)
case int32:
return int64(typed)
case int64:
return typed
case uint:
return int64(typed)
case uint32:
return int64(typed)
case uint64:
if typed > uint64(^uint64(0)>>1) {
return 0
}
return int64(typed)
case float32:
return int64(typed)
case float64:
return int64(typed)
case json.Number:
if i, errParse := typed.Int64(); errParse == nil {
return i
}
case string:
if s := strings.TrimSpace(typed); s != "" {
if i, errParse := json.Number(s).Int64(); errParse == nil {
return i
}
}
}
return 0
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
snapshot := shared.MetadataSnapshot()
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
}
return auth.Metadata, func(fields map[string]any) {
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
}
func stringValue(metadata map[string]any, key string) string {
if len(metadata) == 0 || key == "" {
return ""
}
if v, ok := metadata[key].(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if tok == nil {
return merged
}
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
var tokenMap map[string]any
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok != nil && tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok != nil && tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok != nil && tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if tok != nil && !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func tokenValueFromMetadata(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
switch typed := tokenRaw.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
return v
}
case map[string]any:
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
case map[string]string:
if v := strings.TrimSpace(typed["access_token"]); v != "" {
return v
}
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
return v
}
}
}
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
return ""
}
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
authIndex = strings.TrimSpace(authIndex)
if authIndex == "" || h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
return nil
}
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
return transport
}
}
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
return &http.Transport{Proxy: nil}
}
clone := transport.Clone()
clone.Proxy = nil
return clone
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}

View File

@@ -0,0 +1,173 @@
package management
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
_ = ctx
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
h := &Handler{authManager: manager}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
}
updated, ok := manager.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth in manager after update")
}
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
t.Fatalf("expected manager metadata updated, got %q", got)
}
}
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token)
}
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount)
}
}

View File

@@ -427,9 +427,52 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
log.WithError(err).Warnf("failed to stat auth file %s", path)
}
}
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
return entry
}
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
if auth == nil || auth.Metadata == nil {
return nil
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
return nil
}
idTokenRaw, ok := auth.Metadata["id_token"].(string)
if !ok {
return nil
}
idToken := strings.TrimSpace(idTokenRaw)
if idToken == "" {
return nil
}
claims, err := codex.ParseJWTToken(idToken)
if err != nil || claims == nil {
return nil
}
result := gin.H{}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
result["chatgpt_account_id"] = v
}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
result["plan_type"] = v
}
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveStart; v != nil {
result["chatgpt_subscription_active_start"] = v
}
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveUntil; v != nil {
result["chatgpt_subscription_active_until"] = v
}
if len(result) == 0 {
return nil
}
return result
}
func authEmail(auth *coreauth.Auth) string {
if auth == nil {
return ""

View File

@@ -202,6 +202,26 @@ func (h *Handler) PutLoggingToFile(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
}
// LogsMaxTotalSizeMB
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
}
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
var body struct {
Value *int `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
value := *body.Value
if value < 0 {
value = 0
}
h.cfg.LogsMaxTotalSizeMB = value
h.persist(c)
}
// Request log
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
func (h *Handler) PutRequestLog(c *gin.Context) {
@@ -232,6 +252,52 @@ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
}
// ForceModelPrefix
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
}
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
}
func normalizeRoutingStrategy(strategy string) (string, bool) {
normalized := strings.ToLower(strings.TrimSpace(strategy))
switch normalized {
case "", "round-robin", "roundrobin", "rr":
return "round-robin", true
case "fill-first", "fillfirst", "ff":
return "fill-first", true
default:
return "", false
}
}
// RoutingStrategy
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
if !ok {
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
return
}
c.JSON(200, gin.H{"strategy": strategy})
}
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
var body struct {
Value *string `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
normalized, ok := normalizeRoutingStrategy(*body.Value)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
return
}
h.cfg.Routing.Strategy = normalized
h.persist(c)
}
// Proxy URL
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
func (h *Handler) PutProxyURL(c *gin.Context) {

View File

@@ -487,6 +487,137 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
c.JSON(400, gin.H{"error": "missing name or index"})
}
// vertex-api-key: []VertexCompatKey
func (h *Handler) GetVertexCompatKeys(c *gin.Context) {
c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey})
}
func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.VertexCompatKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.VertexCompatKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
for i := range arr {
normalizeVertexCompatKey(&arr[i])
}
h.cfg.VertexCompatAPIKey = arr
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
type vertexCompatPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
Models *[]config.VertexCompatModel `json:"models"`
}
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *vertexCompatPatch `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
if match != "" {
for i := range h.cfg.VertexCompatAPIKey {
if h.cfg.VertexCompatAPIKey[i].APIKey == match {
targetIndex = i
break
}
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.VertexCompatAPIKey[targetIndex]
if body.Value.APIKey != nil {
trimmed := strings.TrimSpace(*body.Value.APIKey)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.APIKey = trimmed
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.Models != nil {
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
}
normalizeVertexCompatKey(&entry)
h.cfg.VertexCompatAPIKey[targetIndex] = entry
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
for _, v := range h.cfg.VertexCompatAPIKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.VertexCompatAPIKey = out
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, errScan := fmt.Sscanf(idxStr, "%d", &idx)
if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
// oauth-excluded-models: map[string][]string
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
@@ -572,6 +703,103 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
h.persist(c)
}
// oauth-model-mappings: map[string][]ModelNameMapping
func (h *Handler) GetOAuthModelMappings(c *gin.Context) {
c.JSON(200, gin.H{"oauth-model-mappings": sanitizedOAuthModelMappings(h.cfg.OAuthModelMappings)})
}
func (h *Handler) PutOAuthModelMappings(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var entries map[string][]config.ModelNameMapping
if err = json.Unmarshal(data, &entries); err != nil {
var wrapper struct {
Items map[string][]config.ModelNameMapping `json:"items"`
}
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
entries = wrapper.Items
}
h.cfg.OAuthModelMappings = sanitizedOAuthModelMappings(entries)
h.persist(c)
}
func (h *Handler) PatchOAuthModelMappings(c *gin.Context) {
var body struct {
Provider *string `json:"provider"`
Channel *string `json:"channel"`
Mappings []config.ModelNameMapping `json:"mappings"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
channelRaw := ""
if body.Channel != nil {
channelRaw = *body.Channel
} else if body.Provider != nil {
channelRaw = *body.Provider
}
channel := strings.ToLower(strings.TrimSpace(channelRaw))
if channel == "" {
c.JSON(400, gin.H{"error": "invalid channel"})
return
}
normalizedMap := sanitizedOAuthModelMappings(map[string][]config.ModelNameMapping{channel: body.Mappings})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
if h.cfg.OAuthModelMappings == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelMappings, channel)
if len(h.cfg.OAuthModelMappings) == 0 {
h.cfg.OAuthModelMappings = nil
}
h.persist(c)
return
}
if h.cfg.OAuthModelMappings == nil {
h.cfg.OAuthModelMappings = make(map[string][]config.ModelNameMapping)
}
h.cfg.OAuthModelMappings[channel] = normalized
h.persist(c)
}
func (h *Handler) DeleteOAuthModelMappings(c *gin.Context) {
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
if channel == "" {
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
}
if channel == "" {
c.JSON(400, gin.H{"error": "missing channel"})
return
}
if h.cfg.OAuthModelMappings == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelMappings, channel)
if len(h.cfg.OAuthModelMappings) == 0 {
h.cfg.OAuthModelMappings = nil
}
h.persist(c)
}
// codex-api-key: []CodexKey
func (h *Handler) GetCodexKeys(c *gin.Context) {
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
@@ -597,11 +825,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
filtered := make([]config.CodexKey, 0, len(arr))
for i := range arr {
entry := arr[i]
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
normalizeCodexKey(&entry)
if entry.BaseURL == "" {
continue
}
@@ -613,12 +837,13 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
}
func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Models *[]config.CodexModel `json:"models"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct {
Index *int `json:"index"`
@@ -667,12 +892,16 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Models != nil {
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeCodexKey(&entry)
h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys()
h.persist(c)
@@ -762,6 +991,79 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
entry.Models = normalized
}
func normalizeCodexKey(entry *config.CodexKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.CodexModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" && model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.VertexCompatModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" || model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func sanitizedOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping {
if len(entries) == 0 {
return nil
}
copied := make(map[string][]config.ModelNameMapping, len(entries))
for channel, mappings := range entries {
if len(mappings) == 0 {
continue
}
copied[channel] = append([]config.ModelNameMapping(nil), mappings...)
}
if len(copied) == 0 {
return nil
}
cfg := config.Config{OAuthModelMappings: copied}
cfg.SanitizeOAuthModelMappings()
if len(cfg.OAuthModelMappings) == 0 {
return nil
}
return cfg.OAuthModelMappings
}
// GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil {
@@ -913,3 +1215,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
}
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
return
}
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
}
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
// Normalize entries: trim whitespace, filter empty
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
h.cfg.AmpCode.UpstreamAPIKeys = normalized
h.persist(c)
}
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
// Matching is done by upstream-api-key value.
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
}
for _, newEntry := range body.Value {
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
}
if idx, ok := existing[upstreamKey]; ok {
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
} else {
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
}
}
h.persist(c)
}
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
// If "value" is an empty array, clears all entries.
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Value == nil {
c.JSON(400, gin.H{"error": "missing value"})
return
}
// Empty array means clear all
if len(body.Value) == 0 {
h.cfg.AmpCode.UpstreamAPIKeys = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, key := range body.Value {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
toRemove[trimmed] = true
}
if len(toRemove) == 0 {
c.JSON(400, gin.H{"error": "empty value"})
return
}
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
newEntries = append(newEntries, entry)
}
}
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
h.persist(c)
}
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
if len(entries) == 0 {
return nil
}
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
apiKeys := normalizeAPIKeysList(entry.APIKeys)
out = append(out, config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: apiKeys,
})
}
if len(out) == 0 {
return nil
}
return out
}
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
func normalizeAPIKeysList(keys []string) []string {
if len(keys) == 0 {
return nil
}
out := make([]string, 0, len(keys))
for _, k := range keys {
trimmed := strings.TrimSpace(k)
if trimmed != "" {
out = append(out, trimmed)
}
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -24,8 +24,15 @@ import (
type attemptInfo struct {
count int
blockedUntil time.Time
lastActivity time.Time // track last activity for cleanup
}
// attemptCleanupInterval controls how often stale IP entries are purged
const attemptCleanupInterval = 1 * time.Hour
// attemptMaxIdleTime controls how long an IP can be idle before cleanup
const attemptMaxIdleTime = 2 * time.Hour
// Handler aggregates config reference, persistence path and helpers.
type Handler struct {
cfg *config.Config
@@ -47,7 +54,7 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
envSecret = strings.TrimSpace(envSecret)
return &Handler{
h := &Handler{
cfg: cfg,
configFilePath: configFilePath,
failedAttempts: make(map[string]*attemptInfo),
@@ -57,6 +64,43 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
allowRemoteOverride: envSecret != "",
envSecret: envSecret,
}
h.startAttemptCleanup()
return h
}
// startAttemptCleanup launches a background goroutine that periodically
// removes stale IP entries from failedAttempts to prevent memory leaks.
func (h *Handler) startAttemptCleanup() {
go func() {
ticker := time.NewTicker(attemptCleanupInterval)
defer ticker.Stop()
for range ticker.C {
h.purgeStaleAttempts()
}
}()
}
// purgeStaleAttempts removes IP entries that have been idle beyond attemptMaxIdleTime
// and whose ban (if any) has expired.
func (h *Handler) purgeStaleAttempts() {
now := time.Now()
h.attemptsMu.Lock()
defer h.attemptsMu.Unlock()
for ip, ai := range h.failedAttempts {
// Skip if still banned
if !ai.blockedUntil.IsZero() && now.Before(ai.blockedUntil) {
continue
}
// Remove if idle too long
if now.Sub(ai.lastActivity) > attemptMaxIdleTime {
delete(h.failedAttempts, ip)
}
}
}
// NewHandler creates a new management handler instance.
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
return NewHandler(cfg, "", manager)
}
// SetConfig updates the in-memory config reference when the server hot-reloads.
@@ -144,6 +188,7 @@ func (h *Handler) Middleware() gin.HandlerFunc {
h.failedAttempts[clientIP] = aip
}
aip.count++
aip.lastActivity = time.Now()
if aip.count >= maxFailures {
aip.blockedUntil = time.Now().Add(banDuration)
aip.count = 0

View File

@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"files": files})
}
// GetRequestLogByID finds and downloads a request log file by its request ID.
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
func (h *Handler) GetRequestLogByID(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
requestID := strings.TrimSpace(c.Param("id"))
if requestID == "" {
requestID = strings.TrimSpace(c.Query("id"))
}
if requestID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
return
}
if strings.ContainsAny(requestID, "/\\") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
return
}
suffix := "-" + requestID + ".log"
var matchedFile string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, suffix) {
matchedFile = name
break
}
}
if matchedFile == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
return
}
dirAbs, errAbs := filepath.Abs(dir)
if errAbs != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
return
}
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
prefix := dirAbs + string(os.PathSeparator)
if !strings.HasPrefix(fullPath, prefix) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
return
}
info, errStat := os.Stat(fullPath)
if errStat != nil {
if os.IsNotExist(errStat) {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
return
}
if info.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
return
}
c.FileAttachment(fullPath, matchedFile)
}
// DownloadRequestErrorLog downloads a specific error request log file by name.
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
if h == nil {

View File

@@ -1,12 +1,25 @@
package management
import (
"encoding/json"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
)
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
"failed_requests": snapshot.FailureCount,
})
}
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
}

View File

@@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
}
return &RequestInfo{
URL: url,
Method: method,
Headers: headers,
Body: body,
URL: url,
Method: method,
Headers: headers,
Body: body,
RequestID: logging.GetGinRequestID(c),
}, nil
}
@@ -114,7 +115,7 @@ func shouldLogRequest(path string) bool {
}
if strings.HasPrefix(path, "/api") {
return strings.HasPrefix(path, "/api/provider")
return strings.HasPrefix(path, "/api/provider") || strings.HasPrefix(path, "/api/internal")
}
return true

View File

@@ -15,10 +15,11 @@ import (
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
RequestID string // RequestID is the unique identifier for the request.
}
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
@@ -149,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
w.requestInfo.Method,
w.requestInfo.Headers,
w.requestInfo.Body,
w.requestInfo.RequestID,
)
if err == nil {
w.streamWriter = streamWriter
@@ -346,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
@@ -360,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiResponseBody,
apiResponseErrors,
forceLog,
w.requestInfo.RequestID,
)
}
@@ -374,5 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiRequestBody,
apiResponseBody,
apiResponseErrors,
w.requestInfo.RequestID,
)
}

View File

@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
}
}
// Check API key change
// Check API key change (both default and per-client mappings)
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
if apiKeyChanged {
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
if apiKeyChanged || upstreamAPIKeysChanged {
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
if apiKeyChanged {
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
if upstreamAPIKeysChanged {
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
}
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
if m.secretSource == nil {
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
mappedSource := NewMappedSecretSource(defaultSource)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
ms.UpdateMappings(settings.UpstreamAPIKeys)
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
mappedSource := NewMappedSecretSource(ms)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
}
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
@@ -279,16 +300,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
return true
}
// Build map for efficient comparison
oldMap := make(map[string]string, len(old.ModelMappings))
// Build map for efficient and robust comparison
type mappingInfo struct {
to string
regex bool
}
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
to: strings.TrimSpace(mapping.To),
regex: mapping.Regex,
}
}
for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
return true
}
}
@@ -306,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
return oldKey != newKey
}
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.UpstreamAPIKeys) > 0
}
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
return true
}
// Build map for comparison: upstreamKey -> set of clientKeys
type entryInfo struct {
upstreamKey string
clientKeys map[string]struct{}
}
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
for i, entry := range old.UpstreamAPIKeys {
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
for _, k := range entry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
clientKeys[trimmed] = struct{}{}
}
oldEntries[i] = entryInfo{
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
clientKeys: clientKeys,
}
}
for i, newEntry := range new.UpstreamAPIKeys {
if i >= len(oldEntries) {
return true
}
oldE := oldEntries[i]
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
return true
}
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
for _, k := range newEntry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
newKeys[trimmed] = struct{}{}
}
if len(newKeys) != len(oldE.clientKeys) {
return true
}
for k := range newKeys {
if _, ok := oldE.clientKeys[k]; !ok {
return true
}
}
}
return false
}
// GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper

View File

@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
})
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
},
}
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
},
}
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected no change when only whitespace/empty entries differ")
}
}

View File

@@ -3,6 +3,7 @@
package amp
import (
"regexp"
"strings"
"sync"
@@ -26,13 +27,15 @@ type ModelMapper interface {
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct {
mu sync.RWMutex
mappings map[string]string // from -> to (normalized lowercase keys)
mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
}
// NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{
mappings: make(map[string]string),
regexps: nil,
}
m.UpdateMappings(mappings)
return m
@@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
// Check for direct mapping
targetModel, exists := m.mappings[normalizedRequest]
if !exists {
return ""
// Try regex mappings in order
base, _ := util.NormalizeThinkingModel(requestedModel)
for _, rm := range m.regexps {
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
targetModel = rm.to
exists = true
break
}
}
if !exists {
return ""
}
}
// Verify target model has available providers
@@ -78,6 +92,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
// Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings))
m.regexps = make([]regexMapping, 0, len(mappings))
for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
@@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
continue
}
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
if mapping.Regex {
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
pattern := "(?i)" + from
re, err := regexp.Compile(pattern)
if err != nil {
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
continue
}
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
} else {
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
}
if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
}
if n := len(m.regexps); n > 0 {
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
}
}
// GetMappings returns a copy of current mappings (for debugging/status).
@@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
}
return result
}
type regexMapping struct {
re *regexp.Regexp
to string
}

View File

@@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
t.Error("Original map was modified")
}
}
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-1")
mappings := []config.AmpModelMapping{
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
}
mapper := NewModelMapper(mappings)
// Incoming model has reasoning suffix but should match base via regex
result := mapper.MapModel("gpt-5(high)")
if result != "gemini-2.5-pro" {
t.Errorf("Expected gemini-2.5-pro, got %s", result)
}
}
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-2")
defer reg.UnregisterClient("test-client-regex-3")
mappings := []config.AmpModelMapping{
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
}
mapper := NewModelMapper(mappings)
// Exact match should win over regex
result := mapper.MapModel("gpt-5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
// Invalid regex should be skipped and not cause panic
mappings := []config.AmpModelMapping{
{From: "(", To: "target", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("anything")
if result != "" {
t.Errorf("Expected empty result due to invalid regex, got %s", result)
}
}
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-regex-4")
mappings := []config.AmpModelMapping{
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}

View File

@@ -15,6 +15,33 @@ import (
log "github.com/sirupsen/logrus"
)
func removeQueryValuesMatching(req *http.Request, key string, match string) {
if req == nil || req.URL == nil || match == "" {
return
}
q := req.URL.Query()
values, ok := q[key]
if !ok || len(values) == 0 {
return
}
kept := make([]string, 0, len(values))
for _, v := range values {
if v == match {
continue
}
kept = append(kept, v)
}
if len(kept) == 0 {
q.Del(key)
} else {
q[key] = kept
}
req.URL.RawQuery = q.Encode()
}
// readCloser wraps a reader and forwards Close to a separate closer.
// Used to restore peeked bytes while preserving upstream body Close behavior.
type readCloser struct {
@@ -45,6 +72,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// We will set our own Authorization using the configured upstream-api-key
req.Header.Del("Authorization")
req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
// Remove query-based credentials if they match the authenticated client API key.
// This prevents leaking client auth material to the Amp upstream while avoiding
// breaking unrelated upstream query parameters.
clientKey := getClientAPIKeyFromContext(req.Context())
removeQueryValuesMatching(req, "key", clientKey)
removeQueryValuesMatching(req, "auth_token", clientKey)
// Preserve correlation headers for debugging
if req.Header.Get("X-Request-ID") == "" {

View File

@@ -3,11 +3,15 @@ package amp
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Helper: compress data with gzip
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
}
}
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))

View File

@@ -69,7 +69,30 @@ func (rw *ResponseRewriter) Flush() {
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount {
var err error
data, err = sjson.SetBytes(data, "content", filtered.Value())
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
}
}
}
}
if rw.originalModel == "" {
return data
}

View File

@@ -1,6 +1,7 @@
package amp
import (
"context"
"errors"
"net"
"net/http"
@@ -16,6 +17,37 @@ import (
log "github.com/sirupsen/logrus"
)
// clientAPIKeyContextKey is the context key used to pass the client API key
// from gin.Context to the request context for SecretSource lookup.
type clientAPIKeyContextKey struct{}
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
// into the request context so that SecretSource can look it up for per-client upstream routing.
func clientAPIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the client API key from gin context (set by AuthMiddleware)
if apiKey, exists := c.Get("apiKey"); exists {
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
// Inject into request context for SecretSource.Get(ctx) to read
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
c.Request = c.Request.WithContext(ctx)
}
}
c.Next()
}
}
// getClientAPIKeyFromContext retrieves the client API key from request context.
// Returns empty string if not present.
func getClientAPIKeyFromContext(ctx context.Context) string {
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
if keyStr, ok := val.(string); ok {
return keyStr
}
}
return ""
}
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
}
// Inject client API key into request context for per-client upstream routing
ampAPI.Use(clientAPIKeyMiddleware())
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
if authWithBypass != nil {
rootMiddleware = append(rootMiddleware, authWithBypass)
}
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
if auth != nil {
ampProviders.Use(auth)
}
// Inject client API key into request context for per-client upstream routing
ampProviders.Use(clientAPIKeyMiddleware())
provider := ampProviders.Group("/:provider")

View File

@@ -9,6 +9,9 @@ import (
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
// SecretSource provides Amp API keys with configurable precedence and caching
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
return s.key, nil
}
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
// When a request context contains a client API key that matches a configured mapping,
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
type MappedSecretSource struct {
defaultSource SecretSource
mu sync.RWMutex
lookup map[string]string // clientKey -> upstreamKey
}
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
return &MappedSecretSource{
defaultSource: defaultSource,
lookup: make(map[string]string),
}
}
// Get retrieves the Amp API key, checking per-client mappings first.
// If the request context contains a client API key that matches a configured mapping,
// returns the corresponding upstream key. Otherwise, falls back to the default source.
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
// Try to get client API key from request context
clientKey := getClientAPIKeyFromContext(ctx)
if clientKey != "" {
s.mu.RLock()
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
s.mu.RUnlock()
return upstreamKey, nil
}
s.mu.RUnlock()
}
// Fall back to default source
return s.defaultSource.Get(ctx)
}
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
// If the same client key appears in multiple entries, logs a warning and uses the first one.
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
newLookup := make(map[string]string)
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
for _, clientKey := range entry.APIKeys {
trimmedKey := strings.TrimSpace(clientKey)
if trimmedKey == "" {
continue
}
if _, exists := newLookup[trimmedKey]; exists {
// Log warning for duplicate client key, first one wins
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
continue
}
newLookup[trimmedKey] = upstreamKey
}
}
s.mu.Lock()
s.lookup = newLookup
s.mu.Unlock()
}
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(key)
}
}
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) InvalidateCache() {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.InvalidateCache()
}
}

View File

@@ -8,6 +8,10 @@ import (
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
)
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
}
}
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1, got %q", got)
}
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
got, err = s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "default" {
t.Fatalf("want default fallback, got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1 (first wins), got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
hook := test.NewLocal(log.StandardLogger())
defer hook.Reset()
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
foundWarning := false
for _, entry := range hook.AllEntries() {
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
foundWarning = true
break
}
}
if !foundWarning {
t.Fatal("expected warning log for duplicate client key, but none was found")
}
}

View File

@@ -33,6 +33,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
@@ -209,13 +210,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Resolve logs directory relative to the configuration file directory.
var requestLogger logging.RequestLogger
var toggle func(bool)
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
if !cfg.CommercialMode {
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
}
}
}
@@ -474,6 +477,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -487,6 +492,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
@@ -496,6 +505,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
mgmt.POST("/api-call", s.mgmt.APICall)
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
@@ -518,6 +529,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
@@ -544,6 +556,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
@@ -552,6 +568,14 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
@@ -567,11 +591,21 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
mgmt.GET("/oauth-model-mappings", s.mgmt.GetOAuthModelMappings)
mgmt.PUT("/oauth-model-mappings", s.mgmt.PutOAuthModelMappings)
mgmt.PATCH("/oauth-model-mappings", s.mgmt.PatchOAuthModelMappings)
mgmt.DELETE("/oauth-model-mappings", s.mgmt.DeleteOAuthModelMappings)
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
@@ -845,7 +879,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
}
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
if err := logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to reconfigure log output: %v", err)
} else {
if oldCfg == nil {
@@ -955,8 +989,12 @@ func (s *Server) UpdateClients(cfg *config.Config) {
log.Warnf("amp module is nil, skipping config update")
}
// Count client sources from configuration and auth directory
authFiles := util.CountAuthFiles(cfg.AuthDir)
// Count client sources from configuration and auth store.
tokenStore := sdkAuth.GetTokenStore()
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(cfg.AuthDir)
}
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
geminiAPIKeyCount := len(cfg.GeminiKey)
claudeAPIKeyCount := len(cfg.ClaudeKey)
codexAPIKeyCount := len(cfg.CodexKey)
@@ -967,10 +1005,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
openAICompatCount += len(entry.APIKeyEntries)
}
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total,
authFiles,
authEntries,
geminiAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,

View File

@@ -3,7 +3,6 @@ package cache
import (
"crypto/sha256"
"encoding/hex"
"sort"
"sync"
"time"
)
@@ -16,21 +15,24 @@ type SignatureEntry struct {
const (
// SignatureCacheTTL is how long signatures are valid
SignatureCacheTTL = 1 * time.Hour
// MaxEntriesPerSession limits memory usage per session
MaxEntriesPerSession = 100
SignatureCacheTTL = 3 * time.Hour
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
SignatureTextHashLen = 16
// MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50
// SessionCleanupInterval controls how often stale sessions are purged
SessionCleanupInterval = 10 * time.Minute
)
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
var signatureCache sync.Map
// sessionCleanupOnce ensures the background cleanup goroutine starts only once
var sessionCleanupOnce sync.Once
// sessionCache is the inner map type
type sessionCache struct {
mu sync.RWMutex
@@ -45,6 +47,9 @@ func hashText(text string) string {
// getOrCreateSession gets or creates a session cache
func getOrCreateSession(sessionID string) *sessionCache {
// Start background cleanup on first access
sessionCleanupOnce.Do(startSessionCleanup)
if val, ok := signatureCache.Load(sessionID); ok {
return val.(*sessionCache)
}
@@ -53,6 +58,40 @@ func getOrCreateSession(sessionID string) *sessionCache {
return actual.(*sessionCache)
}
// startSessionCleanup launches a background goroutine that periodically
// removes sessions where all entries have expired.
func startSessionCleanup() {
go func() {
ticker := time.NewTicker(SessionCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredSessions()
}
}()
}
// purgeExpiredSessions removes sessions with no valid (non-expired) entries.
func purgeExpiredSessions() {
now := time.Now()
signatureCache.Range(func(key, value any) bool {
sc := value.(*sessionCache)
sc.mu.Lock()
// Remove expired entries
for k, entry := range sc.entries {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, k)
}
}
isEmpty := len(sc.entries) == 0
sc.mu.Unlock()
// Remove session if empty
if isEmpty {
signatureCache.Delete(key)
}
return true
})
}
// CacheSignature stores a thinking signature for a given session and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(sessionID, text, signature string) {
@@ -69,43 +108,6 @@ func CacheSignature(sessionID, text, signature string) {
sc.mu.Lock()
defer sc.mu.Unlock()
// Evict expired entries if at capacity
if len(sc.entries) >= MaxEntriesPerSession {
now := time.Now()
for key, entry := range sc.entries {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, key)
}
}
// If still at capacity, remove oldest entries
if len(sc.entries) >= MaxEntriesPerSession {
// Find and remove oldest quarter
oldest := make([]struct {
key string
ts time.Time
}, 0, len(sc.entries))
for key, entry := range sc.entries {
oldest = append(oldest, struct {
key string
ts time.Time
}{key, entry.Timestamp})
}
// Sort by timestamp (oldest first) using sort.Slice
sort.Slice(oldest, func(i, j int) bool {
return oldest[i].ts.Before(oldest[j].ts)
})
toRemove := len(oldest) / 4
if toRemove < 1 {
toRemove = 1
}
for i := 0; i < toRemove; i++ {
delete(sc.entries, oldest[i].key)
}
}
}
sc.entries[textHash] = SignatureEntry{
Signature: signature,
Timestamp: time.Now(),
@@ -127,22 +129,25 @@ func GetCachedSignature(sessionID, text string) string {
textHash := hashText(text)
sc.mu.RLock()
entry, exists := sc.entries[textHash]
sc.mu.RUnlock()
now := time.Now()
sc.mu.Lock()
entry, exists := sc.entries[textHash]
if !exists {
sc.mu.Unlock()
return ""
}
// Check if expired
if time.Since(entry.Timestamp) > SignatureCacheTTL {
sc.mu.Lock()
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, textHash)
sc.mu.Unlock()
return ""
}
// Refresh TTL on access (sliding expiration).
entry.Timestamp = now
sc.entries[textHash] = entry
sc.mu.Unlock()
return entry.Signature
}

View File

@@ -39,6 +39,9 @@ type Config struct {
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug" json:"debug"`
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
// LoggingToFile controls whether application logs are written to rotating files or stdout.
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
@@ -60,6 +63,9 @@ type Config struct {
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
// Routing controls credential selection behavior.
Routing RoutingConfig `yaml:"routing" json:"routing"`
// WebsocketAuth enables or disables authentication for the WebSocket API.
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
@@ -85,6 +91,14 @@ type Config struct {
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
// These mappings affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
@@ -124,6 +138,23 @@ type QuotaExceeded struct {
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
}
// RoutingConfig configures how credentials are selected for requests.
type RoutingConfig struct {
// Strategy selects the credential selection strategy.
// Supported values: "round-robin" (default), "fill-first".
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// ModelNameMapping defines a model ID mapping for a specific channel.
// It maps the upstream model name (Name) to the client-visible alias (Alias).
// When Fork is true, the alias is added as an additional model in listings while
// keeping the original model ID available.
type ModelNameMapping struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available.
@@ -134,6 +165,11 @@ type AmpModelMapping struct {
// To is the target model name to route to (e.g., "claude-sonnet-4").
// The target model must have available providers in the registry.
To string `yaml:"to" json:"to"`
// Regex indicates whether the 'from' field should be interpreted as a regular
// expression for matching model names. When true, this mapping is evaluated
// after exact matches and in the order provided. Defaults to false (exact match).
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
}
// AmpCode groups Amp CLI integration settings including upstream routing,
@@ -145,6 +181,11 @@ type AmpCode struct {
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
@@ -160,6 +201,17 @@ type AmpCode struct {
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
}
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
type AmpUpstreamAPIKeyEntry struct {
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads.
type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload.
@@ -219,6 +271,9 @@ type ClaudeModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m ClaudeModel) GetName() string { return m.Name }
func (m ClaudeModel) GetAlias() string { return m.Alias }
// CodexKey represents the configuration for a Codex API key,
// including the API key itself and an optional base URL for the API endpoint.
type CodexKey struct {
@@ -235,6 +290,9 @@ type CodexKey struct {
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// Models defines upstream model names and aliases for request routing.
Models []CodexModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -242,6 +300,18 @@ type CodexKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// CodexModel describes a mapping between an alias and the actual upstream model name.
type CodexModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m CodexModel) GetName() string { return m.Name }
func (m CodexModel) GetAlias() string { return m.Alias }
// GeminiKey represents the configuration for a Gemini API key,
// including optional overrides for upstream base URL, proxy routing, and headers.
type GeminiKey struct {
@@ -257,6 +327,9 @@ type GeminiKey struct {
// ProxyURL optionally overrides the global proxy for this API key.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// Models defines upstream model names and aliases for request routing.
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -264,6 +337,18 @@ type GeminiKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// GeminiModel describes a mapping between an alias and the actual upstream model name.
type GeminiModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m GeminiModel) GetName() string { return m.Name }
func (m GeminiModel) GetAlias() string { return m.Alias }
// OpenAICompatibility represents the configuration for OpenAI API compatibility
// with external providers, allowing model aliases to be routed through OpenAI API format.
type OpenAICompatibility struct {
@@ -415,6 +500,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Normalize global OAuth model name mappings.
cfg.SanitizeOAuthModelMappings()
if cfg.legacyMigrationPending {
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
if !optional && configFile != "" {
@@ -431,6 +519,44 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
return &cfg, nil
}
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
func (cfg *Config) SanitizeOAuthModelMappings() {
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
return
}
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
for rawChannel, mappings := range cfg.OAuthModelMappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(mappings) == 0 {
continue
}
seenAlias := make(map[string]struct{}, len(mappings))
clean := make([]ModelNameMapping, 0, len(mappings))
for _, mapping := range mappings {
name := strings.TrimSpace(mapping.Name)
alias := strings.TrimSpace(mapping.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
aliasKey := strings.ToLower(alias)
if _, ok := seenAlias[aliasKey]; ok {
continue
}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork})
}
if len(clean) > 0 {
out[channel] = clean
}
}
cfg.OAuthModelMappings = out
}
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
// not actionable, specifically those missing a BaseURL. It trims whitespace before
// evaluation and preserves the relative order of remaining entries.
@@ -799,8 +925,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
}
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
// key order and comments of existing keys in dst. Unknown keys from src are appended
// to dst at the end, copying their node structure from src.
// key order and comments of existing keys in dst. New keys are only added if their
// value is non-zero to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node) {
if dst == nil || src == nil {
return
@@ -811,20 +937,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
copyNodeShallow(dst, src)
return
}
// Build a lookup of existing keys in dst
for i := 0; i+1 < len(src.Content); i += 2 {
sk := src.Content[i]
sv := src.Content[i+1]
idx := findMapKeyIndex(dst, sk.Value)
if idx >= 0 {
// Merge into existing value node
// Merge into existing value node (always update, even to zero values)
dv := dst.Content[idx+1]
mergeNodePreserve(dv, sv)
} else {
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) {
// New key: only add if value is non-zero to avoid polluting config with defaults
if isZeroValueNode(sv) {
continue
}
// Append new key/value pair by deep-copying from src
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
}
}
@@ -907,32 +1032,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
return -1
}
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
switch key {
case "generative-language-api-key",
"gemini-api-key",
"vertex-api-key",
"claude-api-key",
"codex-api-key",
"openai-compatibility":
return isEmptyCollectionNode(node)
default:
return false
}
}
func isEmptyCollectionNode(node *yaml.Node) bool {
// isZeroValueNode returns true if the YAML node represents a zero/default value
// that should not be written as a new key to preserve config cleanliness.
// For mappings and sequences, recursively checks if all children are zero values.
func isZeroValueNode(node *yaml.Node) bool {
if node == nil {
return true
}
switch node.Kind {
case yaml.SequenceNode:
return len(node.Content) == 0
case yaml.ScalarNode:
return node.Tag == "!!null"
default:
return false
switch node.Tag {
case "!!bool":
return node.Value == "false"
case "!!int", "!!float":
return node.Value == "0" || node.Value == "0.0"
case "!!str":
return node.Value == ""
case "!!null":
return true
}
case yaml.SequenceNode:
if len(node.Content) == 0 {
return true
}
// Check if all elements are zero values
for _, child := range node.Content {
if !isZeroValueNode(child) {
return false
}
}
return true
case yaml.MappingNode:
if len(node.Content) == 0 {
return true
}
// Check if all values are zero values (values are at odd indices)
for i := 1; i < len(node.Content); i += 2 {
if !isZeroValueNode(node.Content[i]) {
return false
}
}
return true
}
return false
}
// deepCopyNode creates a deep copy of a yaml.Node graph.

View File

@@ -0,0 +1,56 @@
package config
import "testing"
func TestSanitizeOAuthModelMappings_PreservesForkFlag(t *testing.T) {
cfg := &Config{
OAuthModelMappings: map[string][]ModelNameMapping{
" CoDeX ": {
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
{Name: "gpt-6", Alias: "g6"},
},
},
}
cfg.SanitizeOAuthModelMappings()
mappings := cfg.OAuthModelMappings["codex"]
if len(mappings) != 2 {
t.Fatalf("expected 2 sanitized mappings, got %d", len(mappings))
}
if mappings[0].Name != "gpt-5" || mappings[0].Alias != "g5" || !mappings[0].Fork {
t.Fatalf("expected first mapping to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", mappings[0].Name, mappings[0].Alias, mappings[0].Fork)
}
if mappings[1].Name != "gpt-6" || mappings[1].Alias != "g6" || mappings[1].Fork {
t.Fatalf("expected second mapping to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", mappings[1].Name, mappings[1].Alias, mappings[1].Fork)
}
}
func TestSanitizeOAuthModelMappings_AllowsMultipleAliasesForSameName(t *testing.T) {
cfg := &Config{
OAuthModelMappings: map[string][]ModelNameMapping{
"antigravity": {
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
},
},
}
cfg.SanitizeOAuthModelMappings()
mappings := cfg.OAuthModelMappings["antigravity"]
expected := []ModelNameMapping{
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
}
if len(mappings) != len(expected) {
t.Fatalf("expected %d sanitized mappings, got %d", len(expected), len(mappings))
}
for i, exp := range expected {
if mappings[i].Name != exp.Name || mappings[i].Alias != exp.Alias || mappings[i].Fork != exp.Fork {
t.Fatalf("expected mapping %d to be name=%q alias=%q fork=%v, got name=%q alias=%q fork=%v", i, exp.Name, exp.Alias, exp.Fork, mappings[i].Name, mappings[i].Alias, mappings[i].Fork)
}
}
}

View File

@@ -22,6 +22,21 @@ type SDKConfig struct {
// Access holds request authentication provider configuration.
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
}
// StreamingConfig holds server streaming behavior configuration.
type StreamingConfig struct {
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
// <= 0 disables keep-alives. Default is 0.
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
// to allow auth rotation / transient recovery.
// <= 0 disables bootstrap retries. Default is 0.
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
}
// AccessConfig groups request authentication providers.

View File

@@ -42,6 +42,9 @@ type VertexCompatModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m VertexCompatModel) GetName() string { return m.Name }
func (m VertexCompatModel) GetAlias() string { return m.Alias }
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
func (cfg *Config) SanitizeVertexCompatKeys() {
if cfg == nil {

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -14,11 +15,24 @@ import (
log "github.com/sirupsen/logrus"
)
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
var aiAPIPrefixes = []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/messages",
"/v1/responses",
"/v1beta/models/",
"/api/provider/",
}
const skipGinLogKey = "__gin_skip_request_logging__"
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
// using logrus. It captures request details including method, path, status code, latency,
// client IP, and any error messages, formatting them in a Gin-style log format.
// client IP, and any error messages. Request ID is only added for AI API requests.
//
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
//
// Returns:
// - gin.HandlerFunc: A middleware handler for request logging
@@ -28,6 +42,15 @@ func GinLogrusLogger() gin.HandlerFunc {
path := c.Request.URL.Path
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
// Only generate request ID for AI API paths
var requestID string
if isAIAPIPath(path) {
requestID = GenerateRequestID()
SetGinRequestID(c, requestID)
ctx := WithRequestID(c.Request.Context(), requestID)
c.Request = c.Request.WithContext(ctx)
}
c.Next()
if shouldSkipGinRequestLogging(c) {
@@ -49,23 +72,38 @@ func GinLogrusLogger() gin.HandlerFunc {
clientIP := c.ClientIP()
method := c.Request.Method
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path)
if requestID == "" {
requestID = "--------"
}
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
if errorMessage != "" {
logLine = logLine + " | " + errorMessage
}
entry := log.WithField("request_id", requestID)
switch {
case statusCode >= http.StatusInternalServerError:
log.Error(logLine)
entry.Error(logLine)
case statusCode >= http.StatusBadRequest:
log.Warn(logLine)
entry.Warn(logLine)
default:
log.Info(logLine)
entry.Info(logLine)
}
}
}
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
func isAIAPIPath(path string) bool {
for _, prefix := range aiAPIPrefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
// and request path, then returns a 500 Internal Server Error response to the client.

View File

@@ -10,6 +10,7 @@ import (
"sync"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
@@ -24,7 +25,8 @@ var (
)
// LogFormatter defines a custom log format for logrus.
// This formatter adds timestamp, level, and source location to each log entry.
// This formatter adds timestamp, level, request ID, and source location to each log entry.
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
type LogFormatter struct{}
// Format renders a single log entry with custom formatting.
@@ -39,11 +41,22 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05")
message := strings.TrimRight(entry.Message, "\r\n")
reqID := "--------"
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
reqID = id
}
level := entry.Level.String()
if level == "warning" {
level = "warn"
}
levelStr := fmt.Sprintf("%-5s", level)
var formatted string
if entry.Caller != nil {
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
} else {
formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message)
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
}
buffer.WriteString(formatted)
@@ -71,10 +84,30 @@ func SetupBaseLogger() {
})
}
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
func isDirWritable(dir string) bool {
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
return false
}
testFile := filepath.Join(dir, ".perm_test")
f, err := os.Create(testFile)
if err != nil {
return false
}
defer func() {
_ = f.Close()
_ = os.Remove(testFile)
}()
return true
}
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit.
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
func ConfigureLogOutput(cfg *config.Config) error {
SetupBaseLogger()
writerMu.Lock()
@@ -83,10 +116,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
logDir := "logs"
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
} else if !isDirWritable(logDir) {
logDir = filepath.Join(cfg.AuthDir, "logs")
}
protectedPath := ""
if loggingToFile {
if cfg.LoggingToFile {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("logging: failed to create log directory: %w", err)
}
@@ -110,7 +145,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
log.SetOutput(os.Stdout)
}
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
return nil
}

View File

@@ -43,10 +43,11 @@ type RequestLogger interface {
// - response: The raw response data
// - apiRequest: The API request data
// - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
//
// Returns:
// - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
//
@@ -55,11 +56,12 @@ type RequestLogger interface {
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - requestID: Optional request ID for log file naming
//
// Returns:
// - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error)
LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
// IsEnabled returns whether request logging is currently enabled.
//
@@ -177,20 +179,21 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
// - response: The raw response data
// - apiRequest: The API request data
// - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
//
// Returns:
// - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
}
// LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
}
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
if !l.enabled && !force {
return nil
}
@@ -200,10 +203,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
}
// Generate filename
filename := l.generateFilename(url)
// Generate filename with request ID
filename := l.generateFilename(url, requestID)
if force && !l.enabled {
filename = l.generateErrorFilename(url)
filename = l.generateErrorFilename(url, requestID)
}
filePath := filepath.Join(l.logsDir, filename)
@@ -271,11 +274,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - requestID: Optional request ID for log file naming
//
// Returns:
// - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) {
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
if !l.enabled {
return &NoOpStreamingLogWriter{}, nil
}
@@ -285,8 +289,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
return nil, fmt.Errorf("failed to create logs directory: %w", err)
}
// Generate filename
filename := l.generateFilename(url)
// Generate filename with request ID
filename := l.generateFilename(url, requestID)
filePath := filepath.Join(l.logsDir, filename)
requestHeaders := make(map[string][]string, len(headers))
@@ -330,8 +334,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
}
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
func (l *FileRequestLogger) generateErrorFilename(url string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url))
func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
}
// ensureLogsDir creates the logs directory if it doesn't exist.
@@ -346,13 +350,15 @@ func (l *FileRequestLogger) ensureLogsDir() error {
}
// generateFilename creates a sanitized filename from the URL path and current timestamp.
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
//
// Parameters:
// - url: The request URL
// - requestID: Optional request ID to include in filename
//
// Returns:
// - string: A sanitized filename for the log file
func (l *FileRequestLogger) generateFilename(url string) string {
func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
// Extract path from URL
path := url
if strings.Contains(url, "?") {
@@ -368,12 +374,18 @@ func (l *FileRequestLogger) generateFilename(url string) string {
sanitized := l.sanitizeForFilename(path)
// Add timestamp
timestamp := time.Now().Format("2006-01-02T150405-.000000000")
timestamp = strings.Replace(timestamp, ".", "", -1)
timestamp := time.Now().Format("2006-01-02T150405")
id := requestLogID.Add(1)
// Use request ID if provided, otherwise use sequential ID
var idPart string
if len(requestID) > 0 && requestID[0] != "" {
idPart = requestID[0]
} else {
id := requestLogID.Add(1)
idPart = fmt.Sprintf("%d", id)
}
return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id)
return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
}
// sanitizeForFilename replaces characters that are not safe for filenames.

View File

@@ -0,0 +1,61 @@
package logging
import (
"context"
"crypto/rand"
"encoding/hex"
"github.com/gin-gonic/gin"
)
// requestIDKey is the context key for storing/retrieving request IDs.
type requestIDKey struct{}
// ginRequestIDKey is the Gin context key for request IDs.
const ginRequestIDKey = "__request_id__"
// GenerateRequestID creates a new 8-character hex request ID.
func GenerateRequestID() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
return "00000000"
}
return hex.EncodeToString(b)
}
// WithRequestID returns a new context with the request ID attached.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID)
}
// GetRequestID retrieves the request ID from the context.
// Returns empty string if not found.
func GetRequestID(ctx context.Context) string {
if ctx == nil {
return ""
}
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
return id
}
return ""
}
// SetGinRequestID stores the request ID in the Gin context.
func SetGinRequestID(c *gin.Context, requestID string) {
if c != nil {
c.Set(ginRequestIDKey, requestID)
}
}
// GetGinRequestID retrieves the request ID from the Gin context.
func GetGinRequestID(c *gin.Context) string {
if c == nil {
return ""
}
if id, exists := c.Get(ginRequestIDKey); exists {
if s, ok := id.(string); ok {
return s
}
}
return ""
}

View File

@@ -24,10 +24,11 @@ import (
)
const (
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
)
// ManagementFileName exposes the control panel asset filename.
@@ -198,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
return
}
localPath := filepath.Join(staticDir, managementAssetName)
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
}
// Rate limiting: check only once every 3 hours
lastUpdateCheckMu.Lock()
now := time.Now()
@@ -210,15 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock()
if err := os.MkdirAll(staticDir, 0o755); err != nil {
log.WithError(err).Warn("failed to prepare static directory for management asset")
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return
}
releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL)
localPath := filepath.Join(staticDir, managementAssetName)
localHash, err := fileSHA256(localPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
@@ -229,6 +239,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to fetch latest management release information")
return
}
@@ -240,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to download management asset")
return
}
@@ -256,6 +280,22 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
}
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
if err != nil {
log.WithError(err).Warn("failed to download fallback management control panel page")
return false
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
}
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
return true
}
func resolveReleaseURL(repo string) string {
repo = strings.TrimSpace(repo)
if repo == "" {

View File

@@ -7,12 +7,77 @@ import (
"embed"
_ "embed"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
//go:embed codex_instructions
var codexInstructionsDir embed.FS
func CodexInstructionsForModel(modelName, systemInstructions string) (bool, string) {
//go:embed opencode_codex_instructions.txt
var opencodeCodexInstructions string
const (
codexUserAgentKey = "__cpa_user_agent"
userAgentOpenAISDK = "ai-sdk/openai/"
)
func InjectCodexUserAgent(raw []byte, userAgent string) []byte {
if len(raw) == 0 {
return raw
}
trimmed := strings.TrimSpace(userAgent)
if trimmed == "" {
return raw
}
updated, err := sjson.SetBytes(raw, codexUserAgentKey, trimmed)
if err != nil {
return raw
}
return updated
}
func ExtractCodexUserAgent(raw []byte) string {
if len(raw) == 0 {
return ""
}
return strings.TrimSpace(gjson.GetBytes(raw, codexUserAgentKey).String())
}
func StripCodexUserAgent(raw []byte) []byte {
if len(raw) == 0 {
return raw
}
if !gjson.GetBytes(raw, codexUserAgentKey).Exists() {
return raw
}
updated, err := sjson.DeleteBytes(raw, codexUserAgentKey)
if err != nil {
return raw
}
return updated
}
func codexInstructionsForOpenCode(systemInstructions string) (bool, string) {
if opencodeCodexInstructions == "" {
return false, ""
}
if strings.HasPrefix(systemInstructions, opencodeCodexInstructions) {
return true, ""
}
return false, opencodeCodexInstructions
}
func useOpenCodeInstructions(userAgent string) bool {
return strings.Contains(strings.ToLower(userAgent), userAgentOpenAISDK)
}
func IsOpenCodeUserAgent(userAgent string) bool {
return useOpenCodeInstructions(userAgent)
}
func codexInstructionsForCodex(modelName, systemInstructions string) (bool, string) {
entries, _ := codexInstructionsDir.ReadDir("codex_instructions")
lastPrompt := ""
@@ -57,3 +122,10 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
return false, lastPrompt
}
}
func CodexInstructionsForModel(modelName, systemInstructions, userAgent string) (bool, string) {
if IsOpenCodeUserAgent(userAgent) {
return codexInstructionsForOpenCode(systemInstructions)
}
return codexInstructionsForCodex(modelName, systemInstructions)
}

View File

@@ -0,0 +1,318 @@
You are a coding agent running in the opencode, a terminal-based coding assistant. opencode is an open source project. You are expected to be precise, safe, and helpful.
Your capabilities:
- Receive user prompts and other context provided by the harness, such as files in the workspace.
- Communicate with the user by streaming thinking & responses, and by making & updating plans.
- Emit function calls to run terminal commands and apply edits. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section.
Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI).
# How you work
## Personality
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
# AGENTS.md spec
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
- Instructions in AGENTS.md files:
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
## Responsiveness
### Preamble messages
Before making tool calls, send a brief preamble to the user explaining what youre about to do. When sending preamble messages, follow these principles and examples:
- **Logically group related actions**: if youre about to run several related commands, describe them together in one preamble rather than sending a separate note for each.
- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (812 words for quick updates).
- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with whats been done so far and create a sense of momentum and clarity for the user to understand your next actions.
- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging.
- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless its part of a larger grouped action.
**Examples:**
- “Ive explored the repo; now checking the API route definitions.”
- “Next, Ill patch the config and update the related tests.”
- “Im about to scaffold the CLI commands and helper functions.”
- “Ok cool, so Ive wrapped my head around the repo. Now digging into the API routes.”
- “Configs looking tidy. Next up is editing helpers to keep things in sync.”
- “Finished poking at the DB gateway. I will now chase down error handling.”
- “Alright, build pipeline order is interesting. Checking how it reports failures.”
- “Spotted a clever caching util; now hunting where it gets used.”
## Planning
You have access to an `todowrite` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go.
Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately.
Do not repeat the full contents of the plan after an `todowrite` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step.
Before running a command, consider whether or not you have completed the
previous step, and make sure to mark it as completed before moving on to the
next step. It may be the case that you complete all steps in your plan after a
single pass of implementation. If this is the case, you can simply mark all the
planned steps as completed. Sometimes, you may need to change plans in the
middle of a task: call `todowrite` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
Use a plan when:
- The task is non-trivial and will require multiple actions over a long time horizon.
- There are logical phases or dependencies where sequencing matters.
- The work has ambiguity that benefits from outlining high-level goals.
- You want intermediate checkpoints for feedback and validation.
- When the user asked you to do more than one thing in a single prompt
- The user has asked you to use the plan tool (aka "TODOs")
- You generate additional steps while working, and plan to do them before yielding to the user
### Examples
**High-quality plans**
Example 1:
1. Add CLI entry with file args
2. Parse Markdown via CommonMark library
3. Apply semantic HTML template
4. Handle code blocks, images, links
5. Add error handling for invalid files
Example 2:
1. Define CSS variables for colors
2. Add toggle with localStorage state
3. Refactor components to use variables
4. Verify all views for readability
5. Add smooth theme-change transition
Example 3:
1. Set up Node.js + WebSocket server
2. Add join/leave broadcast events
3. Implement messaging with timestamps
4. Add usernames + mention highlighting
5. Persist messages in lightweight DB
6. Add typing indicators + unread count
**Low-quality plans**
Example 1:
1. Create CLI tool
2. Add Markdown parser
3. Convert to HTML
Example 2:
1. Add dark mode toggle
2. Save preference
3. Make styles look good
Example 3:
1. Create single-file HTML game
2. Run quick sanity check
3. Summarize usage instructions
If you need to write a plan, only write high quality plans, not low quality ones.
## Task execution
You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer.
You MUST adhere to the following criteria when solving queries:
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
- Analyzing code for vulnerabilities is allowed.
- Showing user code and tool call details is allowed.
- Use the `edit` tool to edit files
If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines:
- Fix the problem at the root cause rather than applying surface-level patches, when possible.
- Avoid unneeded complexity in your solution.
- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
- Update documentation as necessary.
- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
- Use `git log` and `git blame` to search the history of the codebase if additional context is required.
- NEVER add copyright or license headers unless specifically requested.
- Do not waste tokens by re-reading files after calling `edit` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc.
- Do not `git commit` your changes or create new git branches unless explicitly requested.
- Do not add inline comments within code unless explicitly requested.
- Do not use one-letter variable names unless explicitly requested.
- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor.
## Sandbox and approvals
The Codex CLI harness supports several different sandboxing, and approval configurations that the user can choose from.
Filesystem sandboxing prevents you from editing files without user approval. The options are:
- **read-only**: You can only read files.
- **workspace-write**: You can read files. You can write to files in your workspace folder, but not outside it.
- **danger-full-access**: No filesystem sandboxing.
Network sandboxing prevents you from accessing network without approval. Options are
- **restricted**
- **enabled**
Approvals are your mechanism to get user consent to perform more privileged actions. Although they introduce friction to the user because your work is paused until the user responds, you should leverage them to accomplish your important work. Do not let these settings or the sandbox deter you from attempting to accomplish the user's task. Approval options are
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is pared with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
When you are running with approvals `on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /tmp)
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval.
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
- (For all of these, you should weigh alternative paths that do not require approval.)
Note that when sandboxing is set to read-only, you'll need to request approval for any command that isn't a read.
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing ON, and approval on-failure.
## Validating your work
If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete.
When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests.
Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one.
For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance:
- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task.
- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first.
- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task.
## Ambition vs. precision
For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation.
If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature.
You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified.
## Sharing progress updates
For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next.
Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why.
The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along.
## Presenting your work and final message
Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the users style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges.
You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multisection structured responses for results that need grouping or explanation.
The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `edit`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path.
If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If theres something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly.
Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding.
### Final answer structure and style guidelines
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
**Section Headers**
- Use only when they improve clarity — they are not mandatory for every answer.
- Choose descriptive names that fit the content
- Keep headers short (13 words) and in `**Title Case**`. Always start headers with `**` and end with `**`
- Leave no blank line before the first bullet under a header.
- Section headers should only be used where they genuinely improve scannability; avoid fragmenting the answer.
**Bullets**
- Use `-` followed by a space for every bullet.
- Merge related points when possible; avoid a bullet for every trivial detail.
- Keep bullets to one line unless breaking for clarity is unavoidable.
- Group into short lists (46 bullets) ordered by importance.
- Use consistent keyword phrasing and formatting across sections.
**Monospace**
- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``).
- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.
- Never mix monospace and bold markers; choose one based on whether its a keyword (`**`) or inline code/path (`` ` ``).
**File References**
When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a standalone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Line/column (1based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
**Structure**
- Place related bullets together; dont mix unrelated concepts in the same section.
- Order sections from general → specific → supporting info.
- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it.
- Match structure to complexity:
- Multi-part or detailed results → use clear headers and grouped bullets.
- Simple results → minimal headers, possibly just a short list or paragraph.
**Tone**
- Keep the voice collaborative and natural, like a coding partner handing off work.
- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition
- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”).
- Keep descriptions self-contained; dont refer to “above” or “below”.
- Use parallel structure in lists for consistency.
**Dont**
- Dont use literal words “bold” or “monospace” in the content.
- Dont nest bullets or create deep hierarchies.
- Dont output ANSI escape codes directly — the CLI renderer applies them.
- Dont cram unrelated keywords into a single bullet; split for clarity.
- Dont let keyword lists run long — wrap or reformat for scannability.
Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with whats needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable.
For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting.
# Tool Guidelines
## Shell commands
When using the shell, you must adhere to the following guidelines:
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
- Read files in chunks with a max chunk size of 250 lines. Do not use python scripts to attempt to output larger chunks of a file. Command line output will be truncated after 10 kilobytes or 256 lines of output, regardless of the command used.
## `todowrite`
A tool named `todowrite` is available to you. You can use it to keep an uptodate, stepbystep plan for the task.
To create a new plan, call `todowrite` with a short list of 1sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`).
When steps have been completed, use `todowrite` to mark each finished step as
`completed` and the next step you are working on as `in_progress`. There should
always be exactly one `in_progress` step until everything is done. You can mark
multiple items as complete in a single `todowrite` call.
If all steps are complete, ensure you call `todowrite` to mark all steps as `completed`.

View File

@@ -727,6 +727,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
@@ -739,7 +740,8 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {
@@ -771,7 +773,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
return map[string]*AntigravityModelConfig{
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
"gemini-2.5-computer-use-preview-10-2025": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-2.5-computer-use-preview-10-2025"},
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
@@ -779,3 +781,29 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
}
}
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
if modelID == "" {
return nil
}
allModels := [][]*ModelInfo{
GetClaudeModels(),
GetGeminiModels(),
GetGeminiVertexModels(),
GetGeminiCLIModels(),
GetAIStudioModels(),
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
}
for _, models := range allModels {
for _, m := range models {
if m != nil && m.ID == modelID {
return m
}
}
}
return nil
}

View File

@@ -4,6 +4,7 @@
package registry
import (
"context"
"fmt"
"sort"
"strings"
@@ -84,6 +85,13 @@ type ModelRegistration struct {
SuspendedClients map[string]string
}
// ModelRegistryHook provides optional callbacks for external integrations to track model list changes.
// Hook implementations must be non-blocking and resilient; calls are executed asynchronously and panics are recovered.
type ModelRegistryHook interface {
OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo)
OnModelsUnregistered(ctx context.Context, provider, clientID string)
}
// ModelRegistry manages the global registry of available models
type ModelRegistry struct {
// models maps model ID to registration information
@@ -97,6 +105,8 @@ type ModelRegistry struct {
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
// hook is an optional callback sink for model registration changes
hook ModelRegistryHook
}
// Global model registry instance
@@ -117,6 +127,53 @@ func GetGlobalRegistry() *ModelRegistry {
return globalRegistry
}
// SetHook sets an optional hook for observing model registration changes.
func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
if r == nil {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.hook = hook
}
const defaultModelRegistryHookTimeout = 5 * time.Second
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook
if hook == nil {
return
}
modelsCopy := cloneModelInfosUnique(models)
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsRegistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsRegistered(ctx, provider, clientID, modelsCopy)
}()
}
func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
hook := r.hook
if hook == nil {
return
}
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsUnregistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsUnregistered(ctx, provider, clientID)
}()
}
// RegisterClient registers a client and its supported models
// Parameters:
// - clientID: Unique identifier for the client
@@ -177,6 +234,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
} else {
delete(r.clientProviders, clientID)
}
r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator()
return
@@ -310,6 +368,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientProviders, clientID)
}
r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
return
@@ -400,6 +459,25 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
return &copyModel
}
func cloneModelInfosUnique(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
cloned := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
if _, exists := seen[model.ID]; exists {
continue
}
seen[model.ID] = struct{}{}
cloned = append(cloned, cloneModelInfo(model))
}
return cloned
}
// UnregisterClient removes a client and decrements counts for its models
// Parameters:
// - clientID: Unique identifier for the client to remove
@@ -460,6 +538,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
log.Debugf("Unregistered client %s", clientID)
// Separator line after completing client unregistration (after the summary line)
misc.LogCredentialSeparator()
r.triggerModelsUnregistered(provider, clientID)
}
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
@@ -625,6 +704,131 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
return models
}
// GetAvailableModelsByProvider returns models available for the given provider identifier.
// Parameters:
// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity")
//
// Returns:
// - []*ModelInfo: List of available models for the provider
func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return nil
}
r.mutex.RLock()
defer r.mutex.RUnlock()
type providerModel struct {
count int
info *ModelInfo
}
providerModels := make(map[string]*providerModel)
for clientID, clientProvider := range r.clientProviders {
if clientProvider != provider {
continue
}
modelIDs := r.clientModels[clientID]
if len(modelIDs) == 0 {
continue
}
clientInfos := r.clientModelInfos[clientID]
for _, modelID := range modelIDs {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
entry := providerModels[modelID]
if entry == nil {
entry = &providerModel{}
providerModels[modelID] = entry
}
entry.count++
if entry.info == nil {
if clientInfos != nil {
if info := clientInfos[modelID]; info != nil {
entry.info = info
}
}
if entry.info == nil {
if reg, ok := r.models[modelID]; ok && reg != nil && reg.Info != nil {
entry.info = reg.Info
}
}
}
}
}
if len(providerModels) == 0 {
return nil
}
quotaExpiredDuration := 5 * time.Minute
now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels))
for modelID, entry := range providerModels {
if entry == nil || entry.count <= 0 {
continue
}
registration, ok := r.models[modelID]
expiredClients := 0
cooldownSuspended := 0
otherSuspended := 0
if ok && registration != nil {
if registration.QuotaExceededClients != nil {
for clientID, quotaTime := range registration.QuotaExceededClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
expiredClients++
}
}
}
if registration.SuspendedClients != nil {
for clientID, reason := range registration.SuspendedClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if strings.EqualFold(reason, "quota") {
cooldownSuspended++
continue
}
otherSuspended++
}
}
}
availableClients := entry.count
effectiveClients := availableClients - expiredClients - otherSuspended
if effectiveClients < 0 {
effectiveClients = 0
}
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil {
result = append(result, entry.info)
continue
}
if ok && registration != nil && registration.Info != nil {
result = append(result, registration.Info)
}
}
}
return result
}
// GetModelCount returns the number of available clients for a specific model
// Parameters:
// - modelID: The model ID to check

View File

@@ -0,0 +1,204 @@
package registry
import (
"context"
"sync"
"testing"
"time"
)
func newTestModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
}
}
type registeredCall struct {
provider string
clientID string
models []*ModelInfo
}
type unregisteredCall struct {
provider string
clientID string
}
type capturingHook struct {
registeredCh chan registeredCall
unregisteredCh chan unregisteredCall
}
func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models}
}
func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID}
}
func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
inputModels := []*ModelInfo{
{ID: "m1", DisplayName: "Model One"},
{ID: "m2", DisplayName: "Model Two"},
}
r.RegisterClient("client-1", "OpenAI", inputModels)
select {
case call := <-hook.registeredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
if len(call.models) != 2 {
t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2)
}
if call.models[0] == nil || call.models[0].ID != "m1" {
t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1")
}
if call.models[1] == nil || call.models[1].ID != "m2" {
t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
}
func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
r.UnregisterClient("client-1")
select {
case call := <-hook.unregisteredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}
type blockingHook struct {
started chan struct{}
unblock chan struct{}
}
func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
select {
case <-h.started:
default:
close(h.started)
}
<-h.unblock
}
func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {}
func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) {
r := newTestModelRegistry()
hook := &blockingHook{
started: make(chan struct{}),
unblock: make(chan struct{}),
}
r.SetHook(hook)
defer close(hook.unblock)
done := make(chan struct{})
go func() {
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
close(done)
}()
select {
case <-hook.started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for hook to start")
}
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("RegisterClient appears to be blocked by hook")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
}
type panicHook struct {
registeredCalled chan struct{}
unregisteredCalled chan struct{}
}
func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
if h.registeredCalled != nil {
h.registeredCalled <- struct{}{}
}
panic("boom")
}
func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
if h.unregisteredCalled != nil {
h.unregisteredCalled <- struct{}{}
}
panic("boom")
}
func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) {
r := newTestModelRegistry()
hook := &panicHook{
registeredCalled: make(chan struct{}, 1),
unregisteredCalled: make(chan struct{}, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
r.UnregisterClient("client-1")
select {
case <-hook.unregisteredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}

View File

@@ -8,6 +8,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
@@ -50,6 +51,64 @@ func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth)
return nil
}
// HttpRequest forwards an arbitrary HTTP request through the websocket relay.
func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("aistudio executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
if e.relay == nil {
return nil, fmt.Errorf("aistudio executor: ws relay is nil")
}
if auth == nil || auth.ID == "" {
return nil, fmt.Errorf("aistudio executor: missing auth")
}
httpReq := req.WithContext(ctx)
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
return nil, fmt.Errorf("aistudio executor: request URL is empty")
}
var body []byte
if httpReq.Body != nil {
b, errRead := io.ReadAll(httpReq.Body)
if errRead != nil {
return nil, errRead
}
body = b
httpReq.Body = io.NopCloser(bytes.NewReader(b))
}
wsReq := &wsrelay.HTTPRequest{
Method: httpReq.Method,
URL: httpReq.URL.String(),
Headers: httpReq.Header.Clone(),
Body: body,
}
wsResp, errRelay := e.relay.NonStream(ctx, auth.ID, wsReq)
if errRelay != nil {
return nil, errRelay
}
if wsResp == nil {
return nil, fmt.Errorf("aistudio executor: ws response is nil")
}
statusText := http.StatusText(wsResp.Status)
if statusText == "" {
statusText = "Unknown"
}
resp := &http.Response{
StatusCode: wsResp.Status,
Status: fmt.Sprintf("%d %s", wsResp.Status, statusText),
Header: wsResp.Headers.Clone(),
Body: io.NopCloser(bytes.NewReader(wsResp.Body)),
ContentLength: int64(len(wsResp.Body)),
Request: httpReq,
}
return resp, nil
}
// Execute performs a non-streaming request to the AI Studio API.
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
@@ -59,6 +118,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
if err != nil {
return resp, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
@@ -113,6 +173,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
if err != nil {
return nil, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
@@ -321,6 +382,11 @@ type translatedPayload struct {
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, stream)
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
@@ -329,7 +395,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfig(e.cfg, req.Model, payload)
payload = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", payload, originalTranslated)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")

View File

@@ -10,6 +10,7 @@ import (
"crypto/sha256"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
@@ -17,6 +18,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
@@ -32,21 +34,25 @@ import (
)
const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityCountTokensPath = "/v1internal:countTokens"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityCountTokensPath = "/v1internal:countTokens"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
)
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
)
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
@@ -67,12 +73,42 @@ func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
// Identifier returns the executor identifier.
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
// PrepareRequest prepares the HTTP request for execution (no-op for Antigravity).
func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// PrepareRequest injects Antigravity credentials into the outgoing HTTP request.
func (e *AntigravityExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
token, _, errToken := e.ensureAccessToken(req.Context(), auth)
if errToken != nil {
return errToken
}
if strings.TrimSpace(token) == "" {
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
req.Header.Set("Authorization", "Bearer "+token)
return nil
}
// HttpRequest injects Antigravity credentials into the request and executes it.
func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("antigravity executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if strings.Contains(req.Model, "claude") {
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if isClaude || strings.Contains(req.Model, "gemini-3-pro") {
return e.executeClaudeNonStream(ctx, auth, req, opts)
}
@@ -89,13 +125,18 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -114,6 +155,9 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
@@ -146,7 +190,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err
}
@@ -160,7 +210,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
switch {
case lastStatus != 0:
err = statusErr{code: lastStatus, msg: string(lastBody)}
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
@@ -184,13 +240,18 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, true)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -209,6 +270,9 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
@@ -227,6 +291,14 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return resp, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return resp, err
}
lastStatus = 0
lastBody = nil
lastErr = errRead
@@ -245,7 +317,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err
}
@@ -310,7 +388,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
switch {
case lastStatus != 0:
err = statusErr{code: lastStatus, msg: string(lastBody)}
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
@@ -516,15 +600,22 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -543,6 +634,9 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
@@ -561,6 +655,14 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return nil, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return nil, err
}
lastStatus = 0
lastBody = nil
lastErr = errRead
@@ -579,7 +681,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return nil, err
}
@@ -634,7 +742,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
switch {
case lastStatus != 0:
err = statusErr{code: lastStatus, msg: string(lastBody)}
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
@@ -672,6 +786,8 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -688,9 +804,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
for idx, baseURL := range baseURLs {
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
payload = normalizeAntigravityThinking(req.Model, payload)
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, payload)
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
@@ -735,6 +851,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return cliproxyexecutor.Response{}, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
@@ -769,12 +888,24 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return cliproxyexecutor.Response{}, sErr
}
switch {
case lastStatus != 0:
return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)}
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return cliproxyexecutor.Response{}, sErr
case lastErr != nil:
return cliproxyexecutor.Response{}, lastErr
default:
@@ -811,6 +942,9 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
@@ -890,7 +1024,13 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr
if accessToken != "" && expiry.After(time.Now().Add(refreshSkew)) {
return accessToken, nil, nil
}
updated, errRefresh := e.refreshToken(ctx, auth.Clone())
refreshCtx := context.Background()
if ctx != nil {
if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil {
refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt)
}
}
updated, errRefresh := e.refreshToken(refreshCtx, auth.Clone())
if errRefresh != nil {
return "", nil, errRefresh
}
@@ -937,7 +1077,13 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
return auth, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return auth, sErr
}
var tokenResp struct {
@@ -1017,6 +1163,19 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
payload = []byte(strJSON)
}
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-preview") {
systemInstructionPartsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", systemInstruction)
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
for _, partResult := range systemInstructionPartsResult.Array() {
payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(partResult.Raw))
}
}
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
if errReq != nil {
return nil, errReq
@@ -1151,8 +1310,8 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
return []string{base}
}
return []string{
antigravitySandboxBaseURLDaily,
antigravityBaseURLDaily,
// antigravityBaseURLAutopush,
antigravityBaseURLProd,
}
}
@@ -1180,6 +1339,7 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity")
template, _ = sjson.Set(template, "requestType", "agent")
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
@@ -1224,7 +1384,9 @@ func generateRequestID() string {
}
func generateSessionID() string {
randSourceMutex.Lock()
n := randSource.Int63n(9_000_000_000_000_000_000)
randSourceMutex.Unlock()
return "-" + strconv.FormatInt(n, 10)
}
@@ -1248,8 +1410,10 @@ func generateStableSessionID(payload []byte) string {
func generateProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randSourceMutex.Lock()
adj := adjectives[randSource.Intn(len(adjectives))]
noun := nouns[randSource.Intn(len(nouns))]
randSourceMutex.Unlock()
randomPart := strings.ToLower(uuid.NewString())[:5]
return adj + "-" + noun + "-" + randomPart
}
@@ -1300,7 +1464,7 @@ func alias2ModelName(modelName string) string {
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
// For Claude models, it additionally ensures thinking budget < max_tokens.
func normalizeAntigravityThinking(model string, payload []byte) []byte {
func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
payload = util.StripThinkingConfigIfUnsupported(model, payload)
if !util.ModelSupportsThinking(model) {
return payload
@@ -1312,7 +1476,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
raw := int(budget.Int())
normalized := util.NormalizeThinkingBudget(model, raw)
isClaude := strings.Contains(strings.ToLower(model), "claude")
if isClaude {
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
if effectiveMax > 0 && normalized >= effectiveMax {

View File

@@ -1,10 +1,68 @@
package executor
import "time"
import (
"sync"
"time"
)
type codexCache struct {
ID string
Expire time.Time
}
var codexCacheMap = map[string]codexCache{}
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
// Protected by codexCacheMu. Entries expire after 1 hour.
var (
codexCacheMap = make(map[string]codexCache)
codexCacheMu sync.RWMutex
)
// codexCacheCleanupInterval controls how often expired entries are purged.
const codexCacheCleanupInterval = 15 * time.Minute
// codexCacheCleanupOnce ensures the background cleanup goroutine starts only once.
var codexCacheCleanupOnce sync.Once
// startCodexCacheCleanup launches a background goroutine that periodically
// removes expired entries from codexCacheMap to prevent memory leaks.
func startCodexCacheCleanup() {
go func() {
ticker := time.NewTicker(codexCacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredCodexCache()
}
}()
}
// purgeExpiredCodexCache removes entries that have expired.
func purgeExpiredCodexCache() {
now := time.Now()
codexCacheMu.Lock()
defer codexCacheMu.Unlock()
for key, cache := range codexCacheMap {
if cache.Expire.Before(now) {
delete(codexCacheMap, key)
}
}
}
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
func getCodexCache(key string) (codexCache, bool) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.RLock()
cache, ok := codexCacheMap[key]
codexCacheMu.RUnlock()
if !ok || cache.Expire.Before(time.Now()) {
return codexCache{}, false
}
return cache, true
}
// setCodexCache stores a cache entry.
func setCodexCache(key string, cache codexCache) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.Lock()
codexCacheMap[key] = cache
codexCacheMu.Unlock()
}

View File

@@ -35,11 +35,53 @@ type ClaudeExecutor struct {
cfg *config.Config
}
const claudeToolPrefix = "proxy_"
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
func (e *ClaudeExecutor) Identifier() string { return "claude" }
func (e *ClaudeExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// PrepareRequest injects Claude credentials into the outgoing HTTP request.
func (e *ClaudeExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, _ := claudeCreds(auth)
if strings.TrimSpace(apiKey) == "" {
return nil
}
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := req.URL != nil && strings.EqualFold(req.URL.Scheme, "https") && strings.EqualFold(req.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
req.Header.Del("Authorization")
req.Header.Set("x-api-key", apiKey)
} else {
req.Header.Del("x-api-key")
req.Header.Set("Authorization", "Bearer "+apiKey)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest injects Claude credentials into the request and executes it.
func (e *ClaudeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("claude executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
apiKey, baseURL := claudeCreds(auth)
@@ -49,40 +91,46 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, stream)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
body = e.injectThinkingConfig(model, req.Metadata, body)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body)
body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
if err != nil {
return resp, err
}
@@ -97,7 +145,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Body: bodyForUpstream,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
@@ -151,8 +199,20 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.publish(ctx, parseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
out := sdktranslator.TranslateNonStream(
ctx,
to,
from,
req.Model,
bytes.Clone(opts.OriginalRequest),
bodyForTranslation,
data,
&param,
)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
@@ -167,33 +227,39 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
body = e.injectThinkingConfig(model, req.Metadata, body)
body = checkSystemInstructions(body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body)
body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
if err != nil {
return nil, err
}
@@ -208,7 +274,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Body: bodyForUpstream,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
@@ -261,6 +327,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
// Forward the line as-is to preserve SSE format
cloned := make([]byte, len(line)+1)
copy(cloned, line)
@@ -285,7 +354,19 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
if isClaudeOAuthToken(apiKey) {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
chunks := sdktranslator.TranslateStream(
ctx,
to,
from,
req.Model,
bytes.Clone(opts.OriginalRequest),
bodyForTranslation,
bytes.Clone(line),
&param,
)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
@@ -310,27 +391,23 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", model)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
}
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -461,6 +538,19 @@ func (e *ClaudeExecutor) injectThinkingConfig(modelName string, metadata map[str
return util.ApplyClaudeThinkingConfig(body, budget)
}
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
func disableThinkingIfToolChoiceForced(body []byte) []byte {
toolChoiceType := gjson.GetBytes(body, "tool_choice.type").String()
// "auto" is allowed with thinking, but "any" or "tool" (specific tool) are not
if toolChoiceType == "any" || toolChoiceType == "tool" {
// Remove thinking configuration entirely to avoid API error
body, _ = sjson.DeleteBytes(body, "thinking")
}
return body
}
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
// Anthropic API requires this constraint; violating it returns a 400 error.
// This function should be called after all thinking configuration is finalized.
@@ -662,7 +752,14 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
r.Header.Set("Authorization", "Bearer "+apiKey)
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
r.Header.Del("Authorization")
r.Header.Set("x-api-key", apiKey)
} else {
r.Header.Set("Authorization", "Bearer "+apiKey)
}
r.Header.Set("Content-Type", "application/json")
var ginHeaders http.Header
@@ -755,3 +852,107 @@ func checkSystemInstructions(payload []byte) []byte {
}
return payload
}
func isClaudeOAuthToken(apiKey string) bool {
return strings.Contains(apiKey, "sk-ant-oat")
}
func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if prefix == "" {
return body
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
name := tool.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("tools.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
}
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
name := gjson.GetBytes(body, "tool_choice.name").String()
if name != "" && !strings.HasPrefix(name, prefix) {
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
}
}
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
messages.ForEach(func(msgIndex, msg gjson.Result) bool {
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
return true
}
content.ForEach(func(contentIndex, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
}
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
return true
})
}
return body
}
func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
if prefix == "" {
return body
}
content := gjson.GetBytes(body, "content")
if !content.Exists() || !content.IsArray() {
return body
}
content.ForEach(func(index, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
}
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
return true
})
return body
}
func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
if prefix == "" {
return line
}
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
}
contentBlock := gjson.GetBytes(payload, "content_block")
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
return line
}
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
return line
}
trimmed := bytes.TrimSpace(line)
if bytes.HasPrefix(trimmed, []byte("data:")) {
return append([]byte("data: "), updated...)
}
return updated
}

View File

@@ -0,0 +1,51 @@
package executor
import (
"bytes"
"testing"
"github.com/tidwall/gjson"
)
func TestApplyClaudeToolPrefix(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_alpha" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_alpha")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_bravo" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_bravo")
}
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "proxy_charlie" {
t.Fatalf("tool_choice.name = %q, want %q", got, "proxy_charlie")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_delta" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_delta")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.name").String(); got != "alpha" {
t.Fatalf("content.0.name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.name").String(); got != "bravo" {
t.Fatalf("content.1.name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.name").String(); got != "alpha" {
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
}
}

View File

@@ -38,7 +38,38 @@ func NewCodexExecutor(cfg *config.Config) *CodexExecutor { return &CodexExecutor
func (e *CodexExecutor) Identifier() string { return "codex" }
func (e *CodexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// PrepareRequest injects Codex credentials into the outgoing HTTP request.
func (e *CodexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, _ := codexCreds(auth)
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest injects Codex credentials into the request and executes it.
func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("codex executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
apiKey, baseURL := codexCreds(auth)
@@ -49,20 +80,33 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
userAgent := codexUserAgent(ctx)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalPayload = misc.InjectCodexUserAgent(originalPayload, userAgent)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
body = sdktranslator.TranslateRequest(from, to, model, body, false)
body = misc.StripCodexUserAgent(body)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return resp, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -129,7 +173,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, line, &param)
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, line, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
@@ -146,20 +190,33 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
userAgent := codexUserAgent(ctx)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalPayload = misc.InjectCodexUserAgent(originalPayload, userAgent)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
body = sdktranslator.TranslateRequest(from, to, model, body, true)
body = misc.StripCodexUserAgent(body)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return nil, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.SetBytes(body, "model", model)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -231,7 +288,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
@@ -246,20 +303,25 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
userAgent := codexUserAgent(ctx)
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
body = sdktranslator.TranslateRequest(from, to, model, body, false)
body = misc.StripCodexUserAgent(body)
modelForCounting := req.Model
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.SetBytes(body, "stream", false)
enc, err := tokenizerForCodexModel(modelForCounting)
enc, err := tokenizerForCodexModel(model)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
}
@@ -440,14 +502,14 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
var hasKey bool
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) {
var ok bool
if cache, ok = getCodexCache(key); !ok {
cache = codexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
codexCacheMap[key] = cache
setCodexCache(key, cache)
}
}
} else if from == "openai-response" {
@@ -505,6 +567,16 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string) {
util.ApplyCustomHeadersFromAttrs(r, attrs)
}
func codexUserAgent(ctx context.Context) string {
if ctx == nil {
return ""
}
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
return strings.TrimSpace(ginCtx.Request.UserAgent())
}
return ""
}
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""
@@ -520,3 +592,87 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
}
return
}
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveCodexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -63,8 +63,42 @@ func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor {
// Identifier returns the executor identifier.
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini CLI).
func (e *GeminiCLIExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// PrepareRequest injects Gemini CLI credentials into the outgoing HTTP request.
func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
tokenSource, _, errSource := prepareGeminiCLITokenSource(req.Context(), e.cfg, auth)
if errSource != nil {
return errSource
}
tok, errTok := tokenSource.Token()
if errTok != nil {
return errTok
}
if strings.TrimSpace(tok.AccessToken) == "" {
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(req)
return nil
}
// HttpRequest injects Gemini CLI credentials into the request and executes it.
func (e *GeminiCLIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("gemini-cli executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
@@ -77,14 +111,19 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
action := "generateContent"
if req.Metadata != nil {
@@ -216,14 +255,19 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
projectID := resolveGeminiProjectID(auth)
@@ -318,7 +362,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response, reqBody []byte, attempt string) {
go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
@@ -336,14 +380,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reporter.publish(ctx, detail)
}
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
@@ -365,12 +409,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
@@ -417,15 +461,17 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
var lastStatus int
var lastBody []byte
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
// Gemini CLI endpoint when iterating fallback variants.
for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload)
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
tok, errTok := tokenSource.Token()
if errTok != nil {

View File

@@ -55,8 +55,38 @@ func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor {
// Identifier returns the executor identifier.
func (e *GeminiExecutor) Identifier() string { return "gemini" }
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini).
func (e *GeminiExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// PrepareRequest injects Gemini credentials into the outgoing HTTP request.
func (e *GeminiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, bearer := geminiCreds(auth)
if apiKey != "" {
req.Header.Set("x-goog-api-key", apiKey)
req.Header.Del("Authorization")
} else if bearer != "" {
req.Header.Set("Authorization", "Bearer "+bearer)
req.Header.Del("x-goog-api-key")
}
applyGeminiHeaders(req, auth)
return nil
}
// HttpRequest injects Gemini credentials into the request and executes it.
func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("gemini executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request to the Gemini API.
// It translates the request to Gemini format, sends it to the API, and translates
@@ -77,19 +107,27 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
// Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent"
if req.Metadata != nil {
@@ -98,7 +136,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
}
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -173,21 +211,29 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -287,19 +333,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, bearer := geminiCreds(auth)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens")
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
requestBody := bytes.NewReader(translatedReq)
@@ -398,6 +450,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
return base
}
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveGeminiConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
var attrs map[string]string
if auth != nil {

View File

@@ -50,11 +50,49 @@ func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
// Identifier returns the executor identifier.
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
// PrepareRequest prepares the HTTP request for execution (no-op for Vertex).
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
// PrepareRequest injects Vertex credentials into the outgoing HTTP request.
func (e *GeminiVertexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, _ := vertexAPICreds(auth)
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("x-goog-api-key", apiKey)
req.Header.Del("Authorization")
return nil
}
_, _, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return errCreds
}
token, errToken := vertexAccessToken(req.Context(), e.cfg, auth, saJSON)
if errToken != nil {
return errToken
}
if strings.TrimSpace(token) == "" {
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Del("x-goog-api-key")
return nil
}
// HttpRequest injects Vertex credentials into the request and executes it.
func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("vertex executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
// Try API key authentication first
@@ -120,10 +158,13 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
@@ -136,8 +177,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", req.Model)
action := "generateContent"
if req.Metadata != nil {
@@ -146,7 +187,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
}
}
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -220,24 +261,32 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent"
if req.Metadata != nil {
@@ -250,7 +299,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action)
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -321,10 +370,13 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
@@ -337,11 +389,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", req.Model)
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -438,30 +490,38 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -552,8 +612,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
// countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
@@ -566,14 +624,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
@@ -641,21 +699,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
// countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -665,7 +726,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
@@ -808,3 +869,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
}
return tok.AccessToken, nil
}
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
// It matches the requested model alias against configured models and returns the actual upstream name.
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveVertexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -37,8 +37,33 @@ func NewIFlowExecutor(cfg *config.Config) *IFlowExecutor { return &IFlowExecutor
// Identifier returns the provider key.
func (e *IFlowExecutor) Identifier() string { return "iflow" }
// PrepareRequest implements ProviderExecutor but requires no preprocessing.
func (e *IFlowExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// PrepareRequest injects iFlow credentials into the outgoing HTTP request.
func (e *IFlowExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, _ := iflowCreds(auth)
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
return nil
}
// HttpRequest injects iFlow credentials into the request and executes it.
func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("iflow executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming chat completion request.
func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
@@ -56,18 +81,21 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return resp, errValidate
}
body = applyIFlowThinkingConfig(body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = preserveReasoningContentInMessages(body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -147,24 +175,27 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return nil, errValidate
}
body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
toolsResult := gjson.GetBytes(body, "tools")
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body)
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -445,20 +476,85 @@ func ensureToolsArray(body []byte) []byte {
return updated
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
// preserveReasoningContentInMessages checks if reasoning_content from assistant messages
// is preserved in conversation history for iFlow models that support thinking.
// This is helpful for multi-turn conversations where the model may benefit from seeing
// its previous reasoning to maintain coherent thought chains.
//
// For GLM-4.6/4.7 and MiniMax M2/M2.1, it is recommended to include the full assistant
// response (including reasoning_content) in message history for better context continuity.
func preserveReasoningContentInMessages(body []byte) []byte {
model := strings.ToLower(gjson.GetBytes(body, "model").String())
// Only apply to models that support thinking with history preservation
needsPreservation := strings.HasPrefix(model, "glm-4") || strings.HasPrefix(model, "minimax-m2")
if !needsPreservation {
return body
}
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
// Check if any assistant message already has reasoning_content preserved
hasReasoningContent := false
messages.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
if role == "assistant" {
rc := msg.Get("reasoning_content")
if rc.Exists() && rc.String() != "" {
hasReasoningContent = true
return false // stop iteration
}
}
return true
})
// If reasoning content is already present, the messages are properly formatted
// No need to modify - the client has correctly preserved reasoning in history
if hasReasoningContent {
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
}
return body
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
// This should be called after NormalizeThinkingConfig has processed the payload.
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
//
// Model-specific handling:
// - GLM-4.6/4.7: Uses chat_template_kwargs.enable_thinking (boolean) and chat_template_kwargs.clear_thinking=false
// - MiniMax M2/M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
func applyIFlowThinkingConfig(body []byte) []byte {
effort := gjson.GetBytes(body, "reasoning_effort")
if !effort.Exists() {
return body
}
model := strings.ToLower(gjson.GetBytes(body, "model").String())
val := strings.ToLower(strings.TrimSpace(effort.String()))
enableThinking := val != "none" && val != ""
// Remove reasoning_effort as we'll convert to model-specific format
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
body, _ = sjson.DeleteBytes(body, "thinking")
// GLM-4.6/4.7: Use chat_template_kwargs
if strings.HasPrefix(model, "glm-4") {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
if enableThinking {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.clear_thinking", false)
}
return body
}
// MiniMax M2/M2.1: Use reasoning_split
if strings.HasPrefix(model, "minimax-m2") {
body, _ = sjson.SetBytes(body, "reasoning_split", enableThinking)
return body
}
return body
}

View File

@@ -304,11 +304,7 @@ func formatAuthInfo(info upstreamRequestLog) string {
parts = append(parts, "type=api_key")
}
case "oauth":
if authValue != "" {
parts = append(parts, fmt.Sprintf("type=oauth account=%s", authValue))
} else {
parts = append(parts, "type=oauth")
}
parts = append(parts, "type=oauth")
default:
if authType != "" {
if authValue != "" {

View File

@@ -35,11 +35,39 @@ func NewOpenAICompatExecutor(provider string, cfg *config.Config) *OpenAICompatE
// Identifier implements cliproxyauth.ProviderExecutor.
func (e *OpenAICompatExecutor) Identifier() string { return e.provider }
// PrepareRequest is a no-op for now (credentials are added via headers at execution time).
func (e *OpenAICompatExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
// PrepareRequest injects OpenAI-compatible credentials into the outgoing HTTP request.
func (e *OpenAICompatExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
_, apiKey := e.resolveCredentials(auth)
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest injects OpenAI-compatible credentials into the request and executes it.
func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("openai compat executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -53,20 +81,21 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate inbound request to OpenAI format
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
modelOverride := e.resolveUpstreamModel(req.Model, auth)
if modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" && modelOverride == "" {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
return resp, errValidate
}
@@ -149,20 +178,21 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
modelOverride := e.resolveUpstreamModel(req.Model, auth)
if modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" && modelOverride == "" {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
return nil, errValidate
}
@@ -239,6 +269,11 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
// Pass through translator; it yields one or more chunks for the target schema.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), &param)

View File

@@ -14,32 +14,54 @@ import (
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
// Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload
}
if !util.ModelSupportsThinking(model) {
if !util.ModelSupportsThinking(thinkingModel) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
}
// applyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
// ApplyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
func ApplyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
// Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload
}
if !util.ModelSupportsThinking(model) {
if !util.ModelSupportsThinking(thinkingModel) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
@@ -82,17 +104,11 @@ func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model
return payload
}
// applyPayloadConfig applies payload default and override rules from configuration
// to the given JSON payload for the specified model.
// Defaults only fill missing fields, while overrides always overwrite existing values.
func applyPayloadConfig(cfg *config.Config, model string, payload []byte) []byte {
return applyPayloadConfigWithRoot(cfg, model, "", "", payload)
}
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload []byte) []byte {
// and restricts matches to the given protocol when supplied. Defaults are checked
// against the original payload when provided.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte) []byte {
if cfg == nil || len(payload) == 0 {
return payload
}
@@ -105,6 +121,11 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
return payload
}
out := payload
source := original
if len(source) == 0 {
source = payload
}
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
@@ -116,7 +137,10 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
if fullPath == "" {
continue
}
if gjson.GetBytes(out, fullPath).Exists() {
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
@@ -124,6 +148,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.

View File

@@ -12,7 +12,6 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -37,7 +36,33 @@ func NewQwenExecutor(cfg *config.Config) *QwenExecutor { return &QwenExecutor{cf
func (e *QwenExecutor) Identifier() string { return "qwen" }
func (e *QwenExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
// PrepareRequest injects Qwen credentials into the outgoing HTTP request.
func (e *QwenExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
token, _ := qwenCreds(auth)
if strings.TrimSpace(token) != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
return nil
}
// HttpRequest injects Qwen credentials into the request and executes it.
func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("qwen executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
token, baseURL := qwenCreds(auth)
@@ -50,17 +75,19 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return resp, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -129,15 +156,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return nil, errValidate
}
toolsResult := gjson.GetBytes(body, "tools")
@@ -147,7 +176,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
}
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))

View File

@@ -19,7 +19,7 @@ type usageReporter struct {
provider string
model string
authID string
authIndex uint64
authIndex string
apiKey string
source string
requestedAt time.Time
@@ -275,6 +275,20 @@ func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true
}
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
CachedTokens: node.Get("cachedContentTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
}
func parseGeminiCLIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
@@ -284,16 +298,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiUsage(data []byte) usage.Detail {
@@ -305,16 +310,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
@@ -329,16 +325,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
@@ -353,16 +340,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
func parseAntigravityUsage(data []byte) usage.Detail {
@@ -377,16 +355,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
@@ -404,16 +373,7 @@ func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
var stopChunkWithoutUsage sync.Map
@@ -522,12 +482,16 @@ func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
cleaned := jsonBytes
var changed bool
if gjson.GetBytes(cleaned, "usageMetadata").Exists() {
if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
changed = true
}
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() {
if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
changed = true
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -86,6 +85,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
hasSystemInstruction = true
}
}
} else if systemResult.Type == gjson.String {
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
hasSystemInstruction = true
}
// contents
@@ -121,32 +124,31 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Use GetThinkingText to handle wrapped thinking objects
thinkingText := util.GetThinkingText(contentResult)
signatureResult := contentResult.Get("signature")
clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" {
clientSignature = signatureResult.String()
}
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if sessionID != "" && thinkingText != "" {
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
signature = cachedSig
log.Debugf("Using cached signature for thinking block")
clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" {
clientSignature = signatureResult.String()
}
}
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" && cache.HasValidSignature(clientSignature) {
signature = clientSignature
log.Debugf("Using client-provided signature for thinking block")
}
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if sessionID != "" && thinkingText != "" {
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
signature = cachedSig
// log.Debugf("Using cached signature for thinking block")
}
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(signature) {
currentMessageThinkingSignature = signature
}
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" && cache.HasValidSignature(clientSignature) {
signature = clientSignature
// log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(signature) {
currentMessageThinkingSignature = signature
}
// Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(signature)
@@ -155,8 +157,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
// Converting to text would break this requirement
if isUnsigned {
// TypeScript plugin approach: drop unsigned thinking blocks entirely
log.Debugf("Dropping unsigned thinking block (no valid signature)")
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
continue
}
@@ -180,7 +181,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
// NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected.
// The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies.
functionName := contentResult.Get("name").String()
argsResult := contentResult.Get("input")
@@ -321,6 +321,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// tools
toolsJSON := ""
toolDeclCount := 0
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
toolsJSON = `[{"functionDeclarations":[]}]`
@@ -333,10 +334,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
for toolKey := range gjson.Parse(tool).Map() {
if util.InArray(allowedToolKeys, toolKey) {
continue
}
tool, _ = sjson.Delete(tool, toolKey)
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
toolDeclCount++
}

View File

@@ -35,6 +35,7 @@ type Params struct {
CandidatesTokenCount int64 // Cached candidate token count from usage metadata
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
TotalTokenCount int64 // Cached total token count from usage metadata
CachedTokenCount int64 // Cached content token count (indicates prompt caching)
HasSentFinalEvents bool // Indicates if final content/message events have been sent
HasToolUse bool // Indicates if tool use was observed in the stream
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
@@ -98,6 +99,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
}
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
}
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
@@ -127,11 +136,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
log.Debug("Branch: signature_delta")
// log.Debug("Branch: signature_delta")
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset()
}
@@ -270,7 +279,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
params.HasUsageMetadata = true
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int()
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
@@ -322,6 +332,14 @@ func appendFinalEvents(params *Params, output *string, force bool) {
*output = *output + "event: message_delta\n"
*output = *output + "data: "
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if params.CachedTokenCount > 0 {
var err error
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
*output = *output + delta + "\n\n\n"
params.HasSentFinalEvents = true
@@ -361,6 +379,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int()
outputTokens := candidateTokens + thoughtTokens
if outputTokens == 0 && totalTokens > 0 {
outputTokens = totalTokens - promptTokens
@@ -374,6 +393,14 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if cachedTokens > 0 {
var err error
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
contentArrayInitialized := false
ensureContentArray := func() {
@@ -452,7 +479,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) {
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
}

View File

@@ -7,7 +7,6 @@ package gemini
import (
"bytes"
"encoding/json"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
@@ -135,41 +134,31 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int
}
// parseFunctionResponse attempts to unmarshal a function response part.
// Falls back to gjson extraction if standard json.Unmarshal fails.
func parseFunctionResponse(response gjson.Result) map[string]interface{} {
var responseMap map[string]interface{}
err := json.Unmarshal([]byte(response.Raw), &responseMap)
if err == nil {
return responseMap
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
// Falls back to a minimal "functionResponse" object when parsing fails.
func parseFunctionResponseRaw(response gjson.Result) string {
if response.IsObject() && gjson.Valid(response.Raw) {
return response.Raw
}
log.Debugf("unmarshal function response failed, using fallback: %v", err)
log.Debugf("parse function response failed, using fallback")
funcResp := response.Get("functionResponse")
if funcResp.Exists() {
fr := map[string]interface{}{
"name": funcResp.Get("name").String(),
"response": map[string]interface{}{
"result": funcResp.Get("response").String(),
},
}
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
if id := funcResp.Get("id").String(); id != "" {
fr["id"] = id
fr, _ = sjson.Set(fr, "functionResponse.id", id)
}
return map[string]interface{}{"functionResponse": fr}
return fr
}
return map[string]interface{}{
"functionResponse": map[string]interface{}{
"name": "unknown",
"response": map[string]interface{}{"result": response.String()},
},
}
fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
return fr
}
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -196,7 +185,7 @@ func fixCLIToolResponse(input string) (string, error) {
}
// Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array
contentsWrapper := `{"contents":[]}`
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -228,17 +217,16 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
var responseParts []interface{}
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
responseParts = append(responseParts, parseFunctionResponse(response))
partRaw := parseFunctionResponseRaw(response)
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
// Remove this group as it's been satisfied
@@ -252,50 +240,42 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group
if role == "model" {
var functionCallsInThisModel []gjson.Result
functionCallsCount := 0
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part)
functionCallsCount++
}
return true
})
if len(functionCallsInThisModel) > 0 {
if functionCallsCount > 0 {
// Add the model content
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
if !value.IsObject() {
log.Warnf("failed to parse model content")
return true
}
newContents = append(newContents, contentMap)
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ModelContent: contentMap,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
ResponsesNeeded: functionCallsCount,
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
if !value.IsObject() {
log.Warnf("failed to parse content")
return true
}
newContents = append(newContents, contentMap)
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
}
} else {
// Non-model content (user, etc.)
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
if !value.IsObject() {
log.Warnf("failed to parse content")
return true
}
newContents = append(newContents, contentMap)
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
}
return true
@@ -307,25 +287,23 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{}
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
responseParts = append(responseParts, parseFunctionResponse(response))
partRaw := parseFunctionResponseRaw(response)
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
newContentsJSON, _ := json.Marshal(newContents)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
return result, nil
}

View File

@@ -184,7 +184,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
@@ -192,8 +192,16 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -215,6 +223,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
@@ -239,10 +248,31 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
if content.Type == gjson.String {
if content.Type == gjson.String && content.String() != "" {
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
@@ -258,7 +288,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
if gjson.Valid(fargs) {
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
} else {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", []byte(fargs))
}
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
if fid != "" {
@@ -293,6 +327,8 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
}
}
}
@@ -319,7 +355,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
@@ -334,7 +370,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue

View File

@@ -8,12 +8,13 @@ package chat_completions
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -86,18 +87,27 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
}
}
}
// Process the main content part of the response.
@@ -171,21 +181,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload))
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}
}
}

View File

@@ -194,7 +194,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if name := fc.Get("name"); name.Exists() {
toolUse, _ = sjson.Set(toolUse, "name", name.String())
}
if args := fc.Get("args"); args.Exists() {
if args := fc.Get("args"); args.Exists() && args.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
@@ -314,11 +314,11 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if mode := funcCalling.Get("mode"); mode.Exists() {
switch mode.String() {
case "AUTO":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"})
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "NONE":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "none"})
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
case "ANY":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"})
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
}
}
}

View File

@@ -263,51 +263,6 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
}
// convertArrayToJSON converts []interface{} to JSON array string
func convertArrayToJSON(arr []interface{}) string {
result := "[]"
for _, item := range arr {
switch itemData := item.(type) {
case map[string]interface{}:
itemJSON := convertMapToJSON(itemData)
result, _ = sjson.SetRaw(result, "-1", itemJSON)
case string:
result, _ = sjson.Set(result, "-1", itemData)
case bool:
result, _ = sjson.Set(result, "-1", itemData)
case float64, int, int64:
result, _ = sjson.Set(result, "-1", itemData)
default:
result, _ = sjson.Set(result, "-1", itemData)
}
}
return result
}
// convertMapToJSON converts map[string]interface{} to JSON object string
func convertMapToJSON(m map[string]interface{}) string {
result := "{}"
for key, value := range m {
switch val := value.(type) {
case map[string]interface{}:
nestedJSON := convertMapToJSON(val)
result, _ = sjson.SetRaw(result, key, nestedJSON)
case []interface{}:
arrayJSON := convertArrayToJSON(val)
result, _ = sjson.SetRaw(result, key, arrayJSON)
case string:
result, _ = sjson.Set(result, key, val)
case bool:
result, _ = sjson.Set(result, key, val)
case float64, int, int64:
result, _ = sjson.Set(result, key, val)
default:
result, _ = sjson.Set(result, key, val)
}
}
return result
}
// ConvertClaudeResponseToGeminiNonStream converts a non-streaming Claude Code response to a non-streaming Gemini response.
// This function processes the complete Claude Code response and transforms it into a single Gemini-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
@@ -356,8 +311,8 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
}
// Process each streaming event and collect parts
var allParts []interface{}
var finalUsage map[string]interface{}
var allParts []string
var finalUsageJSON string
var responseID string
var createdAt int64
@@ -407,16 +362,14 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
if text := delta.Get("text"); text.Exists() && text.String() != "" {
partJSON := `{"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
part := gjson.Parse(partJSON).Value().(map[string]interface{})
allParts = append(allParts, part)
allParts = append(allParts, partJSON)
}
case "thinking_delta":
// Process reasoning/thinking content
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
partJSON := `{"thought":true,"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
part := gjson.Parse(partJSON).Value().(map[string]interface{})
allParts = append(allParts, part)
allParts = append(allParts, partJSON)
}
case "input_json_delta":
// accumulate args partial_json for this index
@@ -456,9 +409,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
if argsTrim != "" {
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
}
// Parse back to interface{} for allParts
functionCall := gjson.Parse(functionCallJSON).Value().(map[string]interface{})
allParts = append(allParts, functionCall)
allParts = append(allParts, functionCallJSON)
// cleanup used state for this index
if newParam.ToolUseArgs != nil {
delete(newParam.ToolUseArgs, idx)
@@ -501,8 +452,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set traffic type (required by Gemini API)
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
// Convert to map[string]interface{} using gjson
finalUsage = gjson.Parse(usageJSON).Value().(map[string]interface{})
finalUsageJSON = usageJSON
}
}
}
@@ -520,12 +470,16 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set the consolidated parts array
if len(consolidatedParts) > 0 {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", convertToJSONString(consolidatedParts))
partsJSON := "[]"
for _, partJSON := range consolidatedParts {
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
}
// Set usage metadata
if finalUsage != nil {
template, _ = sjson.SetRaw(template, "usageMetadata", convertToJSONString(finalUsage))
if finalUsageJSON != "" {
template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
}
return template
@@ -539,12 +493,12 @@ func GeminiTokenCount(ctx context.Context, count int64) string {
// This function processes the parts array to combine adjacent text elements and thinking elements
// into single consolidated parts, which results in a more readable and efficient response structure.
// Tool calls and other non-text parts are preserved as separate elements.
func consolidateParts(parts []interface{}) []interface{} {
func consolidateParts(parts []string) []string {
if len(parts) == 0 {
return parts
}
var consolidated []interface{}
var consolidated []string
var currentTextPart strings.Builder
var currentThoughtPart strings.Builder
var hasText, hasThought bool
@@ -554,8 +508,7 @@ func consolidateParts(parts []interface{}) []interface{} {
if hasText && currentTextPart.Len() > 0 {
textPartJSON := `{"text":""}`
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
textPart := gjson.Parse(textPartJSON).Value().(map[string]interface{})
consolidated = append(consolidated, textPart)
consolidated = append(consolidated, textPartJSON)
currentTextPart.Reset()
hasText = false
}
@@ -566,42 +519,42 @@ func consolidateParts(parts []interface{}) []interface{} {
if hasThought && currentThoughtPart.Len() > 0 {
thoughtPartJSON := `{"thought":true,"text":""}`
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
thoughtPart := gjson.Parse(thoughtPartJSON).Value().(map[string]interface{})
consolidated = append(consolidated, thoughtPart)
consolidated = append(consolidated, thoughtPartJSON)
currentThoughtPart.Reset()
hasThought = false
}
}
for _, part := range parts {
partMap, ok := part.(map[string]interface{})
if !ok {
for _, partJSON := range parts {
part := gjson.Parse(partJSON)
if !part.Exists() || !part.IsObject() {
// Flush any pending parts and add this non-text part
flushText()
flushThought()
consolidated = append(consolidated, part)
consolidated = append(consolidated, partJSON)
continue
}
if thought, isThought := partMap["thought"]; isThought && thought == true {
thought := part.Get("thought")
if thought.Exists() && thought.Type == gjson.True {
// This is a thinking part - flush any pending text first
flushText() // Flush any pending text first
if text, hasTextContent := partMap["text"].(string); hasTextContent {
currentThoughtPart.WriteString(text)
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
currentThoughtPart.WriteString(text.String())
hasThought = true
}
} else if text, hasTextContent := partMap["text"].(string); hasTextContent {
} else if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
// This is a regular text part - flush any pending thought first
flushThought() // Flush any pending thought first
currentTextPart.WriteString(text)
currentTextPart.WriteString(text.String())
hasText = true
} else {
// This is some other type of part (like function call) - flush both text and thought
flushText()
flushThought()
consolidated = append(consolidated, part)
consolidated = append(consolidated, partJSON)
}
}
@@ -611,20 +564,3 @@ func consolidateParts(parts []interface{}) []interface{} {
return consolidated
}
// convertToJSONString converts interface{} to JSON string using sjson/gjson.
// This function provides a consistent way to serialize different data types to JSON strings
// for inclusion in the Gemini API response structure.
func convertToJSONString(v interface{}) string {
switch val := v.(type) {
case []interface{}:
return convertArrayToJSON(val)
case map[string]interface{}:
return convertMapToJSON(val)
default:
// For simple types, create a temporary JSON and extract the value
temp := `{"temp":null}`
temp, _ = sjson.Set(temp, "temp", val)
return gjson.Get(temp, "temp").Raw
}
}

View File

@@ -10,7 +10,6 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"math/big"
"strings"
@@ -137,9 +136,6 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
out, _ = sjson.Set(out, "stream", stream)
// Process messages and transform them to Claude Code format
var anthropicMessages []interface{}
var toolCallIDs []string // Track tool call IDs for matching with tool results
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
messages.ForEach(func(_, message gjson.Result) bool {
role := message.Get("role").String()
@@ -152,33 +148,23 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
role = "user"
}
msg := map[string]interface{}{
"role": role,
"content": []interface{}{},
}
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
// Handle content based on its type (string or array)
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
// Simple text content conversion
msg["content"] = []interface{}{
map[string]interface{}{
"type": "text",
"text": contentResult.String(),
},
}
part := `{"type":"text","text":""}`
part, _ = sjson.Set(part, "text", contentResult.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
} else if contentResult.Exists() && contentResult.IsArray() {
// Array of content parts processing
var contentParts []interface{}
contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
// Text part conversion
contentParts = append(contentParts, map[string]interface{}{
"type": "text",
"text": part.Get("text").String(),
})
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
case "image_url":
// Convert OpenAI image format to Claude Code format
@@ -191,132 +177,95 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
data := parts[1]
contentParts = append(contentParts, map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": "base64",
"media_type": mediaType,
"data": data,
},
})
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
}
return true
})
if len(contentParts) > 0 {
msg["content"] = contentParts
}
} else {
// Initialize empty content array for tool calls
msg["content"] = []interface{}{}
}
// Handle tool calls (for assistant messages)
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() && role == "assistant" {
var contentParts []interface{}
// Add existing text content if any
if existingContent, ok := msg["content"].([]interface{}); ok {
contentParts = existingContent
}
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
if toolCall.Get("type").String() == "function" {
toolCallID := toolCall.Get("id").String()
if toolCallID == "" {
toolCallID = genToolCallID()
}
toolCallIDs = append(toolCallIDs, toolCallID)
function := toolCall.Get("function")
toolUse := map[string]interface{}{
"type": "tool_use",
"id": toolCallID,
"name": function.Get("name").String(),
}
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
// Parse arguments for the tool call
if args := function.Get("arguments"); args.Exists() {
argsStr := args.String()
if argsStr != "" {
var argsMap map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
toolUse["input"] = argsMap
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
} else {
toolUse["input"] = map[string]interface{}{}
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
} else {
toolUse["input"] = map[string]interface{}{}
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
} else {
toolUse["input"] = map[string]interface{}{}
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
contentParts = append(contentParts, toolUse)
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
}
return true
})
msg["content"] = contentParts
}
anthropicMessages = append(anthropicMessages, msg)
out, _ = sjson.SetRaw(out, "messages.-1", msg)
case "tool":
// Handle tool result messages conversion
toolCallID := message.Get("tool_call_id").String()
content := message.Get("content").String()
// Create tool result message in Claude Code format
msg := map[string]interface{}{
"role": "user",
"content": []interface{}{
map[string]interface{}{
"type": "tool_result",
"tool_use_id": toolCallID,
"content": content,
},
},
}
anthropicMessages = append(anthropicMessages, msg)
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
msg, _ = sjson.Set(msg, "content.0.content", content)
out, _ = sjson.SetRaw(out, "messages.-1", msg)
}
return true
})
}
// Set messages in the output template
if len(anthropicMessages) > 0 {
messagesJSON, _ := json.Marshal(anthropicMessages)
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
}
// Tools mapping: OpenAI tools -> Claude Code tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
var anthropicTools []interface{}
hasAnthropicTools := false
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "function" {
function := tool.Get("function")
anthropicTool := map[string]interface{}{
"name": function.Get("name").String(),
"description": function.Get("description").String(),
}
anthropicTool := `{"name":"","description":""}`
anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
// Convert parameters schema for the tool
if parameters := function.Get("parameters"); parameters.Exists() {
anthropicTool["input_schema"] = parameters.Value()
} else if parameters = function.Get("parametersJsonSchema"); parameters.Exists() {
anthropicTool["input_schema"] = parameters.Value()
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
} else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
}
anthropicTools = append(anthropicTools, anthropicTool)
out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
hasAnthropicTools = true
}
return true
})
if len(anthropicTools) > 0 {
toolsJSON, _ := json.Marshal(anthropicTools)
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
if !hasAnthropicTools {
out, _ = sjson.Delete(out, "tools")
}
}
@@ -329,18 +278,17 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "none":
// Don't set tool_choice, Claude Code will not use tools
case "auto":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"})
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "required":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"})
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
}
case gjson.JSON:
// Specific tool choice mapping
if toolChoice.Get("type").String() == "function" {
functionName := toolChoice.Get("function.name").String()
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{
"type": "tool",
"name": functionName,
})
toolChoiceJSON := `{"type":"tool","name":""}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
}
default:
}

View File

@@ -8,7 +8,7 @@ package chat_completions
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"time"
@@ -178,18 +178,11 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if arguments == "" {
arguments = "{}"
}
toolCall := map[string]interface{}{
"index": index,
"id": accumulator.ID,
"type": "function",
"function": map[string]interface{}{
"name": accumulator.Name,
"arguments": arguments,
},
}
template, _ = sjson.Set(template, "choices.0.delta.tool_calls", []interface{}{toolCall})
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
// Clean up the accumulator for this index
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
@@ -210,12 +203,14 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
// Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{
"prompt_tokens": usage.Get("input_tokens").Int(),
"completion_tokens": usage.Get("output_tokens").Int(),
"total_tokens": usage.Get("input_tokens").Int() + usage.Get("output_tokens").Int(),
}
template, _ = sjson.Set(template, "usage", usageObj)
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
return []string{template}
@@ -230,14 +225,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
case "error":
// Error event - format and return error response
if errorData := root.Get("error"); errorData.Exists() {
errorResponse := map[string]interface{}{
"error": map[string]interface{}{
"message": errorData.Get("message").String(),
"type": errorData.Get("type").String(),
},
}
errorJSON, _ := json.Marshal(errorResponse)
return []string{string(errorJSON)}
errorJSON := `{"error":{"message":"","type":""}}`
errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
return []string{errorJSON}
}
return []string{}
@@ -293,15 +284,10 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var messageID string
var model string
var createdAt int64
var inputTokens, outputTokens int64
var reasoningTokens int64
var stopReason string
var contentParts []string
var reasoningParts []string
// Use map to track tool calls by index for proper merging
toolCallsMap := make(map[int]map[string]interface{})
// Track tool call arguments accumulation
toolCallArgsMap := make(map[int]strings.Builder)
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
for _, chunk := range chunks {
root := gjson.ParseBytes(chunk)
@@ -314,9 +300,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String()
model = message.Get("model").String()
createdAt = time.Now().Unix()
if usage := message.Get("usage"); usage.Exists() {
inputTokens = usage.Get("input_tokens").Int()
}
}
case "content_block_start":
@@ -327,18 +310,12 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
// Start of thinking/reasoning content - skip for now as it's handled in delta
continue
} else if blockType == "tool_use" {
// Initialize tool call tracking for this index
// Initialize tool call accumulator for this index
index := int(root.Get("index").Int())
toolCallsMap[index] = map[string]interface{}{
"id": contentBlock.Get("id").String(),
"type": "function",
"function": map[string]interface{}{
"name": contentBlock.Get("name").String(),
"arguments": "",
},
toolCallsAccumulator[index] = &ToolCallAccumulator{
ID: contentBlock.Get("id").String(),
Name: contentBlock.Get("name").String(),
}
// Initialize arguments builder for this tool call
toolCallArgsMap[index] = strings.Builder{}
}
}
@@ -361,9 +338,8 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
// Accumulate tool call arguments
if partialJSON := delta.Get("partial_json"); partialJSON.Exists() {
index := int(root.Get("index").Int())
if builder, exists := toolCallArgsMap[index]; exists {
builder.WriteString(partialJSON.String())
toolCallArgsMap[index] = builder
if accumulator, exists := toolCallsAccumulator[index]; exists {
accumulator.Arguments.WriteString(partialJSON.String())
}
}
}
@@ -372,14 +348,9 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
case "content_block_stop":
// Finalize tool call arguments for this index when content block ends
index := int(root.Get("index").Int())
if toolCall, exists := toolCallsMap[index]; exists {
if builder, argsExists := toolCallArgsMap[index]; argsExists {
// Set the accumulated arguments for the tool call
arguments := builder.String()
if arguments == "" {
arguments = "{}"
}
toolCall["function"].(map[string]interface{})["arguments"] = arguments
if accumulator, exists := toolCallsAccumulator[index]; exists {
if accumulator.Arguments.Len() == 0 {
accumulator.Arguments.WriteString("{}")
}
}
@@ -391,11 +362,14 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
}
if usage := root.Get("usage"); usage.Exists() {
outputTokens = usage.Get("output_tokens").Int()
// Estimate reasoning tokens from accumulated thinking content
if len(reasoningParts) > 0 {
reasoningTokens = int64(len(strings.Join(reasoningParts, "")) / 4) // Rough estimation
}
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
}
}
@@ -417,24 +391,35 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
// Set tool calls if any were accumulated during processing
if len(toolCallsMap) > 0 {
// Convert tool calls map to array, preserving order by index
var toolCallsArray []interface{}
// Find the maximum index to determine the range
if len(toolCallsAccumulator) > 0 {
toolCallsCount := 0
maxIndex := -1
for index := range toolCallsMap {
for index := range toolCallsAccumulator {
if index > maxIndex {
maxIndex = index
}
}
// Iterate through all possible indices up to maxIndex
for i := 0; i <= maxIndex; i++ {
if toolCall, exists := toolCallsMap[i]; exists {
toolCallsArray = append(toolCallsArray, toolCall)
accumulator, exists := toolCallsAccumulator[i]
if !exists {
continue
}
arguments := accumulator.Arguments.String()
idPath := fmt.Sprintf("choices.0.message.tool_calls.%d.id", toolCallsCount)
typePath := fmt.Sprintf("choices.0.message.tool_calls.%d.type", toolCallsCount)
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
out, _ = sjson.Set(out, idPath, accumulator.ID)
out, _ = sjson.Set(out, typePath, "function")
out, _ = sjson.Set(out, namePath, accumulator.Name)
out, _ = sjson.Set(out, argumentsPath, arguments)
toolCallsCount++
}
if len(toolCallsArray) > 0 {
out, _ = sjson.Set(out, "choices.0.message.tool_calls", toolCallsArray)
if toolCallsCount > 0 {
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
} else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
@@ -443,16 +428,5 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
// Set usage information including prompt tokens, completion tokens, and total tokens
totalTokens := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", totalTokens)
// Add reasoning tokens to usage details if any reasoning content was processed
if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.completion_tokens_details.reasoning_tokens", reasoningTokens)
}
return out
}

View File

@@ -114,13 +114,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var builder strings.Builder
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
text := part.Get("text").String()
textResult := part.Get("text")
text := textResult.String()
if builder.Len() > 0 && text != "" {
builder.WriteByte('\n')
}
builder.WriteString(text)
return true
})
} else if parts.Type == gjson.String {
builder.WriteString(parts.String())
}
instructionsText = builder.String()
if instructionsText != "" {
@@ -207,6 +210,8 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
return true
})
} else if parts.Type == gjson.String {
textAggregate.WriteString(parts.String())
}
// Fallback to given role if content types not decisive
@@ -254,7 +259,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
toolUse, _ = sjson.Set(toolUse, "id", callID)
toolUse, _ = sjson.Set(toolUse, "name", name)
if argsStr != "" && gjson.Valid(argsStr) {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsStr)
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
}
}
asst := `{"role":"assistant","content":[]}`
@@ -309,16 +317,18 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case gjson.String:
switch toolChoice.String() {
case "auto":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"})
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "none":
// Leave unset; implies no tools
case "required":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"})
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
}
case gjson.JSON:
if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String()
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "tool", "name": fn})
toolChoiceJSON := `{"name":"","type":"tool"}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
}
default:

View File

@@ -40,6 +40,16 @@ type claudeToResponsesState struct {
var dataTag = []byte("data:")
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
return originalRequestRawJSON
}
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
return requestRawJSON
}
return nil
}
func emitEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
}
@@ -279,8 +289,9 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String())
}
@@ -344,31 +355,20 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
// Build response.output from aggregated state
var outputs []interface{}
outputsWrapper := `{"arr":[]}`
// reasoning item (if any)
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
r := map[string]interface{}{
"id": st.ReasoningItemID,
"type": "reasoning",
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": st.ReasoningBuf.String()}},
}
outputs = append(outputs, r)
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
// assistant message item (if any text)
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
m := map[string]interface{}{
"id": st.CurrentMsgID,
"type": "message",
"status": "completed",
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": st.TextBuf.String(),
}},
"role": "assistant",
}
outputs = append(outputs, m)
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
// function_call items (in ascending index order for determinism)
if len(st.FuncArgsBuf) > 0 {
@@ -395,19 +395,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if callID == "" && st.CurrentFCID != "" {
callID = st.CurrentFCID
}
item := map[string]interface{}{
"id": fmt.Sprintf("fc_%s", callID),
"type": "function_call",
"status": "completed",
"arguments": args,
"call_id": callID,
"name": name,
}
outputs = append(outputs, item)
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if len(outputs) > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs)
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
}
reasoningTokens := int64(0)
@@ -563,8 +560,9 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
out, _ = sjson.Set(out, "created_at", createdAt)
// Inject request echo fields as top-level (similar to streaming variant)
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
out, _ = sjson.Set(out, "instructions", v.String())
}
@@ -628,27 +626,18 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
}
// Build output array
var outputs []interface{}
outputsWrapper := `{"arr":[]}`
if reasoningBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{
"id": reasoningItemID,
"type": "reasoning",
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": reasoningBuf.String()}},
})
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", reasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
if currentMsgID != "" || textBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{
"id": currentMsgID,
"type": "message",
"status": "completed",
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": textBuf.String(),
}},
"role": "assistant",
})
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", currentMsgID)
item, _ = sjson.Set(item, "content.0.text", textBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
if len(toolCalls) > 0 {
// Preserve index order
@@ -669,18 +658,16 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
if args == "" {
args = "{}"
}
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("fc_%s", st.id),
"type": "function_call",
"status": "completed",
"arguments": args,
"call_id": st.id,
"name": st.name,
})
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", st.id)
item, _ = sjson.Set(item, "name", st.name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if len(outputs) > 0 {
out, _ = sjson.Set(out, "output", outputs)
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
}
// Usage

View File

@@ -37,10 +37,11 @@ import (
// - []byte: The transformed request data in internal client format
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
userAgent := misc.ExtractCodexUserAgent(rawJSON)
template := `{"model":"","instructions":"","input":[]}`
_, instructions := misc.CodexInstructionsForModel(modelName, "")
_, instructions := misc.CodexInstructionsForModel(modelName, "", userAgent)
template, _ = sjson.Set(template, "instructions", instructions)
rootResult := gjson.ParseBytes(rawJSON)

View File

@@ -9,7 +9,6 @@ package claude
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
@@ -21,6 +20,12 @@ var (
dataTag = []byte("data:")
)
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
}
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates Codex API responses
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
@@ -39,8 +44,10 @@ var (
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
hasToolCall := false
*param = &hasToolCall
*param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false,
BlockIndex: 0,
}
}
// log.Debugf("rawJSON: %s", string(rawJSON))
@@ -63,46 +70,49 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.content_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.output_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.content_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.completed" {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
p := (*param).(*bool)
if *p {
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
if p {
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
} else {
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
@@ -119,10 +129,9 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
p := true
*param = &p
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
{
// Restore original tool name if shortened
@@ -138,7 +147,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output += fmt.Sprintf("data: %s\n\n", template)
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
@@ -148,14 +157,15 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
}
} else if typeStr == "response.function_call_arguments.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n"
@@ -191,21 +201,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
return ""
}
response := map[string]interface{}{
"id": responseData.Get("id").String(),
"type": "message",
"role": "assistant",
"model": responseData.Get("model").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": responseData.Get("usage.input_tokens").Int(),
"output_tokens": responseData.Get("usage.output_tokens").Int(),
},
}
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
var contentBlocks []interface{}
hasToolCall := false
if output := responseData.Get("output"); output.Exists() && output.IsArray() {
@@ -244,10 +245,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
}
if thinkingBuilder.Len() > 0 {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "thinking",
"thinking": thinkingBuilder.String(),
})
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
}
case "message":
if content := item.Get("content"); content.Exists() {
@@ -256,10 +256,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
if part.Get("type").String() == "output_text" {
text := part.Get("text").String()
if text != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": text,
})
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
return true
@@ -267,10 +266,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} else {
text := content.String()
if text != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": text,
})
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
}
@@ -281,54 +279,41 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
name = original
}
toolBlock := map[string]interface{}{
"type": "tool_use",
"id": item.Get("call_id").String(),
"name": name,
"input": map[string]interface{}{},
}
if argsStr := item.Get("arguments").String(); argsStr != "" {
var args interface{}
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
toolBlock["input"] = args
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
inputRaw = argsJSON.Raw
}
}
contentBlocks = append(contentBlocks, toolBlock)
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
}
return true
})
}
if len(contentBlocks) > 0 {
response["content"] = contentBlocks
}
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
response["stop_reason"] = stopReason.String()
out, _ = sjson.Set(out, "stop_reason", stopReason.String())
} else if hasToolCall {
response["stop_reason"] = "tool_use"
out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else {
response["stop_reason"] = "end_turn"
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
response["stop_sequence"] = stopSequence.Value()
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
}
if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() {
response["usage"] = map[string]interface{}{
"input_tokens": responseData.Get("usage.input_tokens").Int(),
"output_tokens": responseData.Get("usage.output_tokens").Int(),
}
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
}
responseJSON, err := json.Marshal(response)
if err != nil {
return ""
}
return string(responseJSON)
return out
}
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.

View File

@@ -38,11 +38,12 @@ import (
// - []byte: The transformed request data in Codex API format
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
userAgent := misc.ExtractCodexUserAgent(rawJSON)
// Base template
out := `{"model":"","instructions":"","input":[]}`
// Inject standard Codex instructions
_, instructions := misc.CodexInstructionsForModel(modelName, "")
_, instructions := misc.CodexInstructionsForModel(modelName, "", userAgent)
out, _ = sjson.Set(out, "instructions", instructions)
root := gjson.ParseBytes(rawJSON)

View File

@@ -7,7 +7,6 @@ package gemini
import (
"bytes"
"context"
"encoding/json"
"fmt"
"time"
@@ -190,19 +189,19 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
}
// Process output content to build parts array
var parts []interface{}
hasToolCall := false
var pendingFunctionCalls []interface{}
var pendingFunctionCalls []string
flushPendingFunctionCalls := func() {
if len(pendingFunctionCalls) > 0 {
// Add all pending function calls as individual parts
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
parts = append(parts, fc)
}
pendingFunctionCalls = nil
if len(pendingFunctionCalls) == 0 {
return
}
// Add all pending function calls as individual parts
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
}
pendingFunctionCalls = nil
}
if output := responseData.Get("output"); output.Exists() && output.IsArray() {
@@ -216,11 +215,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add thinking content
if content := value.Get("content"); content.Exists() {
part := map[string]interface{}{
"thought": true,
"text": content.String(),
}
parts = append(parts, part)
part := `{"text":"","thought":true}`
part, _ = sjson.Set(part, "text", content.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
}
case "message":
@@ -232,10 +229,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
content.ForEach(func(_, contentItem gjson.Result) bool {
if contentItem.Get("type").String() == "output_text" {
if text := contentItem.Get("text"); text.Exists() {
part := map[string]interface{}{
"text": text.String(),
}
parts = append(parts, part)
part := `{"text":""}`
part, _ = sjson.Set(part, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
}
}
return true
@@ -245,28 +241,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
case "function_call":
// Collect function call for potential merging with consecutive ones
hasToolCall = true
functionCall := map[string]interface{}{
"functionCall": map[string]interface{}{
"name": func() string {
n := value.Get("name").String()
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
if orig, ok := rev[n]; ok {
return orig
}
return n
}(),
"args": map[string]interface{}{},
},
functionCall := `{"functionCall":{"args":{},"name":""}}`
{
n := value.Get("name").String()
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
if orig, ok := rev[n]; ok {
n = orig
}
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
}
// Parse and set arguments
if argsStr := value.Get("arguments").String(); argsStr != "" {
argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() {
var args map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
functionCall["functionCall"].(map[string]interface{})["args"] = args
}
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
}
}
@@ -279,11 +268,6 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
flushPendingFunctionCalls()
}
// Set the parts array
if len(parts) > 0 {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", mustMarshalJSON(parts))
}
// Set finish reason based on whether there were tool calls
if hasToolCall {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
@@ -323,15 +307,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
return rev
}
// mustMarshalJSON marshals a value to JSON, panicking on error.
func mustMarshalJSON(v interface{}) string {
data, err := json.Marshal(v)
if err != nil {
return ""
}
return string(data)
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}

View File

@@ -31,6 +31,7 @@ import (
// - []byte: The transformed request data in OpenAI Responses API format
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
userAgent := misc.ExtractCodexUserAgent(rawJSON)
// Start with empty JSON object
out := `{}`
@@ -96,7 +97,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// Extract system instructions from first system message (string or text object)
messages := gjson.GetBytes(rawJSON, "messages")
_, instructions := misc.CodexInstructionsForModel(modelName, "")
_, instructions := misc.CodexInstructionsForModel(modelName, "", userAgent)
out, _ = sjson.Set(out, "instructions", instructions)
// if messages.IsArray() {
// arr := messages.Array()
@@ -275,7 +276,15 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
arr := tools.Array()
for i := 0; i < len(arr); i++ {
t := arr[i]
if t.Get("type").String() == "function" {
toolType := t.Get("type").String()
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
// Only "function" needs structural conversion because Chat Completions nests details under "function".
if toolType != "" && toolType != "function" && t.IsObject() {
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
continue
}
if toolType == "function" {
item := `{}`
item, _ = sjson.Set(item, "type", "function")
fn := t.Get("function")
@@ -304,6 +313,37 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
}
}
// Map tool_choice when present.
// Chat Completions: "tool_choice" can be a string ("auto"/"none") or an object (e.g. {"type":"function","function":{"name":"..."}}).
// Responses API: keep built-in tool choices as-is; flatten function choice to {"type":"function","name":"..."}.
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
switch {
case tc.Type == gjson.String:
out, _ = sjson.Set(out, "tool_choice", tc.String())
case tc.IsObject():
tcType := tc.Get("type").String()
if tcType == "function" {
name := tc.Get("function.name").String()
if name != "" {
if short, ok := originalToolNameMap[name]; ok {
name = short
} else {
name = shortenNameIfNeeded(name)
}
}
choice := `{}`
choice, _ = sjson.Set(choice, "type", "function")
if name != "" {
choice, _ = sjson.Set(choice, "name", name)
}
out, _ = sjson.SetRaw(out, "tool_choice", choice)
} else if tcType != "" {
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
}
}
}
out, _ = sjson.Set(out, "store", false)
return []byte(out)
}

View File

@@ -12,6 +12,8 @@ import (
func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
userAgent := misc.ExtractCodexUserAgent(rawJSON)
rawJSON = misc.StripCodexUserAgent(rawJSON)
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
rawJSON, _ = sjson.SetBytes(rawJSON, "store", false)
@@ -32,7 +34,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
originalInstructionsText = originalInstructionsResult.String()
}
hasOfficialInstructions, instructions := misc.CodexInstructionsForModel(modelName, originalInstructionsResult.String())
hasOfficialInstructions, instructions := misc.CodexInstructionsForModel(modelName, originalInstructionsResult.String(), userAgent)
inputResult := gjson.GetBytes(rawJSON, "input")
var inputResults []gjson.Result

View File

@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -18,7 +19,10 @@ func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string
if typeResult := gjson.GetBytes(rawJSON, "type"); typeResult.Exists() {
typeStr := typeResult.String()
if typeStr == "response.created" || typeStr == "response.in_progress" || typeStr == "response.completed" {
rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", gjson.GetBytes(originalRequestRawJSON, "instructions").String())
if gjson.GetBytes(rawJSON, "response.instructions").Exists() {
instructions := selectInstructions(originalRequestRawJSON, requestRawJSON)
rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", instructions)
}
}
}
out := fmt.Sprintf("data: %s", string(rawJSON))
@@ -37,6 +41,16 @@ func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, modelName
}
responseResult := rootResult.Get("response")
template := responseResult.Raw
template, _ = sjson.Set(template, "instructions", gjson.GetBytes(originalRequestRawJSON, "instructions").String())
if responseResult.Get("instructions").Exists() {
template, _ = sjson.Set(template, "instructions", selectInstructions(originalRequestRawJSON, requestRawJSON))
}
return template
}
func selectInstructions(originalRequestRawJSON, requestRawJSON []byte) string {
userAgent := misc.ExtractCodexUserAgent(originalRequestRawJSON)
if misc.IsOpenCodeUserAgent(userAgent) {
return gjson.GetBytes(requestRawJSON, "instructions").String()
}
return gjson.GetBytes(originalRequestRawJSON, "instructions").String()
}

View File

@@ -7,10 +7,8 @@ package claude
import (
"bytes"
"encoding/json"
"strings"
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
@@ -41,92 +39,102 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
// system instruction
var systemInstruction *client.Content
systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
systemPart := client.Part{Text: systemPrompt}
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemInstruction := `{"role":"user","parts":[]}`
hasSystemParts := false
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
hasSystemParts = true
}
}
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
}
if len(systemInstruction.Parts) == 0 {
systemInstruction = nil
}
} else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String())
}
// contents
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
return true
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
switch contentResult.Get("type").String() {
case "text":
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
case "tool_use":
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{Name: functionName, Args: args},
ThoughtSignature: geminiCLIClaudeThoughtSignature,
})
argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) {
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
part, _ = sjson.Set(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
part, _ = sjson.Set(part, "functionCall.name", functionName)
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
case "tool_result":
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
if toolCallID == "" {
return true
}
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
}
}
contents = append(contents, clientContent)
return true
})
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
}
}
return true
})
}
// tools
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
hasTools := false
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
@@ -136,30 +144,19 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
var toolDeclaration any
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
if !hasTools {
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
hasTools = true
}
out, _ = sjson.SetRaw(out, "request.tools.0.functionDeclarations.-1", tool)
}
}
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "request.tools")
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
if systemInstruction != nil {
b, _ := json.Marshal(systemInstruction)
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
}
if len(contents) > 0 {
b, _ := json.Marshal(contents)
out, _ = sjson.SetRaw(out, "request.contents", string(b))
}
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
b, _ := json.Marshal(tools)
out, _ = sjson.SetRaw(out, "request.tools", string(b))
}
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled

View File

@@ -9,7 +9,6 @@ package claude
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
@@ -276,22 +275,16 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
root := gjson.ParseBytes(rawJSON)
response := map[string]interface{}{
"id": root.Get("response.responseId").String(),
"type": "message",
"role": "assistant",
"model": root.Get("response.modelVersion").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": root.Get("response.usageMetadata.promptTokenCount").Int(),
"output_tokens": root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int(),
},
}
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("response.responseId").String())
out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String())
inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
parts := root.Get("response.candidates.0.content.parts")
var contentBlocks []interface{}
textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{}
toolIDCounter := 0
@@ -301,10 +294,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if textBuilder.Len() == 0 {
return
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": textBuilder.String(),
})
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
textBuilder.Reset()
}
@@ -312,10 +304,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if thinkingBuilder.Len() == 0 {
return
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "thinking",
"thinking": thinkingBuilder.String(),
})
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
thinkingBuilder.Reset()
}
@@ -339,21 +330,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
name := functionCall.Get("name").String()
toolIDCounter++
toolBlock := map[string]interface{}{
"type": "tool_use",
"id": fmt.Sprintf("tool_%d", toolIDCounter),
"name": name,
"input": map[string]interface{}{},
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
}
if args := functionCall.Get("args"); args.Exists() {
var parsed interface{}
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
toolBlock["input"] = parsed
}
}
contentBlocks = append(contentBlocks, toolBlock)
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
continue
}
}
@@ -362,8 +347,6 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
flushThinking()
flushText()
response["content"] = contentBlocks
stopReason := "end_turn"
if hasToolCall {
stopReason = "tool_use"
@@ -379,19 +362,13 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
}
}
}
response["stop_reason"] = stopReason
out, _ = sjson.Set(out, "stop_reason", stopReason)
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) {
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
delete(response, "usage")
}
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() {
out, _ = sjson.Delete(out, "usage")
}
encoded, err := json.Marshal(response)
if err != nil {
return ""
}
return string(encoded)
return out
}
func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -7,7 +7,6 @@ package gemini
import (
"bytes"
"encoding/json"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
@@ -117,8 +116,6 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int
}
@@ -146,7 +143,7 @@ func fixCLIToolResponse(input string) (string, error) {
}
// Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array
contentsWrapper := `{"contents":[]}`
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -178,23 +175,17 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
var responseParts []interface{}
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
responseParts = append(responseParts, responseMap)
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
// Remove this group as it's been satisfied
@@ -208,50 +199,42 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group
if role == "model" {
var functionCallsInThisModel []gjson.Result
functionCallsCount := 0
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part)
functionCallsCount++
}
return true
})
if len(functionCallsInThisModel) > 0 {
if functionCallsCount > 0 {
// Add the model content
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
if !value.IsObject() {
log.Warnf("failed to parse model content")
return true
}
newContents = append(newContents, contentMap)
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ModelContent: contentMap,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
ResponsesNeeded: functionCallsCount,
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
if !value.IsObject() {
log.Warnf("failed to parse content")
return true
}
newContents = append(newContents, contentMap)
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
}
} else {
// Non-model content (user, etc.)
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
if !value.IsObject() {
log.Warnf("failed to parse content")
return true
}
newContents = append(newContents, contentMap)
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
}
return true
@@ -263,31 +246,24 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{}
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
responseParts = append(responseParts, responseMap)
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
newContentsJSON, _ := json.Marshal(newContents)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
return result, nil
}

View File

@@ -152,7 +152,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
@@ -160,8 +160,16 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -183,6 +191,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
@@ -210,8 +219,30 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if content.Type == gjson.String {
// Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
@@ -236,7 +267,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`)
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
@@ -252,6 +283,8 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
}
}
}
@@ -278,7 +311,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
@@ -293,7 +326,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue

View File

@@ -8,7 +8,6 @@ package chat_completions
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
@@ -171,21 +170,16 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload))
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}
}
}

View File

@@ -7,10 +7,8 @@ package claude
import (
"bytes"
"encoding/json"
"strings"
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
@@ -34,92 +32,102 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"contents":[]}`
out, _ = sjson.Set(out, "model", modelName)
// system instruction
var systemInstruction *client.Content
systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
systemPart := client.Part{Text: systemPrompt}
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemInstruction := `{"role":"user","parts":[]}`
hasSystemParts := false
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
hasSystemParts = true
}
}
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
}
if len(systemInstruction.Parts) == 0 {
systemInstruction = nil
}
} else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
}
// contents
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
return true
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
switch contentResult.Get("type").String() {
case "text":
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
case "tool_use":
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{Name: functionName, Args: args},
ThoughtSignature: geminiClaudeThoughtSignature,
})
argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) {
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
part, _ = sjson.Set(part, "thoughtSignature", geminiClaudeThoughtSignature)
part, _ = sjson.Set(part, "functionCall.name", functionName)
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
case "tool_result":
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
if toolCallID == "" {
return true
}
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
}
}
contents = append(contents, clientContent)
return true
})
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
}
}
return true
})
}
// tools
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
hasTools := false
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
@@ -129,30 +137,19 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
var toolDeclaration any
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
if !hasTools {
out, _ = sjson.SetRaw(out, "tools", `[{"functionDeclarations":[]}]`)
hasTools = true
}
out, _ = sjson.SetRaw(out, "tools.0.functionDeclarations.-1", tool)
}
}
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "tools")
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
// Build output Gemini CLI request JSON
out := `{"contents":[]}`
out, _ = sjson.Set(out, "model", modelName)
if systemInstruction != nil {
b, _ := json.Marshal(systemInstruction)
out, _ = sjson.SetRaw(out, "system_instruction", string(b))
}
if len(contents) > 0 {
b, _ := json.Marshal(contents)
out, _ = sjson.SetRaw(out, "contents", string(b))
}
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
b, _ := json.Marshal(tools)
out, _ = sjson.SetRaw(out, "tools", string(b))
}
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled

View File

@@ -9,7 +9,6 @@ package claude
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
@@ -282,22 +281,16 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
root := gjson.ParseBytes(rawJSON)
response := map[string]interface{}{
"id": root.Get("responseId").String(),
"type": "message",
"role": "assistant",
"model": root.Get("modelVersion").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": root.Get("usageMetadata.promptTokenCount").Int(),
"output_tokens": root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int(),
},
}
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("responseId").String())
out, _ = sjson.Set(out, "model", root.Get("modelVersion").String())
inputTokens := root.Get("usageMetadata.promptTokenCount").Int()
outputTokens := root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
parts := root.Get("candidates.0.content.parts")
var contentBlocks []interface{}
textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{}
toolIDCounter := 0
@@ -307,10 +300,9 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
if textBuilder.Len() == 0 {
return
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": textBuilder.String(),
})
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
textBuilder.Reset()
}
@@ -318,10 +310,9 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
if thinkingBuilder.Len() == 0 {
return
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "thinking",
"thinking": thinkingBuilder.String(),
})
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
thinkingBuilder.Reset()
}
@@ -345,21 +336,15 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
name := functionCall.Get("name").String()
toolIDCounter++
toolBlock := map[string]interface{}{
"type": "tool_use",
"id": fmt.Sprintf("tool_%d", toolIDCounter),
"name": name,
"input": map[string]interface{}{},
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
}
if args := functionCall.Get("args"); args.Exists() {
var parsed interface{}
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
toolBlock["input"] = parsed
}
}
contentBlocks = append(contentBlocks, toolBlock)
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
continue
}
}
@@ -368,8 +353,6 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
flushThinking()
flushText()
response["content"] = contentBlocks
stopReason := "end_turn"
if hasToolCall {
stopReason = "tool_use"
@@ -385,19 +368,13 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
}
}
}
response["stop_reason"] = stopReason
out, _ = sjson.Set(out, "stop_reason", stopReason)
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) {
if usageMeta := root.Get("usageMetadata"); !usageMeta.Exists() {
delete(response, "usage")
}
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("usageMetadata").Exists() {
out, _ = sjson.Delete(out, "usage")
}
encoded, err := json.Marshal(response)
if err != nil {
return ""
}
return string(encoded)
return out
}
func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -170,7 +170,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> system_instruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
@@ -178,8 +178,16 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, "system_instruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
@@ -201,6 +209,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
}
}
@@ -225,18 +234,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
if content.Type == gjson.String {
// Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
@@ -248,12 +254,12 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
}
}
}
}
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
}
// Tool calls -> single model content with functionCall parts
@@ -278,7 +284,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`)
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
@@ -294,6 +300,8 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
}
}
}
@@ -320,7 +328,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
@@ -335,7 +343,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue

View File

@@ -8,12 +8,12 @@ package chat_completions
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -89,18 +89,27 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
}
}
}
// Process the main content part of the response.
@@ -173,21 +182,16 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload))
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}
}
}
@@ -248,10 +252,19 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini openai response: failed to set cached_tokens in non-streaming: %v", err)
}
}
}
// Process the main content part of the response.
@@ -305,21 +318,16 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.message.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", string(imagePayload))
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
}
}
}

View File

@@ -23,6 +23,7 @@ type geminiToResponsesState struct {
MsgIndex int
CurrentMsgID string
TextBuf strings.Builder
ItemTextBuf strings.Builder
// reasoning aggregation
ReasoningOpened bool
@@ -189,6 +190,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID)
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
out = append(out, emitEvent("response.content_part.added", partAdded))
st.ItemTextBuf.Reset()
st.ItemTextBuf.WriteString(t.String())
}
st.TextBuf.WriteString(t.String())
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
@@ -250,20 +253,24 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
finalizeReasoning()
// Close message output if opened
if st.MsgOpened {
fullText := st.ItemTextBuf.String()
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final, _ = sjson.Set(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
}
@@ -377,27 +384,18 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
}
// Compose outputs in encountered order: reasoning, message, function_calls
var outputs []interface{}
outputsWrapper := `{"arr":[]}`
if st.ReasoningOpened {
outputs = append(outputs, map[string]interface{}{
"id": st.ReasoningItemID,
"type": "reasoning",
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": st.ReasoningBuf.String()}},
})
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
if st.MsgOpened {
outputs = append(outputs, map[string]interface{}{
"id": st.CurrentMsgID,
"type": "message",
"status": "completed",
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": st.TextBuf.String(),
}},
"role": "assistant",
})
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
if len(st.FuncArgsBuf) > 0 {
idxs := make([]int, 0, len(st.FuncArgsBuf))
@@ -416,18 +414,16 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
if b := st.FuncArgsBuf[idx]; b != nil {
args = b.String()
}
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]),
"type": "function_call",
"status": "completed",
"arguments": args,
"call_id": st.FuncCallIDs[idx],
"name": st.FuncNames[idx],
})
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", st.FuncCallIDs[idx])
item, _ = sjson.Set(item, "name", st.FuncNames[idx])
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if len(outputs) > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs)
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
}
// usage mapping
@@ -558,11 +554,24 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
}
// Build outputs from candidates[0].content.parts
var outputs []interface{}
var reasoningText strings.Builder
var reasoningEncrypted string
var messageText strings.Builder
var haveMessage bool
haveOutput := false
ensureOutput := func() {
if haveOutput {
return
}
resp, _ = sjson.SetRaw(resp, "output", "[]")
haveOutput = true
}
appendOutput := func(itemJSON string) {
ensureOutput()
resp, _ = sjson.SetRaw(resp, "output.-1", itemJSON)
}
if parts := root.Get("candidates.0.content.parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, p gjson.Result) bool {
if p.Get("thought").Bool() {
@@ -583,19 +592,16 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
name := fc.Get("name").String()
args := fc.Get("args")
callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("fc_%s", callID),
"type": "function_call",
"status": "completed",
"arguments": func() string {
if args.Exists() {
return args.Raw
}
return ""
}(),
"call_id": callID,
"name": name,
})
itemJSON := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("fc_%s", callID))
itemJSON, _ = sjson.Set(itemJSON, "call_id", callID)
itemJSON, _ = sjson.Set(itemJSON, "name", name)
argsStr := ""
if args.Exists() {
argsStr = args.Raw
}
itemJSON, _ = sjson.Set(itemJSON, "arguments", argsStr)
appendOutput(itemJSON)
return true
}
return true
@@ -605,42 +611,24 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// Reasoning output item
if reasoningText.Len() > 0 || reasoningEncrypted != "" {
rid := strings.TrimPrefix(id, "resp_")
item := map[string]interface{}{
"id": fmt.Sprintf("rs_%s", rid),
"type": "reasoning",
"encrypted_content": reasoningEncrypted,
}
var summaries []interface{}
itemJSON := `{"id":"","type":"reasoning","encrypted_content":""}`
itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("rs_%s", rid))
itemJSON, _ = sjson.Set(itemJSON, "encrypted_content", reasoningEncrypted)
if reasoningText.Len() > 0 {
summaries = append(summaries, map[string]interface{}{
"type": "summary_text",
"text": reasoningText.String(),
})
summaryJSON := `{"type":"summary_text","text":""}`
summaryJSON, _ = sjson.Set(summaryJSON, "text", reasoningText.String())
itemJSON, _ = sjson.SetRaw(itemJSON, "summary", "[]")
itemJSON, _ = sjson.SetRaw(itemJSON, "summary.-1", summaryJSON)
}
if summaries != nil {
item["summary"] = summaries
}
outputs = append(outputs, item)
appendOutput(itemJSON)
}
// Assistant message output item
if haveMessage {
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("msg_%s_0", strings.TrimPrefix(id, "resp_")),
"type": "message",
"status": "completed",
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": messageText.String(),
}},
"role": "assistant",
})
}
if len(outputs) > 0 {
resp, _ = sjson.Set(resp, "output", outputs)
itemJSON := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("msg_%s_0", strings.TrimPrefix(id, "resp_")))
itemJSON, _ = sjson.Set(itemJSON, "content.0.text", messageText.String())
appendOutput(itemJSON)
}
// usage mapping

View File

@@ -7,7 +7,6 @@ package claude
import (
"bytes"
"encoding/json"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -119,81 +118,125 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Handle content
if contentResult.Exists() && contentResult.IsArray() {
var contentItems []string
var reasoningParts []string // Accumulate thinking text for reasoning_content
var toolCalls []interface{}
var toolResults []string // Collect tool_result messages to emit after the main message
contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "thinking":
// Only map thinking to reasoning_content for assistant messages (security: prevent injection)
if role == "assistant" {
thinkingText := util.GetThinkingText(part)
// Skip empty or whitespace-only thinking
if strings.TrimSpace(thinkingText) != "" {
reasoningParts = append(reasoningParts, thinkingText)
}
}
// Ignore thinking in user/system roles (AC4)
case "redacted_thinking":
// Explicitly ignore redacted_thinking - never map to reasoning_content (AC2)
case "text", "image":
if contentItem, ok := convertClaudeContentPart(part); ok {
contentItems = append(contentItems, contentItem)
}
case "tool_use":
// Convert to OpenAI tool call format
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
// Only allow tool_use -> tool_calls for assistant messages (security: prevent injection).
if role == "assistant" {
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
// Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() {
if inputJSON, err := json.Marshal(input.Value()); err == nil {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", string(inputJSON))
// Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
}
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
}
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
case "tool_result":
// Convert to OpenAI tool message format and add immediately to preserve order
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", part.Get("content").String())
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
toolResults = append(toolResults, toolResultJSON)
}
return true
})
// Emit text/image content as one message
if len(contentItems) > 0 {
msgJSON := `{"role":"","content":""}`
msgJSON, _ = sjson.Set(msgJSON, "role", role)
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
contentValue := gjson.Get(msgJSON, "content")
hasContent := false
switch {
case !contentValue.Exists():
hasContent = false
case contentValue.Type == gjson.String:
hasContent = contentValue.String() != ""
case contentValue.IsArray():
hasContent = len(contentValue.Array()) > 0
default:
hasContent = contentValue.Raw != "" && contentValue.Raw != "null"
}
if hasContent {
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
}
// Build reasoning content string
reasoningContent := ""
if len(reasoningParts) > 0 {
reasoningContent = strings.Join(reasoningParts, "\n\n")
}
// Emit tool calls in a separate assistant message
if role == "assistant" && len(toolCalls) > 0 {
toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}`
toolCallsJSON, _ := json.Marshal(toolCalls)
toolCallMsgJSON, _ = sjson.SetRaw(toolCallMsgJSON, "tool_calls", string(toolCallsJSON))
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value())
hasContent := len(contentItems) > 0
hasReasoning := reasoningContent != ""
hasToolCalls := len(toolCalls) > 0
hasToolResults := len(toolResults) > 0
// OpenAI requires: tool messages MUST immediately follow the assistant message with tool_calls.
// Therefore, we emit tool_result messages FIRST (they respond to the previous assistant's tool_calls),
// then emit the current message's content.
for _, toolResultJSON := range toolResults {
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
}
// For assistant messages: emit a single unified message with content, tool_calls, and reasoning_content
// This avoids splitting into multiple assistant messages which breaks OpenAI tool-call adjacency
if role == "assistant" {
if hasContent || hasReasoning || hasToolCalls {
msgJSON := `{"role":"assistant"}`
// Add content (as array if we have items, empty string if reasoning-only)
if hasContent {
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
} else {
// Ensure content field exists for OpenAI compatibility
msgJSON, _ = sjson.Set(msgJSON, "content", "")
}
// Add reasoning_content if present
if hasReasoning {
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
}
// Add tool_calls if present (in same message as content)
if hasToolCalls {
msgJSON, _ = sjson.Set(msgJSON, "tool_calls", toolCalls)
}
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
}
} else {
// For non-assistant roles: emit content message if we have content
// If the message only contains tool_results (no text/image), we still processed them above
if hasContent {
msgJSON := `{"role":""}`
msgJSON, _ = sjson.Set(msgJSON, "role", role)
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
} else if hasToolResults && !hasContent {
// tool_results already emitted above, no additional user message needed
}
}
} else if contentResult.Exists() && contentResult.Type == gjson.String {
@@ -313,3 +356,43 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
return "", false
}
}
func convertClaudeToolResultContentToString(content gjson.Result) string {
if !content.Exists() {
return ""
}
if content.Type == gjson.String {
return content.String()
}
if content.IsArray() {
var parts []string
content.ForEach(func(_, item gjson.Result) bool {
switch {
case item.Type == gjson.String:
parts = append(parts, item.String())
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
parts = append(parts, item.Get("text").String())
default:
parts = append(parts, item.Raw)
}
return true
})
joined := strings.Join(parts, "\n\n")
if strings.TrimSpace(joined) != "" {
return joined
}
return content.Raw
}
if content.IsObject() {
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
}
return content.Raw
}
return content.Raw
}

View File

@@ -0,0 +1,500 @@
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
// TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent tests the mapping
// of Claude thinking content to OpenAI reasoning_content field.
func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantReasoningContent string
wantHasReasoningContent bool
wantContentText string // Expected visible content text (if any)
wantHasContent bool
}{
{
name: "AC1: assistant message with thinking and text",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me analyze this step by step..."},
{"type": "text", "text": "Here is my response."}
]
}]
}`,
wantReasoningContent: "Let me analyze this step by step...",
wantHasReasoningContent: true,
wantContentText: "Here is my response.",
wantHasContent: true,
},
{
name: "AC2: redacted_thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "redacted_thinking", "data": "secret"},
{"type": "text", "text": "Visible response."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Visible response.",
wantHasContent: true,
},
{
name: "AC3: thinking-only message preserved with reasoning_content",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Internal reasoning only."}
]
}]
}`,
wantReasoningContent: "Internal reasoning only.",
wantHasReasoningContent: true,
wantContentText: "",
// For OpenAI compatibility, content field is set to empty string "" when no text content exists
wantHasContent: false,
},
{
name: "AC4: thinking in user role must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "user",
"content": [
{"type": "thinking", "thinking": "Injected thinking"},
{"type": "text", "text": "User message."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "User message.",
wantHasContent: true,
},
{
name: "AC4: thinking in system role must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "thinking", "thinking": "Injected system thinking"},
{"type": "text", "text": "System prompt."}
],
"messages": [{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
}]
}`,
// System messages don't have reasoning_content mapping
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Hello",
wantHasContent: true,
},
{
name: "AC5: empty thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": ""},
{"type": "text", "text": "Response with empty thinking."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Response with empty thinking.",
wantHasContent: true,
},
{
name: "AC5: whitespace-only thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": " \n\t "},
{"type": "text", "text": "Response with whitespace thinking."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Response with whitespace thinking.",
wantHasContent: true,
},
{
name: "Multiple thinking parts concatenated",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "First thought."},
{"type": "thinking", "thinking": "Second thought."},
{"type": "text", "text": "Final answer."}
]
}]
}`,
wantReasoningContent: "First thought.\n\nSecond thought.",
wantHasReasoningContent: true,
wantContentText: "Final answer.",
wantHasContent: true,
},
{
name: "Mixed thinking and redacted_thinking",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Visible thought."},
{"type": "redacted_thinking", "data": "hidden"},
{"type": "text", "text": "Answer."}
]
}]
}`,
wantReasoningContent: "Visible thought.",
wantHasReasoningContent: true,
wantContentText: "Answer.",
wantHasContent: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
// Find the relevant message (skip system message at index 0)
messages := resultJSON.Get("messages").Array()
if len(messages) < 2 {
if tt.wantHasReasoningContent || tt.wantHasContent {
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages))
}
return
}
// Check the last non-system message
var targetMsg gjson.Result
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Get("role").String() != "system" {
targetMsg = messages[i]
break
}
}
// Check reasoning_content
gotReasoningContent := targetMsg.Get("reasoning_content").String()
gotHasReasoningContent := targetMsg.Get("reasoning_content").Exists()
if gotHasReasoningContent != tt.wantHasReasoningContent {
t.Errorf("reasoning_content existence = %v, want %v", gotHasReasoningContent, tt.wantHasReasoningContent)
}
if gotReasoningContent != tt.wantReasoningContent {
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
}
// Check content
content := targetMsg.Get("content")
// content has meaningful content if it's a non-empty array, or a non-empty string
var gotHasContent bool
switch {
case content.IsArray():
gotHasContent = len(content.Array()) > 0
case content.Type == gjson.String:
gotHasContent = content.String() != ""
default:
gotHasContent = false
}
if gotHasContent != tt.wantHasContent {
t.Errorf("content existence = %v, want %v", gotHasContent, tt.wantHasContent)
}
if tt.wantHasContent && tt.wantContentText != "" {
// Find text content
var foundText string
content.ForEach(func(_, v gjson.Result) bool {
if v.Get("type").String() == "text" {
foundText = v.Get("text").String()
return false
}
return true
})
if foundText != tt.wantContentText {
t.Errorf("content text = %q, want %q", foundText, tt.wantContentText)
}
}
})
}
}
// TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved tests AC3:
// that a message with only thinking content is preserved (not dropped).
func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}]
},
{
"role": "assistant",
"content": [{"type": "thinking", "thinking": "Let me calculate: 2+2=4"}]
},
{
"role": "user",
"content": [{"type": "text", "text": "Thanks"}]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages
if len(messages) != 4 {
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
}
// Check the assistant message (index 2) has reasoning_content
assistantMsg := messages[2]
if assistantMsg.Get("role").String() != "assistant" {
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String())
}
if !assistantMsg.Get("reasoning_content").Exists() {
t.Error("Expected assistant message to have reasoning_content")
}
if assistantMsg.Get("reasoning_content").String() != "Let me calculate: 2+2=4" {
t.Errorf("Unexpected reasoning_content: %s", assistantMsg.Get("reasoning_content").String())
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "before"},
{"type": "tool_result", "tool_use_id": "call_1", "content": [{"type":"text","text":"tool ok"}]},
{"type": "text", "text": "after"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after)
if len(messages) != 4 {
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[0].Get("role").String() != "system" {
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
}
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
}
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
if messages[2].Get("role").String() != "tool" {
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String())
}
if got := messages[2].Get("tool_call_id").String(); got != "call_1" {
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
}
if got := messages[2].Get("content").String(); got != "tool ok" {
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
}
// User message comes after tool message
if messages[3].Get("role").String() != "user" {
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String())
}
// User message should contain both "before" and "after" text
if got := messages[3].Get("content.0.text").String(); got != "before" {
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
}
if got := messages[3].Get("content.1.text").String(); got != "after" {
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "call_1", "content": {"foo": "bar"}}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// system + assistant(tool_calls) + tool(result)
if len(messages) != 3 {
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[2].Get("role").String() != "tool" {
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String())
}
toolContent := messages[2].Get("content").String()
parsed := gjson.Parse(toolContent)
if parsed.Get("foo").String() != "bar" {
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "pre"},
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
{"type": "text", "text": "post"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// New behavior: content + tool_calls unified in single assistant message
// Expect: system + assistant(content[pre,post] + tool_calls)
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[0].Get("role").String() != "system" {
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
}
assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
}
// Should have both content and tool_calls in same message
if !assistantMsg.Get("tool_calls").Exists() {
t.Fatalf("Expected assistant message to have tool_calls")
}
if got := assistantMsg.Get("tool_calls.0.id").String(); got != "call_1" {
t.Fatalf("Expected tool_call id %q, got %q", "call_1", got)
}
if got := assistantMsg.Get("tool_calls.0.function.name").String(); got != "do_work" {
t.Fatalf("Expected tool_call name %q, got %q", "do_work", got)
}
// Content should have both pre and post text
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
}
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "t1"},
{"type": "text", "text": "pre"},
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
{"type": "thinking", "thinking": "t2"},
{"type": "text", "text": "post"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// New behavior: all content, thinking, and tool_calls unified in single assistant message
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
}
// Should have content with both pre and post
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
}
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
}
// Should have tool_calls
if !assistantMsg.Get("tool_calls").Exists() {
t.Fatalf("Expected assistant message to have tool_calls")
}
// Should have combined reasoning_content from both thinking blocks
if got := assistantMsg.Get("reasoning_content").String(); got != "t1\n\nt2" {
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
}
}

View File

@@ -8,7 +8,6 @@ package claude
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
@@ -133,24 +132,10 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
if delta := root.Get("choices.0.delta"); delta.Exists() {
if !param.MessageStarted {
// Send message_start event
messageStart := map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": param.MessageID,
"type": "message",
"role": "assistant",
"model": param.Model,
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
},
}
messageStartJSON, _ := json.Marshal(messageStart)
results = append(results, "event: message_start\ndata: "+string(messageStartJSON)+"\n\n")
messageStartJSON := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
messageStartJSON, _ = sjson.Set(messageStartJSON, "message.id", param.MessageID)
messageStartJSON, _ = sjson.Set(messageStartJSON, "message.model", param.Model)
results = append(results, "event: message_start\ndata: "+messageStartJSON+"\n\n")
param.MessageStarted = true
// Don't send content_block_start for text here - wait for actual content
@@ -168,29 +153,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
param.ThinkingContentBlockIndex = param.NextContentBlockIndex
param.NextContentBlockIndex++
}
contentBlockStart := map[string]interface{}{
"type": "content_block_start",
"index": param.ThinkingContentBlockIndex,
"content_block": map[string]interface{}{
"type": "thinking",
"thinking": "",
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", param.ThinkingContentBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
param.ThinkingContentBlockStarted = true
}
thinkingDelta := map[string]interface{}{
"type": "content_block_delta",
"index": param.ThinkingContentBlockIndex,
"delta": map[string]interface{}{
"type": "thinking_delta",
"thinking": reasoningText,
},
}
thinkingDeltaJSON, _ := json.Marshal(thinkingDelta)
results = append(results, "event: content_block_delta\ndata: "+string(thinkingDeltaJSON)+"\n\n")
thinkingDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
thinkingDeltaJSON, _ = sjson.Set(thinkingDeltaJSON, "index", param.ThinkingContentBlockIndex)
thinkingDeltaJSON, _ = sjson.Set(thinkingDeltaJSON, "delta.thinking", reasoningText)
results = append(results, "event: content_block_delta\ndata: "+thinkingDeltaJSON+"\n\n")
}
}
@@ -203,29 +175,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
param.TextContentBlockIndex = param.NextContentBlockIndex
param.NextContentBlockIndex++
}
contentBlockStart := map[string]interface{}{
"type": "content_block_start",
"index": param.TextContentBlockIndex,
"content_block": map[string]interface{}{
"type": "text",
"text": "",
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", param.TextContentBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
param.TextContentBlockStarted = true
}
contentDelta := map[string]interface{}{
"type": "content_block_delta",
"index": param.TextContentBlockIndex,
"delta": map[string]interface{}{
"type": "text_delta",
"text": content.String(),
},
}
contentDeltaJSON, _ := json.Marshal(contentDelta)
results = append(results, "event: content_block_delta\ndata: "+string(contentDeltaJSON)+"\n\n")
contentDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
contentDeltaJSON, _ = sjson.Set(contentDeltaJSON, "index", param.TextContentBlockIndex)
contentDeltaJSON, _ = sjson.Set(contentDeltaJSON, "delta.text", content.String())
results = append(results, "event: content_block_delta\ndata: "+contentDeltaJSON+"\n\n")
// Accumulate content
param.ContentAccumulator.WriteString(content.String())
@@ -263,18 +222,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
stopTextContentBlock(param, &results)
// Send content_block_start for tool_use
contentBlockStart := map[string]interface{}{
"type": "content_block_start",
"index": blockIndex,
"content_block": map[string]interface{}{
"type": "tool_use",
"id": accumulator.ID,
"name": accumulator.Name,
"input": map[string]interface{}{},
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
}
// Handle function arguments
@@ -298,12 +250,9 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send content_block_stop for thinking content if needed
if param.ThinkingContentBlockStarted {
contentBlockStop := map[string]interface{}{
"type": "content_block_stop",
"index": param.ThinkingContentBlockIndex,
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1
}
@@ -319,24 +268,15 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send complete input_json_delta with all accumulated arguments
if accumulator.Arguments.Len() > 0 {
inputDelta := map[string]interface{}{
"type": "content_block_delta",
"index": blockIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": util.FixJSON(accumulator.Arguments.String()),
},
}
inputDeltaJSON, _ := json.Marshal(inputDelta)
results = append(results, "event: content_block_delta\ndata: "+string(inputDeltaJSON)+"\n\n")
inputDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "index", blockIndex)
inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "delta.partial_json", util.FixJSON(accumulator.Arguments.String()))
results = append(results, "event: content_block_delta\ndata: "+inputDeltaJSON+"\n\n")
}
contentBlockStop := map[string]interface{}{
"type": "content_block_stop",
"index": blockIndex,
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", blockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
delete(param.ToolCallBlockIndexes, index)
}
param.ContentBlocksStopped = true
@@ -359,26 +299,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
inputTokens = promptTokens.Int()
outputTokens = completionTokens.Int()
}
// Send message_delta with usage
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results)
}
// Send message_delta with usage
messageDelta := map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason),
"stop_sequence": nil,
},
"usage": map[string]interface{}{
"input_tokens": inputTokens,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n")
param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results)
}
return results
@@ -390,12 +320,9 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
// Ensure all content blocks are stopped before final events
if param.ThinkingContentBlockStarted {
contentBlockStop := map[string]interface{}{
"type": "content_block_stop",
"index": param.ThinkingContentBlockIndex,
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1
}
@@ -408,24 +335,15 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
blockIndex := param.toolContentBlockIndex(index)
if accumulator.Arguments.Len() > 0 {
inputDelta := map[string]interface{}{
"type": "content_block_delta",
"index": blockIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": util.FixJSON(accumulator.Arguments.String()),
},
}
inputDeltaJSON, _ := json.Marshal(inputDelta)
results = append(results, "event: content_block_delta\ndata: "+string(inputDeltaJSON)+"\n\n")
inputDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "index", blockIndex)
inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "delta.partial_json", util.FixJSON(accumulator.Arguments.String()))
results = append(results, "event: content_block_delta\ndata: "+inputDeltaJSON+"\n\n")
}
contentBlockStop := map[string]interface{}{
"type": "content_block_stop",
"index": blockIndex,
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", blockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
delete(param.ToolCallBlockIndexes, index)
}
param.ContentBlocksStopped = true
@@ -433,16 +351,9 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
// If we haven't sent message_delta yet (no usage info was received), send it now
if param.FinishReason != "" && !param.MessageDeltaSent {
messageDelta := map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason),
"stop_sequence": nil,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n")
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
}
@@ -455,105 +366,73 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
root := gjson.ParseBytes(rawJSON)
// Build Anthropic response
response := map[string]interface{}{
"id": root.Get("id").String(),
"type": "message",
"role": "assistant",
"model": root.Get("model").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
}
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
// Process message content and tool calls
var contentBlocks []interface{}
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 {
choice := choices.Array()[0] // Take first choice
reasoningNode := choice.Get("message.reasoning_content")
allReasoning := collectOpenAIReasoningTexts(reasoningNode)
for _, reasoningText := range allReasoning {
reasoningNode := choice.Get("message.reasoning_content")
for _, reasoningText := range collectOpenAIReasoningTexts(reasoningNode) {
if reasoningText == "" {
continue
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "thinking",
"thinking": reasoningText,
})
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", reasoningText)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
// Handle text content
if content := choice.Get("message.content"); content.Exists() && content.String() != "" {
textBlock := map[string]interface{}{
"type": "text",
"text": content.String(),
}
contentBlocks = append(contentBlocks, textBlock)
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", content.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
}
// Handle tool calls
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolUseBlock := map[string]interface{}{
"type": "tool_use",
"id": toolCall.Get("id").String(),
"name": toolCall.Get("function.name").String(),
}
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
// Parse arguments
argsStr := toolCall.Get("function.arguments").String()
argsStr = util.FixJSON(argsStr)
if argsStr != "" {
var args interface{}
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
toolUseBlock["input"] = args
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", argsJSON.Raw)
} else {
toolUseBlock["input"] = map[string]interface{}{}
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
}
} else {
toolUseBlock["input"] = map[string]interface{}{}
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
}
contentBlocks = append(contentBlocks, toolUseBlock)
out, _ = sjson.SetRaw(out, "content.-1", toolUseBlock)
return true
})
}
// Set stop reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
response["stop_reason"] = mapOpenAIFinishReasonToAnthropic(finishReason.String())
out, _ = sjson.Set(out, "stop_reason", mapOpenAIFinishReasonToAnthropic(finishReason.String()))
}
}
response["content"] = contentBlocks
// Set usage information
if usage := root.Get("usage"); usage.Exists() {
response["usage"] = map[string]interface{}{
"input_tokens": usage.Get("prompt_tokens").Int(),
"output_tokens": usage.Get("completion_tokens").Int(),
"reasoning_tokens": func() int64 {
if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
return v.Int()
}
return 0
}(),
}
} else {
response["usage"] = map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
out, _ = sjson.Set(out, "usage.input_tokens", usage.Get("prompt_tokens").Int())
out, _ = sjson.Set(out, "usage.output_tokens", usage.Get("completion_tokens").Int())
reasoningTokens := int64(0)
if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
reasoningTokens = v.Int()
}
out, _ = sjson.Set(out, "usage.reasoning_tokens", reasoningTokens)
}
responseJSON, _ := json.Marshal(response)
return []string{string(responseJSON)}
return []string{out}
}
// mapOpenAIFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic equivalents
@@ -600,15 +479,15 @@ func collectOpenAIReasoningTexts(node gjson.Result) []string {
switch node.Type {
case gjson.String:
if text := strings.TrimSpace(node.String()); text != "" {
if text := node.String(); text != "" {
texts = append(texts, text)
}
case gjson.JSON:
if text := node.Get("text"); text.Exists() {
if trimmed := strings.TrimSpace(text.String()); trimmed != "" {
texts = append(texts, trimmed)
if textStr := text.String(); textStr != "" {
texts = append(texts, textStr)
}
} else if raw := strings.TrimSpace(node.Raw); raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
} else if raw := node.Raw; raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
texts = append(texts, raw)
}
}
@@ -620,12 +499,9 @@ func stopThinkingContentBlock(param *ConvertOpenAIResponseToAnthropicParams, res
if !param.ThinkingContentBlockStarted {
return
}
contentBlockStop := map[string]interface{}{
"type": "content_block_stop",
"index": param.ThinkingContentBlockIndex,
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
*results = append(*results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
*results = append(*results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1
}
@@ -642,12 +518,9 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results
if !param.TextContentBlockStarted {
return
}
contentBlockStop := map[string]interface{}{
"type": "content_block_stop",
"index": param.TextContentBlockIndex,
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
*results = append(*results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.TextContentBlockIndex)
*results = append(*results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
param.TextContentBlockStarted = false
param.TextContentBlockIndex = -1
}
@@ -667,29 +540,19 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
response := map[string]interface{}{
"id": root.Get("id").String(),
"type": "message",
"role": "assistant",
"model": root.Get("model").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
}
contentBlocks := make([]interface{}, 0)
hasToolCall := false
stopReasonSet := false
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 {
choice := choices.Array()[0]
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
response["stop_reason"] = mapOpenAIFinishReasonToAnthropic(finishReason.String())
out, _ = sjson.Set(out, "stop_reason", mapOpenAIFinishReasonToAnthropic(finishReason.String()))
stopReasonSet = true
}
if message := choice.Get("message"); message.Exists() {
@@ -702,10 +565,9 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if textBuilder.Len() == 0 {
return
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": textBuilder.String(),
})
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
textBuilder.Reset()
}
@@ -713,16 +575,14 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if thinkingBuilder.Len() == 0 {
return
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "thinking",
"thinking": thinkingBuilder.String(),
})
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
thinkingBuilder.Reset()
}
for _, item := range contentResult.Array() {
typeStr := item.Get("type").String()
switch typeStr {
switch item.Get("type").String() {
case "text":
flushThinking()
textBuilder.WriteString(item.Get("text").String())
@@ -733,25 +593,23 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if toolCalls.IsArray() {
toolCalls.ForEach(func(_, tc gjson.Result) bool {
hasToolCall = true
toolUse := map[string]interface{}{
"type": "tool_use",
"id": tc.Get("id").String(),
"name": tc.Get("function.name").String(),
}
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
toolUse, _ = sjson.Set(toolUse, "name", tc.Get("function.name").String())
argsStr := util.FixJSON(tc.Get("function.arguments").String())
if argsStr != "" {
var parsed interface{}
if err := json.Unmarshal([]byte(argsStr), &parsed); err == nil {
toolUse["input"] = parsed
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
} else {
toolUse["input"] = map[string]interface{}{}
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
} else {
toolUse["input"] = map[string]interface{}{}
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
contentBlocks = append(contentBlocks, toolUse)
out, _ = sjson.SetRaw(out, "content.-1", toolUse)
return true
})
}
@@ -771,10 +629,9 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
} else if contentResult.Type == gjson.String {
textContent := contentResult.String()
if textContent != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": textContent,
})
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textContent)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
}
@@ -784,81 +641,52 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if reasoningText == "" {
continue
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "thinking",
"thinking": reasoningText,
})
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", reasoningText)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
hasToolCall = true
toolUseBlock := map[string]interface{}{
"type": "tool_use",
"id": toolCall.Get("id").String(),
"name": toolCall.Get("function.name").String(),
}
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
argsStr := toolCall.Get("function.arguments").String()
argsStr = util.FixJSON(argsStr)
if argsStr != "" {
var args interface{}
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
toolUseBlock["input"] = args
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", argsJSON.Raw)
} else {
toolUseBlock["input"] = map[string]interface{}{}
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
}
} else {
toolUseBlock["input"] = map[string]interface{}{}
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
}
contentBlocks = append(contentBlocks, toolUseBlock)
out, _ = sjson.SetRaw(out, "content.-1", toolUseBlock)
return true
})
}
}
}
response["content"] = contentBlocks
if respUsage := root.Get("usage"); respUsage.Exists() {
usageJSON := `{}`
usageJSON, _ = sjson.Set(usageJSON, "input_tokens", respUsage.Get("prompt_tokens").Int())
usageJSON, _ = sjson.Set(usageJSON, "output_tokens", respUsage.Get("completion_tokens").Int())
parsedUsage := gjson.Parse(usageJSON).Value().(map[string]interface{})
response["usage"] = parsedUsage
} else {
response["usage"] = `{"input_tokens":0,"output_tokens":0}`
out, _ = sjson.Set(out, "usage.input_tokens", respUsage.Get("prompt_tokens").Int())
out, _ = sjson.Set(out, "usage.output_tokens", respUsage.Get("completion_tokens").Int())
}
if response["stop_reason"] == nil {
if !stopReasonSet {
if hasToolCall {
response["stop_reason"] = "tool_use"
out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else {
response["stop_reason"] = "end_turn"
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
}
if !hasToolCall {
if toolBlocks := response["content"].([]interface{}); len(toolBlocks) > 0 {
for _, block := range toolBlocks {
if m, ok := block.(map[string]interface{}); ok && m["type"] == "tool_use" {
hasToolCall = true
break
}
}
}
if hasToolCall {
response["stop_reason"] = "tool_use"
}
}
responseJSON, err := json.Marshal(response)
if err != nil {
return ""
}
return string(responseJSON)
return out
}
func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -8,7 +8,6 @@ package gemini
import (
"bytes"
"crypto/rand"
"encoding/json"
"fmt"
"math/big"
"strings"
@@ -94,7 +93,6 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
out, _ = sjson.Set(out, "stream", stream)
// Process contents (Gemini messages) -> OpenAI messages
var openAIMessages []interface{}
var toolCallIDs []string // Track tool call IDs for matching with tool results
// System instruction -> OpenAI system message
@@ -105,22 +103,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
}
if systemInstruction.Exists() {
parts := systemInstruction.Get("parts")
msg := map[string]interface{}{
"role": "system",
"content": []interface{}{},
}
var aggregatedParts []interface{}
msg := `{"role":"system","content":[]}`
hasContent := false
if parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
// Handle text parts
if text := part.Get("text"); text.Exists() {
formattedText := text.String()
aggregatedParts = append(aggregatedParts, map[string]interface{}{
"type": "text",
"text": formattedText,
})
contentPart := `{"type":"text","text":""}`
contentPart, _ = sjson.Set(contentPart, "text", text.String())
msg, _ = sjson.SetRaw(msg, "content.-1", contentPart)
hasContent = true
}
// Handle inline data (e.g., images)
@@ -132,20 +125,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
data := inlineData.Get("data").String()
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
aggregatedParts = append(aggregatedParts, map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": imageURL,
},
})
contentPart := `{"type":"image_url","image_url":{"url":""}}`
contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
msg, _ = sjson.SetRaw(msg, "content.-1", contentPart)
hasContent = true
}
return true
})
}
if len(aggregatedParts) > 0 {
msg["content"] = aggregatedParts
openAIMessages = append(openAIMessages, msg)
if hasContent {
out, _ = sjson.SetRaw(out, "messages.-1", msg)
}
}
@@ -159,16 +149,15 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
role = "assistant"
}
// Create OpenAI message
msg := map[string]interface{}{
"role": role,
"content": "",
}
msg := `{"role":"","content":""}`
msg, _ = sjson.Set(msg, "role", role)
var textBuilder strings.Builder
var aggregatedParts []interface{}
contentWrapper := `{"arr":[]}`
contentPartsCount := 0
onlyTextContent := true
var toolCalls []interface{}
toolCallsWrapper := `{"arr":[]}`
toolCallsCount := 0
if parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
@@ -176,10 +165,10 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
if text := part.Get("text"); text.Exists() {
formattedText := text.String()
textBuilder.WriteString(formattedText)
aggregatedParts = append(aggregatedParts, map[string]interface{}{
"type": "text",
"text": formattedText,
})
contentPart := `{"type":"text","text":""}`
contentPart, _ = sjson.Set(contentPart, "text", formattedText)
contentWrapper, _ = sjson.SetRaw(contentWrapper, "arr.-1", contentPart)
contentPartsCount++
}
// Handle inline data (e.g., images)
@@ -193,12 +182,10 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
data := inlineData.Get("data").String()
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
aggregatedParts = append(aggregatedParts, map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": imageURL,
},
})
contentPart := `{"type":"image_url","image_url":{"url":""}}`
contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
contentWrapper, _ = sjson.SetRaw(contentWrapper, "arr.-1", contentPart)
contentPartsCount++
}
// Handle function calls (Gemini) -> tool calls (OpenAI)
@@ -206,44 +193,32 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
toolCallID := genToolCallID()
toolCallIDs = append(toolCallIDs, toolCallID)
toolCall := map[string]interface{}{
"id": toolCallID,
"type": "function",
"function": map[string]interface{}{
"name": functionCall.Get("name").String(),
},
}
toolCall := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCall, _ = sjson.Set(toolCall, "id", toolCallID)
toolCall, _ = sjson.Set(toolCall, "function.name", functionCall.Get("name").String())
// Convert args to arguments JSON string
if args := functionCall.Get("args"); args.Exists() {
argsJSON, _ := json.Marshal(args.Value())
toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON)
toolCall, _ = sjson.Set(toolCall, "function.arguments", args.Raw)
} else {
toolCall["function"].(map[string]interface{})["arguments"] = "{}"
toolCall, _ = sjson.Set(toolCall, "function.arguments", "{}")
}
toolCalls = append(toolCalls, toolCall)
toolCallsWrapper, _ = sjson.SetRaw(toolCallsWrapper, "arr.-1", toolCall)
toolCallsCount++
}
// Handle function responses (Gemini) -> tool role messages (OpenAI)
if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
// Create tool message for function response
toolMsg := map[string]interface{}{
"role": "tool",
"tool_call_id": "", // Will be set based on context
"content": "",
}
toolMsg := `{"role":"tool","tool_call_id":"","content":""}`
// Convert response.content to JSON string
if response := functionResponse.Get("response"); response.Exists() {
if content = response.Get("content"); content.Exists() {
// Use the content field from the response
contentJSON, _ := json.Marshal(content.Value())
toolMsg["content"] = string(contentJSON)
if contentField := response.Get("content"); contentField.Exists() {
toolMsg, _ = sjson.Set(toolMsg, "content", contentField.Raw)
} else {
// Fallback to entire response
responseJSON, _ := json.Marshal(response.Value())
toolMsg["content"] = string(responseJSON)
toolMsg, _ = sjson.Set(toolMsg, "content", response.Raw)
}
}
@@ -252,13 +227,13 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
if len(toolCallIDs) > 0 {
// Use the last tool call ID (simple matching by function name)
// In a real implementation, you might want more sophisticated matching
toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1]
toolMsg, _ = sjson.Set(toolMsg, "tool_call_id", toolCallIDs[len(toolCallIDs)-1])
} else {
// Generate a tool call ID if none available
toolMsg["tool_call_id"] = genToolCallID()
toolMsg, _ = sjson.Set(toolMsg, "tool_call_id", genToolCallID())
}
openAIMessages = append(openAIMessages, toolMsg)
out, _ = sjson.SetRaw(out, "messages.-1", toolMsg)
}
return true
@@ -266,170 +241,46 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
}
// Set content
if len(aggregatedParts) > 0 {
if contentPartsCount > 0 {
if onlyTextContent {
msg["content"] = textBuilder.String()
msg, _ = sjson.Set(msg, "content", textBuilder.String())
} else {
msg["content"] = aggregatedParts
msg, _ = sjson.SetRaw(msg, "content", gjson.Get(contentWrapper, "arr").Raw)
}
}
// Set tool calls if any
if len(toolCalls) > 0 {
msg["tool_calls"] = toolCalls
if toolCallsCount > 0 {
msg, _ = sjson.SetRaw(msg, "tool_calls", gjson.Get(toolCallsWrapper, "arr").Raw)
}
openAIMessages = append(openAIMessages, msg)
// switch role {
// case "user", "model":
// // Convert role: model -> assistant
// if role == "model" {
// role = "assistant"
// }
//
// // Create OpenAI message
// msg := map[string]interface{}{
// "role": role,
// "content": "",
// }
//
// var contentParts []string
// var toolCalls []interface{}
//
// if parts.Exists() && parts.IsArray() {
// parts.ForEach(func(_, part gjson.Result) bool {
// // Handle text parts
// if text := part.Get("text"); text.Exists() {
// contentParts = append(contentParts, text.String())
// }
//
// // Handle function calls (Gemini) -> tool calls (OpenAI)
// if functionCall := part.Get("functionCall"); functionCall.Exists() {
// toolCallID := genToolCallID()
// toolCallIDs = append(toolCallIDs, toolCallID)
//
// toolCall := map[string]interface{}{
// "id": toolCallID,
// "type": "function",
// "function": map[string]interface{}{
// "name": functionCall.Get("name").String(),
// },
// }
//
// // Convert args to arguments JSON string
// if args := functionCall.Get("args"); args.Exists() {
// argsJSON, _ := json.Marshal(args.Value())
// toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON)
// } else {
// toolCall["function"].(map[string]interface{})["arguments"] = "{}"
// }
//
// toolCalls = append(toolCalls, toolCall)
// }
//
// return true
// })
// }
//
// // Set content
// if len(contentParts) > 0 {
// msg["content"] = strings.Join(contentParts, "")
// }
//
// // Set tool calls if any
// if len(toolCalls) > 0 {
// msg["tool_calls"] = toolCalls
// }
//
// openAIMessages = append(openAIMessages, msg)
//
// case "function":
// // Handle Gemini function role -> OpenAI tool role
// if parts.Exists() && parts.IsArray() {
// parts.ForEach(func(_, part gjson.Result) bool {
// // Handle function responses (Gemini) -> tool role messages (OpenAI)
// if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
// // Create tool message for function response
// toolMsg := map[string]interface{}{
// "role": "tool",
// "tool_call_id": "", // Will be set based on context
// "content": "",
// }
//
// // Convert response.content to JSON string
// if response := functionResponse.Get("response"); response.Exists() {
// if content = response.Get("content"); content.Exists() {
// // Use the content field from the response
// contentJSON, _ := json.Marshal(content.Value())
// toolMsg["content"] = string(contentJSON)
// } else {
// // Fallback to entire response
// responseJSON, _ := json.Marshal(response.Value())
// toolMsg["content"] = string(responseJSON)
// }
// }
//
// // Try to match with previous tool call ID
// _ = functionResponse.Get("name").String() // functionName not used for now
// if len(toolCallIDs) > 0 {
// // Use the last tool call ID (simple matching by function name)
// // In a real implementation, you might want more sophisticated matching
// toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1]
// } else {
// // Generate a tool call ID if none available
// toolMsg["tool_call_id"] = genToolCallID()
// }
//
// openAIMessages = append(openAIMessages, toolMsg)
// }
//
// return true
// })
// }
// }
out, _ = sjson.SetRaw(out, "messages.-1", msg)
return true
})
}
// Set messages
if len(openAIMessages) > 0 {
messagesJSON, _ := json.Marshal(openAIMessages)
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
}
// Tools mapping: Gemini tools -> OpenAI tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
var openAITools []interface{}
tools.ForEach(func(_, tool gjson.Result) bool {
if functionDeclarations := tool.Get("functionDeclarations"); functionDeclarations.Exists() && functionDeclarations.IsArray() {
functionDeclarations.ForEach(func(_, funcDecl gjson.Result) bool {
openAITool := map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": funcDecl.Get("name").String(),
"description": funcDecl.Get("description").String(),
},
}
openAITool := `{"type":"function","function":{"name":"","description":""}}`
openAITool, _ = sjson.Set(openAITool, "function.name", funcDecl.Get("name").String())
openAITool, _ = sjson.Set(openAITool, "function.description", funcDecl.Get("description").String())
// Convert parameters schema
if parameters := funcDecl.Get("parameters"); parameters.Exists() {
openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value()
} else if parameters = funcDecl.Get("parametersJsonSchema"); parameters.Exists() {
openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value()
openAITool, _ = sjson.SetRaw(openAITool, "function.parameters", parameters.Raw)
} else if parameters := funcDecl.Get("parametersJsonSchema"); parameters.Exists() {
openAITool, _ = sjson.SetRaw(openAITool, "function.parameters", parameters.Raw)
}
openAITools = append(openAITools, openAITool)
out, _ = sjson.SetRaw(out, "tools.-1", openAITool)
return true
})
}
return true
})
if len(openAITools) > 0 {
toolsJSON, _ := json.Marshal(openAITools)
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
}
}
// Tool choice mapping (Gemini doesn't have direct equivalent, but we can handle it)

View File

@@ -8,7 +8,6 @@ package gemini
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
@@ -84,15 +83,12 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
template, _ = sjson.Set(template, "model", model.String())
}
usageObj := map[string]interface{}{
"promptTokenCount": usage.Get("prompt_tokens").Int(),
"candidatesTokenCount": usage.Get("completion_tokens").Int(),
"totalTokenCount": usage.Get("total_tokens").Int(),
}
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", reasoningTokens)
}
template, _ = sjson.Set(template, "usageMetadata", usageObj)
return []string{template}
}
return []string{}
@@ -133,13 +129,8 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
continue
}
reasoningTemplate := baseTemplate
parts := []interface{}{
map[string]interface{}{
"thought": true,
"text": reasoningText,
},
}
reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts", parts)
reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts.0.thought", true)
reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts.0.text", reasoningText)
chunkOutputs = append(chunkOutputs, reasoningTemplate)
}
}
@@ -150,13 +141,8 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
(*param).(*ConvertOpenAIResponseToGeminiParams).ContentAccumulator.WriteString(contentText)
// Create text part for this delta
parts := []interface{}{
map[string]interface{}{
"text": contentText,
},
}
contentTemplate := baseTemplate
contentTemplate, _ = sjson.Set(contentTemplate, "candidates.0.content.parts", parts)
contentTemplate, _ = sjson.Set(contentTemplate, "candidates.0.content.parts.0.text", contentText)
chunkOutputs = append(chunkOutputs, contentTemplate)
}
@@ -225,24 +211,13 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
// If we have accumulated tool calls, output them now
if len((*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator) > 0 {
var parts []interface{}
partIndex := 0
for _, accumulator := range (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator {
argsStr := accumulator.Arguments.String()
var argsMap map[string]interface{}
argsMap = parseArgsToMap(argsStr)
functionCallPart := map[string]interface{}{
"functionCall": map[string]interface{}{
"name": accumulator.Name,
"args": argsMap,
},
}
parts = append(parts, functionCallPart)
}
if len(parts) > 0 {
template, _ = sjson.Set(template, "candidates.0.content.parts", parts)
namePath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.name", partIndex)
argsPath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.args", partIndex)
template, _ = sjson.Set(template, namePath, accumulator.Name)
template, _ = sjson.SetRaw(template, argsPath, parseArgsToObjectRaw(accumulator.Arguments.String()))
partIndex++
}
// Clear accumulators
@@ -255,15 +230,12 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
// Handle usage information
if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{
"promptTokenCount": usage.Get("prompt_tokens").Int(),
"candidatesTokenCount": usage.Get("completion_tokens").Int(),
"totalTokenCount": usage.Get("total_tokens").Int(),
}
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", reasoningTokens)
}
template, _ = sjson.Set(template, "usageMetadata", usageObj)
results = append(results, template)
return true
}
@@ -291,46 +263,54 @@ func mapOpenAIFinishReasonToGemini(openAIReason string) string {
}
}
// parseArgsToMap safely parses a JSON string of function arguments into a map.
// It returns an empty map if the input is empty or cannot be parsed as a JSON object.
func parseArgsToMap(argsStr string) map[string]interface{} {
// parseArgsToObjectRaw safely parses a JSON string of function arguments into an object JSON string.
// It returns "{}" if the input is empty or cannot be parsed as a JSON object.
func parseArgsToObjectRaw(argsStr string) string {
trimmed := strings.TrimSpace(argsStr)
if trimmed == "" || trimmed == "{}" {
return map[string]interface{}{}
return "{}"
}
// First try strict JSON
var out map[string]interface{}
if errUnmarshal := json.Unmarshal([]byte(trimmed), &out); errUnmarshal == nil {
return out
if gjson.Valid(trimmed) {
strict := gjson.Parse(trimmed)
if strict.IsObject() {
return strict.Raw
}
}
// Tolerant parse: handle streams where values are barewords (e.g., 北京, celsius)
tolerant := tolerantParseJSONMap(trimmed)
if len(tolerant) > 0 {
tolerant := tolerantParseJSONObjectRaw(trimmed)
if tolerant != "{}" {
return tolerant
}
// Fallback: return empty object when parsing fails
return map[string]interface{}{}
return "{}"
}
// tolerantParseJSONMap attempts to parse a JSON-like object string into a map, tolerating
func escapeSjsonPathKey(key string) string {
key = strings.ReplaceAll(key, `\`, `\\`)
key = strings.ReplaceAll(key, `.`, `\.`)
return key
}
// tolerantParseJSONObjectRaw attempts to parse a JSON-like object string into a JSON object string, tolerating
// bareword values (unquoted strings) commonly seen during streamed tool calls.
// Example input: {"location": 北京, "unit": celsius}
func tolerantParseJSONMap(s string) map[string]interface{} {
func tolerantParseJSONObjectRaw(s string) string {
// Ensure we operate within the outermost braces if present
start := strings.Index(s, "{")
end := strings.LastIndex(s, "}")
if start == -1 || end == -1 || start >= end {
return map[string]interface{}{}
return "{}"
}
content := s[start+1 : end]
runes := []rune(content)
n := len(runes)
i := 0
result := make(map[string]interface{})
result := "{}"
for i < n {
// Skip whitespace and commas
@@ -356,6 +336,7 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
break
}
keyName := jsonStringTokenToRawString(keyToken)
sjsonKey := escapeSjsonPathKey(keyName)
i = nextIdx
// Skip whitespace
@@ -375,17 +356,16 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
}
// Parse value (string, number, object/array, bareword)
var value interface{}
switch runes[i] {
case '"':
// JSON string
valToken, ni := parseJSONStringRunes(runes, i)
if ni == -1 {
// Malformed; treat as empty string
value = ""
result, _ = sjson.Set(result, sjsonKey, "")
i = n
} else {
value = jsonStringTokenToRawString(valToken)
result, _ = sjson.Set(result, sjsonKey, jsonStringTokenToRawString(valToken))
i = ni
}
case '{', '[':
@@ -394,11 +374,10 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
if ni == -1 {
i = n
} else {
var anyVal interface{}
if errUnmarshal := json.Unmarshal([]byte(seg), &anyVal); errUnmarshal == nil {
value = anyVal
if gjson.Valid(seg) {
result, _ = sjson.SetRaw(result, sjsonKey, seg)
} else {
value = seg
result, _ = sjson.Set(result, sjsonKey, seg)
}
i = ni
}
@@ -411,21 +390,19 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
token := strings.TrimSpace(string(runes[i:j]))
// Interpret common JSON atoms and numbers; otherwise treat as string
if token == "true" {
value = true
result, _ = sjson.Set(result, sjsonKey, true)
} else if token == "false" {
value = false
result, _ = sjson.Set(result, sjsonKey, false)
} else if token == "null" {
value = nil
result, _ = sjson.Set(result, sjsonKey, nil)
} else if numVal, ok := tryParseNumber(token); ok {
value = numVal
result, _ = sjson.Set(result, sjsonKey, numVal)
} else {
value = token
result, _ = sjson.Set(result, sjsonKey, token)
}
i = j
}
result[keyName] = value
// Skip trailing whitespace and optional comma before next pair
for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t') {
i++
@@ -463,9 +440,9 @@ func parseJSONStringRunes(runes []rune, start int) (string, int) {
// jsonStringTokenToRawString converts a JSON string token (including quotes) to a raw Go string value.
func jsonStringTokenToRawString(token string) string {
var s string
if errUnmarshal := json.Unmarshal([]byte(token), &s); errUnmarshal == nil {
return s
r := gjson.Parse(token)
if r.Type == gjson.String {
return r.String()
}
// Fallback: strip surrounding quotes if present
if len(token) >= 2 && token[0] == '"' && token[len(token)-1] == '"' {
@@ -579,7 +556,7 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
}
}
var parts []interface{}
partIndex := 0
// Handle reasoning content before visible text
if reasoning := message.Get("reasoning_content"); reasoning.Exists() {
@@ -587,18 +564,16 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
if reasoningText == "" {
continue
}
parts = append(parts, map[string]interface{}{
"thought": true,
"text": reasoningText,
})
out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.thought", partIndex), true)
out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.text", partIndex), reasoningText)
partIndex++
}
}
// Handle content first
if content := message.Get("content"); content.Exists() && content.String() != "" {
parts = append(parts, map[string]interface{}{
"text": content.String(),
})
out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.text", partIndex), content.String())
partIndex++
}
// Handle tool calls
@@ -609,27 +584,16 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
functionName := function.Get("name").String()
functionArgs := function.Get("arguments").String()
// Parse arguments
var argsMap map[string]interface{}
argsMap = parseArgsToMap(functionArgs)
functionCallPart := map[string]interface{}{
"functionCall": map[string]interface{}{
"name": functionName,
"args": argsMap,
},
}
parts = append(parts, functionCallPart)
namePath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.name", partIndex)
argsPath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.args", partIndex)
out, _ = sjson.Set(out, namePath, functionName)
out, _ = sjson.SetRaw(out, argsPath, parseArgsToObjectRaw(functionArgs))
partIndex++
}
return true
})
}
// Set parts
if len(parts) > 0 {
out, _ = sjson.Set(out, "candidates.0.content.parts", parts)
}
// Handle finish reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String())
@@ -645,15 +609,12 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
// Handle usage information
if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{
"promptTokenCount": usage.Get("prompt_tokens").Int(),
"candidatesTokenCount": usage.Get("completion_tokens").Int(),
"totalTokenCount": usage.Get("total_tokens").Int(),
}
out, _ = sjson.Set(out, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
out, _ = sjson.Set(out, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
out, _ = sjson.Set(out, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens
out, _ = sjson.Set(out, "usageMetadata.thoughtsTokenCount", reasoningTokens)
}
out, _ = sjson.Set(out, "usageMetadata", usageObj)
}
return out

View File

@@ -163,6 +163,14 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
var chatCompletionsTools []interface{}
tools.ForEach(func(_, tool gjson.Result) bool {
// Built-in tools (e.g. {"type":"web_search"}) are already compatible with the Chat Completions schema.
// Only function tools need structural conversion because Chat Completions nests details under "function".
toolType := tool.Get("type").String()
if toolType != "" && toolType != "function" && tool.IsObject() {
chatCompletionsTools = append(chatCompletionsTools, tool.Value())
return true
}
chatTool := `{"type":"function","function":{}}`
// Convert tool structure from responses format to chat completions format

View File

@@ -484,16 +484,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
}
}
// Build response.output using aggregated buffers
var outputs []interface{}
outputsWrapper := `{"arr":[]}`
if st.ReasoningBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{
"id": st.ReasoningID,
"type": "reasoning",
"summary": []interface{}{map[string]interface{}{
"type": "summary_text",
"text": st.ReasoningBuf.String(),
}},
})
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningID)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
// Append message items in ascending index order
if len(st.MsgItemAdded) > 0 {
@@ -513,18 +509,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
}
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("msg_%s_%d", st.ResponseID, i),
"type": "message",
"status": "completed",
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": txt,
}},
"role": "assistant",
})
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.Set(item, "content.0.text", txt)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if len(st.FuncArgsBuf) > 0 {
@@ -547,18 +535,16 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
}
callID := st.FuncCallIDs[i]
name := st.FuncNames[i]
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("fc_%s", callID),
"type": "function_call",
"status": "completed",
"arguments": args,
"call_id": callID,
"name": name,
})
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if len(outputs) > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs)
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
}
if st.UsageSeen {
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.PromptTokens)
@@ -681,7 +667,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
}
// Build output list from choices[...]
var outputs []interface{}
outputsWrapper := `{"arr":[]}`
// Detect and capture reasoning content if present
rcText := gjson.GetBytes(rawJSON, "choices.0.message.reasoning_content").String()
includeReasoning := rcText != ""
@@ -693,21 +679,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
if strings.HasPrefix(rid, "resp_") {
rid = strings.TrimPrefix(rid, "resp_")
}
reasoningItem := map[string]interface{}{
"id": fmt.Sprintf("rs_%s", rid),
"type": "reasoning",
"encrypted_content": "",
}
// Prefer summary_text from reasoning_content; encrypted_content is optional
var summaries []interface{}
reasoningItem := `{"id":"","type":"reasoning","encrypted_content":"","summary":[]}`
reasoningItem, _ = sjson.Set(reasoningItem, "id", fmt.Sprintf("rs_%s", rid))
if rcText != "" {
summaries = append(summaries, map[string]interface{}{
"type": "summary_text",
"text": rcText,
})
reasoningItem, _ = sjson.Set(reasoningItem, "summary.0.type", "summary_text")
reasoningItem, _ = sjson.Set(reasoningItem, "summary.0.text", rcText)
}
reasoningItem["summary"] = summaries
outputs = append(outputs, reasoningItem)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", reasoningItem)
}
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
@@ -716,18 +695,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
if msg.Exists() {
// Text message part
if c := msg.Get("content"); c.Exists() && c.String() != "" {
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("msg_%s_%d", id, int(choice.Get("index").Int())),
"type": "message",
"status": "completed",
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": c.String(),
}},
"role": "assistant",
})
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", id, int(choice.Get("index").Int())))
item, _ = sjson.Set(item, "content.0.text", c.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
// Function/tool calls
@@ -736,14 +707,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
callID := tc.Get("id").String()
name := tc.Get("function.name").String()
args := tc.Get("function.arguments").String()
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("fc_%s", callID),
"type": "function_call",
"status": "completed",
"arguments": args,
"call_id": callID,
"name": name,
})
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
return true
})
}
@@ -751,8 +720,8 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
return true
})
}
if len(outputs) > 0 {
resp, _ = sjson.Set(resp, "output", outputs)
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
resp, _ = sjson.SetRaw(resp, "output", gjson.Get(outputsWrapper, "arr").Raw)
}
// usage mapping

View File

@@ -6,6 +6,7 @@ package usage
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
@@ -90,7 +91,7 @@ type modelStats struct {
type RequestDetail struct {
Timestamp time.Time `json:"timestamp"`
Source string `json:"source"`
AuthIndex uint64 `json:"auth_index"`
AuthIndex string `json:"auth_index"`
Tokens TokenStats `json:"tokens"`
Failed bool `json:"failed"`
}
@@ -281,6 +282,118 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
return result
}
type MergeResult struct {
Added int64 `json:"added"`
Skipped int64 `json:"skipped"`
}
// MergeSnapshot merges an exported statistics snapshot into the current store.
// Existing data is preserved and duplicate request details are skipped.
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
result := MergeResult{}
if s == nil {
return result
}
s.mu.Lock()
defer s.mu.Unlock()
seen := make(map[string]struct{})
for apiName, stats := range s.apis {
if stats == nil {
continue
}
for modelName, modelStatsValue := range stats.Models {
if modelStatsValue == nil {
continue
}
for _, detail := range modelStatsValue.Details {
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
}
}
}
for apiName, apiSnapshot := range snapshot.APIs {
apiName = strings.TrimSpace(apiName)
if apiName == "" {
continue
}
stats, ok := s.apis[apiName]
if !ok || stats == nil {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[apiName] = stats
} else if stats.Models == nil {
stats.Models = make(map[string]*modelStats)
}
for modelName, modelSnapshot := range apiSnapshot.Models {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
modelName = "unknown"
}
for _, detail := range modelSnapshot.Details {
detail.Tokens = normaliseTokenStats(detail.Tokens)
if detail.Timestamp.IsZero() {
detail.Timestamp = time.Now()
}
key := dedupKey(apiName, modelName, detail)
if _, exists := seen[key]; exists {
result.Skipped++
continue
}
seen[key] = struct{}{}
s.recordImported(apiName, modelName, stats, detail)
result.Added++
}
}
}
return result
}
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
totalTokens := detail.Tokens.TotalTokens
if totalTokens < 0 {
totalTokens = 0
}
s.totalRequests++
if detail.Failed {
s.failureCount++
} else {
s.successCount++
}
s.totalTokens += totalTokens
s.updateAPIStats(stats, modelName, detail)
dayKey := detail.Timestamp.Format("2006-01-02")
hourKey := detail.Timestamp.Hour()
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func dedupKey(apiName, modelName string, detail RequestDetail) string {
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
tokens := normaliseTokenStats(detail.Tokens)
return fmt.Sprintf(
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
apiName,
modelName,
timestamp,
detail.Source,
detail.AuthIndex,
detail.Failed,
tokens.InputTokens,
tokens.OutputTokens,
tokens.ReasoningTokens,
tokens.CachedTokens,
tokens.TotalTokens,
)
}
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
@@ -340,6 +453,16 @@ func normaliseDetail(detail coreusage.Detail) TokenStats {
return tokens
}
func normaliseTokenStats(tokens TokenStats) TokenStats {
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
return tokens
}
func formatHour(hour int) string {
if hour < 0 {
hour = 0

View File

@@ -344,7 +344,7 @@ func cleanupRequiredFields(jsonStr string) string {
}
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
// Claude VALIDATED mode requires at least one property in tool schemas.
// Claude VALIDATED mode requires at least one required property in tool schemas.
func addEmptySchemaPlaceholder(jsonStr string) string {
// Find all "type" fields
paths := findPaths(jsonStr, "type")
@@ -364,6 +364,9 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
// Check if properties exists and is empty or missing
propsPath := joinPath(parentPath, "properties")
propsVal := gjson.Get(jsonStr, propsPath)
reqPath := joinPath(parentPath, "required")
reqVal := gjson.Get(jsonStr, reqPath)
hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
needsPlaceholder := false
if !propsVal.Exists() {
@@ -381,8 +384,22 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
// Add to required array
reqPath := joinPath(parentPath, "required")
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
continue
}
// If schema has properties but none are required, add a minimal placeholder.
if propsVal.IsObject() && !hasRequiredProperties {
// DO NOT add placeholder if it's a top-level schema (parentPath is empty)
// or if we've already added a placeholder reason above.
if parentPath == "" {
continue
}
placeholderPath := joinPath(propsPath, "_")
if !gjson.Get(jsonStr, placeholderPath).Exists() {
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
}
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
}
}

Some files were not shown because too many files have changed in this diff Show More