Compare commits

..

94 Commits

Author SHA1 Message Date
Luis Pater
35fdc4cfd3 fix some bugs (#399)
* feat(config): add pruning of stale YAML mapping keys during config save

* Revert watcher.go in "fix: enable hot reload for amp-model-mappings config"
2025-12-02 22:28:30 +08:00
hkfires
3ebbab0a9a Revert watcher.go in "fix: enable hot reload for amp-model-mappings config" 2025-12-02 22:17:54 +08:00
hkfires
480cd714b2 feat(config): add pruning of stale YAML mapping keys during config save 2025-12-02 21:38:54 +08:00
Luis Pater
41ee44432d **fix(translator): rename responseSchema key for generationConfig**
- Renamed `generationConfig.responseSchema` to `generationConfig.responseJsonSchema` in Gemini request transformation to align with updated schema expectations.
2025-12-02 18:32:23 +08:00
Luis Pater
1434bc38e5 **refactor(registry): remove Qwen3-Coder from model definitions** 2025-12-02 11:34:38 +08:00
Luis Pater
0fd2abbc3b **refactor(cliproxy, config): remove vertex-compat flow, streamline Vertex API key handling**
- Removed `vertex-compat` executor and related configuration.
- Consolidated Vertex compatibility checks into `vertex` handling with `apikey`-based model resolution.
- Streamlined model generation logic for Vertex API key entries.
2025-12-02 09:18:24 +08:00
Aero
0ebb654019 feat: Add support for VertexAI compatible service (#375)
feat: consolidate Vertex AI compatibility with API key support in Gemini
2025-12-02 08:14:22 +08:00
Luis Pater
08a1d2edf9 Merge pull request #390 from NguyenSiTrung/main
feat(amp): add model mapping support for routing unavailable models to alternatives
2025-12-02 08:07:56 +08:00
NguyenSiTrung
3409f4e336 fix: enable hot reload for amp-model-mappings config
- Store ampModule in Server struct to access it during config updates
- Call ampModule.OnConfigUpdated() in UpdateClients() for hot reload
- Watch config directory instead of file to handle atomic saves (vim, VSCode, etc.)
- Improve config file event detection with basename matching
- Add diagnostic logging for config reload tracing
2025-12-01 13:34:49 +07:00
NguyenSiTrung
9354b87e54 Merge branch 'router-for-me:main' into main 2025-12-01 08:12:29 +07:00
Luis Pater
54e24110ec Merge pull request #386 from auroraflux/feat/dedupe-thinking-metadata-helpers
refactor(executor): dedupe thinking metadata helpers across Gemini executors
2025-12-01 09:00:27 +08:00
Luis Pater
717c703bff docs(readme): add CCS (Claude Code Switch) to projects list 2025-12-01 07:22:42 +08:00
auroraflux
1c6f4be8ae refactor(executor): dedupe thinking metadata helpers across Gemini executors
Extract applyThinkingMetadata and applyThinkingMetadataCLI helpers to
payload_helpers.go and use them across all four Gemini-based executors:
- gemini_executor.go (Execute, ExecuteStream, CountTokens)
- gemini_cli_executor.go (Execute, ExecuteStream, CountTokens)
- aistudio_executor.go (translateRequest)
- antigravity_executor.go (Execute, ExecuteStream)

This eliminates code duplication introduced in the -reasoning suffix PR
and centralizes the thinking config application logic.

Net reduction: 28 lines of code.
2025-11-30 15:20:15 -08:00
Luis Pater
0de2560cee Merge pull request #379 from kaitranntt/docs/add-ccs-project
docs: add CCS (Claude Code Switch) to projects list
2025-12-01 07:20:04 +08:00
Kai (Tam Nhu) Tran
85eb926482 fix: change AGY to Antigravity 2025-11-30 12:43:12 -05:00
Kai (Tam Nhu) Tran
c52ef08e67 docs: add CCS to projects list 2025-11-30 12:40:35 -05:00
Luis Pater
cb580cd083 Merge pull request #377 from router-for-me/gemini
feat(registry): add thinking support to gemini models
2025-11-30 21:27:54 +08:00
hkfires
75e278c7a5 feat(registry): add thinking support to gemini models 2025-11-30 20:56:29 +08:00
Luis Pater
73208c4e55 Merge pull request #376 from auroraflux/feat/reasoning-suffix-support
feat(util): add -reasoning suffix support for Gemini models
2025-11-30 20:55:38 +08:00
auroraflux
32d3809f8c **feat(util): add -reasoning suffix support for Gemini models**
Adds support for the `-reasoning` model name suffix which enables
thinking/reasoning mode with dynamic budget. This allows clients to
request reasoning-enabled inference using model names like
`gemini-2.5-flash-reasoning` without explicit configuration.

The suffix is normalized to the base model (e.g., gemini-2.5-flash)
with thinkingBudget=-1 (dynamic) and include_thoughts=true.

Follows the existing pattern established by -nothinking and
-thinking-N suffixes.
2025-11-30 01:18:57 -08:00
Luis Pater
a748e93fd9 **fix(executor, auth): ensure index assignment consistency for auth objects**
- Updated `usage_helpers.go` to call `EnsureIndex()` for proper index assignment in reporter initialization.
- Adjusted `auth/manager.go` to assign auth indices inside a locked section when they are unassigned, ensuring thread safety and consistency.
2025-11-30 16:56:29 +08:00
Luis Pater
54a9c4c3c7 Merge pull request #371 from ben-vargas/test-amp-tools
fix(amp): add /threads.rss root-level route for AMP CLI
2025-11-30 15:18:23 +08:00
Luis Pater
18b5c35dea Merge pull request #366 from router-for-me/blacklist
Add Model Blacklist
2025-11-30 15:17:46 +08:00
hkfires
7b7871ede2 feat(api): add oauth excluded model management 2025-11-30 13:38:23 +08:00
hkfires
c4e3646b75 docs(config): expand model exclusion examples 2025-11-30 11:55:47 +08:00
hkfires
022aa81be1 feat(cliproxy): support wildcard exclusions for models 2025-11-30 08:02:00 +08:00
hkfires
c43f0ea7b1 refactor(config): rename model blacklist fields to excluded models 2025-11-29 21:23:47 +08:00
hkfires
6a191358af fix(auth): fix runtime auth reload on oauth blacklist change 2025-11-29 20:30:11 +08:00
Ben Vargas
db1119dd78 fix(amp): add /threads.rss root-level route for AMP CLI
AMP CLI requests /threads.rss at the root level, but the AMP module
only registered routes under /api/*. This caused a 404 error during
AMP CLI startup.

Add the missing root-level route with the same security middleware
(noCORS, optional localhost restriction) as other management routes.
2025-11-29 05:01:19 -07:00
Trung Nguyen
33a5656235 docs: add model mapping documentation for Amp CLI integration
- Add model mapping feature to README.md Amp CLI section
- Add detailed Model Mapping Configuration section to amp-cli-integration.md
- Update architecture diagram to show model mapping flow
- Update Model Fallback Behavior to include mapping step
- Add Table of Contents entry for model mapping
2025-11-29 12:51:03 +07:00
Trung Nguyen
2cd59806e2 feat(amp): add model mapping support for routing unavailable models to alternatives
- Add AmpModelMapping config to route models like 'claude-opus-4.5' to 'claude-sonnet-4'
- Add ModelMapper interface and DefaultModelMapper implementation with hot-reload support
- Enhance FallbackHandler to apply model mappings before falling back to ampcode.com
- Add structured logging for routing decisions (local provider, mapping, amp credits)
- Update config.example.yaml with amp-model-mappings documentation
2025-11-29 12:44:09 +07:00
hkfires
5983e3ec87 feat(auth): add oauth provider model blacklist 2025-11-28 10:37:10 +08:00
hkfires
f8cebb9343 feat(config): add per-key model blacklist for providers 2025-11-27 21:57:07 +08:00
Luis Pater
72c7ef7647 **fix(translator): handle non-JSON output parsing for OpenAI function responses**
- Updated `antigravity_openai_request.go` to process non-JSON outputs gracefully by verifying and distinguishing between JSON and plain string formats.
- Ensured proper assignment of parsed or raw response to `functionResponse`.
2025-11-27 16:18:49 +08:00
Luis Pater
d2e4639b2a **feat(registry): add context length and update max tokens for Claude model configurations**
- Added `ContextLength` field with a value of 200,000 to all applicable Claude model definitions.
- Standardized `MaxCompletionTokens` values across models for consistency and alignment.
2025-11-27 16:13:25 +08:00
Luis Pater
08321223c4 Merge pull request #340 from nestharus/fix/339-thinking-openai-gemini-compat
fix(thinking): resolve OpenAI/Gemini compatibility for thinking model…
2025-11-27 16:03:24 +08:00
Luis Pater
7e30157590 Fixed: #354
**fix(translator): add support for "xhigh" reasoning effort in OpenAI responses**

- Updated handling in `openai_openai-responses_request.go` to include the new "xhigh" reasoning effort level.
2025-11-27 15:59:15 +08:00
nestharus
e73cdf5cff fix(claude): ensure max_tokens exceeds thinking budget for thinking models
Fixes an issue where Claude thinking models would return 400 errors when
the thinking.budget_tokens was greater than or equal to max_tokens.

Changes:
- Add MaxCompletionTokens: 128000 to all Claude thinking model definitions
- Add ensureMaxTokensForThinking() function in claude_executor.go that:
  - Checks if thinking is enabled with a budget_tokens value
  - Looks up the model's MaxCompletionTokens from the registry
  - Ensures max_tokens is set to at least the model's MaxCompletionTokens
  - Falls back to budget_tokens + 4000 buffer if registry lookup fails

This ensures Anthropic API constraint (max_tokens > thinking.budget_tokens)
is always satisfied when using extended thinking features.

Fixes: #339

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-26 22:31:05 -08:00
Luis Pater
39621a0340 **fix(translator): normalize function calls and outputs for consistent input processing**
- Implemented logic to pair consecutive function calls and their outputs, ensuring proper sequencing for processing.
- Adjusted `gemini_openai-responses_request.go` to normalize message structures and maintain expected flow.
2025-11-27 10:25:45 +08:00
Luis Pater
346b663079 **fix(translator): handle non-JSON output gracefully in function call outputs**
- Updated handling of `output` in `gemini_openai-responses_request.go` to use `.Str` instead of `.Raw` when parsing non-JSON string outputs.
- Added checks to distinguish between JSON and non-JSON `output` types for accurate `functionResponse` construction.
2025-11-27 09:40:00 +08:00
Luis Pater
0bcae68c6c **fix(translator): preserve raw JSON encoding in function call outputs**
- Updated handling of `output` in `gemini_openai-responses_request.go` to use `.Raw` instead of `.String` for preserving original JSON encoding.
- Ensured proper setting of raw JSON output when constructing `functionResponse`.
2025-11-27 08:26:53 +08:00
Luis Pater
c8cee547fd **fix(translator): ensure partial content is retained while skipping encrypted thoughtSignature**
- Updated handling of `thoughtSignature` across all translator modules to retain other content payloads if present.
- Adjusted logic for `thought_signature` and `inline_data` keys for consistent processing.
2025-11-27 00:52:17 +08:00
Luis Pater
36755421fe Merge pull request #343 from router-for-me/misc
style(amp): tidy whitespace in proxy module and tests
2025-11-26 19:03:07 +08:00
hkfires
6c17dbc4da style(amp): tidy whitespace in proxy module and tests 2025-11-26 18:57:26 +08:00
Luis Pater
ee6429cc75 **feat(registry): add Gemini 3 Pro Image Preview model and remove Claude Sonnet 4.5 Thinking**
- Added new `Gemini 3 Pro Image Preview` model with detailed metadata and configuration.
- Removed outdated `Claude Sonnet 4.5 Thinking` model definition for cleanup and relevance.
2025-11-26 18:22:40 +08:00
Luis Pater
a4a26d978e Fixed: #339
**feat(handlers, executor): add Gemini 3 Pro Preview support and refine Claude system instructions**

- Added support for the new "Gemini 3 Pro Preview" action in Gemini handlers, including detailed metadata and configuration.
- Removed redundant `cache_control` field from Claude system instructions for cleaner payload structure.
2025-11-26 11:42:57 +08:00
Luis Pater
ed9f6e897e Fixed: #337
**fix(executor): replace redundant commented code with `checkSystemInstructions` helper**

- Replaced commented-out `sjson.SetRawBytes` lines with the new `checkSystemInstructions` function.
- Centralized system instruction handling for better code clarity and reuse.
- Ensured consistent logic for managing `system` field across Claude executor flows.
2025-11-26 08:27:48 +08:00
Luis Pater
9c1e3c0687 Merge pull request #334 from nestharus/feat/claude-thinking-and-beta-headers
feat(claude): add thinking model variants and beta headers support
2025-11-26 02:17:02 +08:00
Luis Pater
2e5681ea32 Merge branch 'dev' into feat/claude-thinking-and-beta-headers 2025-11-26 02:16:40 +08:00
Luis Pater
52c17f03a5 **fix(executor): comment out redundant code for setting Claude system instructions**
- Commented out multiple instances of `sjson.SetRawBytes` for setting `system` key to Claude instructions as they are redundant.
- Code cleanup to improve clarity and maintainability without affecting functionality.
2025-11-26 02:06:16 +08:00
nestharus
d0e694d4ed feat(claude): add thinking model variants and beta headers support
- Add Claude thinking model definitions (sonnet-4-5-thinking, opus-4-5-thinking variants)
- Add Thinking support for antigravity models with -thinking suffix
- Add injectThinkingConfig() for automatic thinking budget based on model suffix
- Add resolveUpstreamModel() mappings for thinking variants to actual Claude models
- Add extractAndRemoveBetas() to convert betas array to anthropic-beta header
- Update applyClaudeHeaders() to merge custom betas from request body

Closes #324
2025-11-25 03:33:05 -08:00
Luis Pater
506f1117dd **fix(handlers): refactor API response capture to append data safely**
- Introduced `appendAPIResponse` helper to preserve and append data to existing API responses.
- Ensured newline inclusion when appending, if necessary.
- Improved `nil` and data type checks for response handling.
- Updated middleware to skip request logging for `GET` requests.
2025-11-25 11:37:02 +08:00
Luis Pater
113db3c5bf **fix(executor): update antigravity executor to enhance model metadata handling**
- Added additional metadata fields (`Name`, `Description`, `DisplayName`, `Version`) to `ModelInfo` struct initialization for better model representation.
- Removed unnecessary whitespace in the code.
2025-11-25 09:19:01 +08:00
Luis Pater
1aa0b6cd11 Merge pull request #322 from ben-vargas/feat-claude-opus-4-5
feat(registry): add Claude Opus 4.5 model definition
2025-11-25 08:38:06 +08:00
Ben Vargas
0895533400 fix(registry): correct Claude Opus 4.5 created timestamp
Update epoch from 1730419200 (2024-11-01) to 1761955200 (2025-11-01).
2025-11-24 12:27:23 -07:00
Ben Vargas
43f007c234 feat(registry): add Claude Opus 4.5 model definition
Add support for claude-opus-4-5-20251101 with 200K context window
and 64K max output tokens.
2025-11-24 12:26:39 -07:00
Luis Pater
0ceee56d99 Merge pull request #318 from router-for-me/log
feat(logs): add limit query param to cap returned logs
2025-11-24 20:35:28 +08:00
hkfires
943a8c74df feat(logs): add limit query param to cap returned logs 2025-11-24 19:59:24 +08:00
Luis Pater
0a47b452e9 **fix(translator): add conditional check for key renaming in Gemini tools**
- Ensured `functionDeclarations` key renaming only occurs if the key exists in Gemini tools processing.
- Prevented unnecessary JSON reassignment when the target key is absent.
2025-11-24 17:15:43 +08:00
Luis Pater
261f08a82a **fix(translator): adjust key renaming logic in Gemini request processing**
- Fixed parameter key renaming to correctly handle `functionDeclarations` and `parametersJsonSchema` in Gemini tools.
- Resolved potential overwriting issue by reassigning JSON strings after each key rename.
2025-11-24 17:12:04 +08:00
Luis Pater
d114d8d0bd **feat(config): add TLS support for HTTPS server configuration**
- Introduced `TLSConfig` to support HTTPS configurations, including enabling TLS, specifying certificate and key files.
- Updated HTTP server logic to handle HTTPS mode when TLS is enabled.
- Enhanced `config.example.yaml` with TLS settings example.
- Adjusted internal URL generation to respect protocol based on TLS state.
2025-11-24 10:41:29 +08:00
Luis Pater
bb9955e461 **fix(auth): resolve index reassignment issue during auth management**
- Fixed improper handling of `indexAssigned` and `Index` during auth reassignment.
- Ensured `EnsureIndex` is invoked after validating existing auth entries.
2025-11-24 10:10:09 +08:00
Luis Pater
7063a176f4 #293
**feat(retry): add configurable retry logic with cooldown support**

- Introduced `max-retry-interval` configuration for cooldown durations between retries.
- Added `SetRetryConfig` in `Manager` to handle retry attempts and cooldown intervals.
- Enhanced provider execution logic to include retry attempts, cooldown management, and dynamic wait periods.
- Updated API endpoints and YAML configuration to support `max-retry-interval`.
2025-11-24 09:55:15 +08:00
Luis Pater
e3082887a6 **feat(logging, middleware): add error-based logging support and error log management**
- Introduced `logOnErrorOnly` mode to enable logging only for error responses when request logging is disabled.
- Added endpoints to list and download error logs (`/request-error-logs`).
- Implemented error log file cleanup to retain only the newest 10 logs.
- Refactored `ResponseWriterWrapper` to support forced logging for error responses.
- Enhanced middleware to capture data for upstream error persistence.
- Improved log file naming and error log filename generation.
2025-11-23 22:41:57 +08:00
Luis Pater
ddb0c0ec1c **fix(translator): reintroduce thoughtSignature bypass logic for model parts**
- Restored `thoughtSignature` validator bypass for model-specific parts in Gemini content processing.
- Removed redundant logic from the `executor` for cleaner handling.
2025-11-23 20:52:23 +08:00
Luis Pater
d1736cb29c Merge pull request #315 from router-for-me/aistudio
fix(aistudio): strip Gemini generation config overrides
2025-11-23 20:25:59 +08:00
hkfires
62bfd62871 fix(aistudio): strip Gemini generation config overrides
Remove generationConfig.maxOutputTokens, generationConfig.responseMimeType and generationConfig.responseJsonSchema from the Gemini payload in translateRequest so we no longer send unsupported or conflicting response configuration fields. This lets the backend or caller control response formatting and output limits and helps prevent potential API errors caused by these keys.
2025-11-23 19:44:03 +08:00
Luis Pater
257621c5ed **chore(executor): update default agent version and simplify const formatting**
- Updated `defaultAntigravityAgent` to version `1.11.5`.
- Adjusted const value formatting for improved readability.

**feat(executor): introduce fallback mechanism for Antigravity base URLs**

- Added retry logic with fallback order for Antigravity base URLs to handle request errors and rate limits.
- Refactored base URL handling with `antigravityBaseURLFallbackOrder` and related utilities.
- Enhanced error handling in non-streaming and streaming requests with retry support and improved metadata reporting.
- Updated `buildRequest` to support dynamic base URL assignment.
2025-11-23 17:53:07 +08:00
Luis Pater
ac064389ca **feat(executor, translator): enhance token handling and payload processing**
- Improved Antigravity executor to handle `thinkingConfig` adjustments and default `thinkingBudget` when `thinkingLevel` is removed.
- Updated translator response handling to set default values for output token counts when specific token data is missing.
2025-11-23 11:32:37 +08:00
Luis Pater
8d23ffc873 **feat(executor): add model alias mapping and improve Antigravity payload handling**
- Introduced `modelName2Alias` and `alias2ModelName` functions for mapping between model names and aliases.
- Improved Antigravity payload transformation to include alias-to-model name conversion.
- Enhanced processing for Claude Sonnet models to adjust template parameters based on schema presence.
2025-11-23 03:16:14 +08:00
Luis Pater
4307f08bbc **feat(watcher): optimize auth file handling with hash-based change detection**
- Added `authFileUnchanged` to skip reloads for unchanged files based on SHA256 hash comparisons.
- Introduced `isKnownAuthFile` to verify known files before handling removal events.
- Improved event processing in `handleEvent` to reduce unnecessary reloads and enhance performance.
2025-11-23 01:22:16 +08:00
Luis Pater
9d50a68768 **feat(translator): improve content processing and Antigravity request conversion**
- Refactored response translation logic to support mixed content types (`input_text`, `output_text`, `input_image`) with better role assignments and part handling.
- Added image processing logic for embedding inline data with MIME type and base64 encoded content.
- Updated Antigravity request conversion to replace Gemini CLI references for consistency.
2025-11-22 21:34:34 +08:00
Luis Pater
7c3c24addc Merge pull request #306 from router-for-me/usage
fix some bugs
2025-11-22 17:45:49 +08:00
hkfires
166fa9e2e6 fix(gemini): parse stream usage from JSON, skip thoughtSignature 2025-11-22 16:07:12 +08:00
hkfires
88e566281e fix(gemini): filter SSE usage metadata in streams 2025-11-22 15:53:36 +08:00
hkfires
d32bb9db6b fix(runtime): treat non-empty finishReason as terminal 2025-11-22 15:39:46 +08:00
hkfires
8356b35320 fix(executor): expire stop chunks without usage metadata 2025-11-22 15:27:47 +08:00
hkfires
19a048879c feat(runtime): track antigravity usage and token counts 2025-11-22 14:04:28 +08:00
hkfires
1061354b2f fix: handle empty and non-JSON SSE chunks safely 2025-11-22 13:49:23 +08:00
hkfires
46b4110ff3 fix: preserve SSE usage metadata-only trailing chunks 2025-11-22 13:25:25 +08:00
hkfires
c29931e093 fix(translator): ignore empty JSON chunks in OpenAI responses 2025-11-22 13:09:16 +08:00
hkfires
b05cfd9f84 fix(translator): include empty text chunks in responses 2025-11-22 13:03:50 +08:00
hkfires
8ce22b8403 fix(sse): preserve usage metadata for stop chunks 2025-11-22 12:50:23 +08:00
Luis Pater
d1cdedc4d1 Merge pull request #303 from router-for-me/image
feat(translator): support image size and googleSearch tools
2025-11-22 11:20:58 +08:00
Luis Pater
d291eb9489 Fixed: #302
**feat(executor): enhance WebSocket error handling and metadata logging**

- Added handling for stream closure before start with appropriate error recording.
- Improved metadata logging for non-OK HTTP status codes in WebSocket responses.
- Consolidated event processing logic with `processEvent` for better error handling and payload management.
- Refactored stream initialization to include the first event handling for smoother execution flow.
2025-11-22 11:18:13 +08:00
hkfires
dc8d3201e1 feat(translator): support image size and googleSearch tools 2025-11-22 10:36:52 +08:00
Luis Pater
7757210af6 **feat(auth): implement Antigravity OAuth authentication flow**
- Added new endpoint `/antigravity-auth-url` to initiate Antigravity authentication.
- Implemented `RequestAntigravityToken` to manage the OAuth flow, including token exchange and user info retrieval.
- Introduced `.oauth-antigravity` temporary file handling for state and code management.
- Added `sanitizeAntigravityFileName` utility for safe token file names based on user email.
- Registered `/antigravity/callback` endpoint for OAuth redirects.
2025-11-22 01:45:06 +08:00
Luis Pater
cbf9a57135 **build(goreleaser): set CGO_ENABLED=0 for cli-proxy-api binaries**
- Disabled CGO to produce statically linked binaries.
- Minor formatting adjustment for newline at EOF.
2025-11-21 23:59:02 +08:00
Luis Pater
c1031e2d3f **feat(translator): add Antigravity translation logic**
- Introduced request and response translation functions to enable compatibility between OpenAI Chat Completions API and Antigravity.
- Registered translation utilities for both streaming and non-streaming scenarios.
- Added support for reasoning content, tool calls, and metadata handling.
- Established request normalization and embedding for Antigravity-compatible payloads.
- Added new fields to `Params` struct for better tracking of finish reasons, usage metadata, and tool usage.
- Refactored handling of response transitions, final events, and state-driven logic in `ConvertAntigravityResponseToClaude`.
- Introduced `appendFinalEvents` and `resolveStopReason` helper functions for cleaner separation of concerns.
- Added `TotalTokenCount` field to `Params` struct for enhanced token tracking.
- Updated token count calculations to fallback on `TotalTokenCount` when specific counts are missing.
- Introduced `hasNonZeroUsageMetadata` function to validate presence of token data in `usage_metadata`.
2025-11-21 23:40:59 +08:00
Luis Pater
327cc7039e **refactor(auth): use customizable HTTP client for Antigravity requests**
- Replaced `http.DefaultClient` with a configurable `http.Client` instance for Antigravity OAuth flow methods.
- Updated `exchangeAntigravityCode` and `fetchAntigravityUserInfo` to accept `httpClient` as a parameter.
- Added `util.SetProxy` usage to initialize the `httpClient` with proxy support.
2025-11-21 20:54:56 +08:00
Luis Pater
b4d15ace91 Merge pull request #296 from router-for-me/antigravity
Antigravity bugfix
2025-11-21 17:32:36 +08:00
hkfires
abc2465b29 fix(gemini-cli): ignore thoughtSignature and empty parts 2025-11-21 17:12:56 +08:00
hkfires
4ba5b43d82 feat(executor): share SSE usage filtering across streams 2025-11-21 16:51:05 +08:00
hkfires
27faf718a3 fix(auth): use fixed antigravity callback port 51121 2025-11-21 13:56:33 +08:00
69 changed files with 6068 additions and 614 deletions

5
.gitignore vendored
View File

@@ -15,6 +15,7 @@ pgstore/*
gitstore/*
objectstore/*
static/*
refs/*
# Authentication data
auths/*
@@ -30,3 +31,7 @@ GEMINI.md
.vscode/*
.claude/*
.serena/*
# macOS
.DS_Store
._*

View File

@@ -1,5 +1,7 @@
builds:
- id: "cli-proxy-api"
env:
- CGO_ENABLED=0
goos:
- linux
- windows
@@ -34,4 +36,4 @@ changelog:
filters:
exclude:
- '^docs:'
- '^test:'
- '^test:'

View File

@@ -56,6 +56,7 @@ CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and A
- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`)
- Management proxy for OAuth authentication and account features
- Smart model fallback with automatic routing
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5``claude-sonnet-4`)
- Security-first design with localhost-only management endpoints
**→ [Complete Amp CLI Integration Guide](docs/amp-cli-integration.md)**
@@ -90,6 +91,10 @@ Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with A
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.

View File

@@ -89,6 +89,10 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型Gemini, Codex, Antigravity无需 API 密钥。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。

View File

@@ -1,6 +1,12 @@
# Server port
port: 8317
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
tls:
enable: false
cert: ""
key: ""
# Management API settings
remote-management:
# Whether to allow remote (non-localhost) management access.
@@ -38,6 +44,9 @@ proxy-url: ""
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
request-retry: 3
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
max-retry-interval: 30
# Quota exceeded behavior
quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
@@ -46,6 +55,28 @@ quota-exceeded:
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
# Amp CLI Integration
# Configure upstream URL for Amp CLI OAuth and management features
#amp-upstream-url: "https://ampcode.com"
# Optional: Override API key for Amp upstream (otherwise uses env or file)
#amp-upstream-api-key: ""
# Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
#amp-restrict-management-to-localhost: true
# Amp Model Mappings
# Route unavailable Amp models to alternative models available in your local proxy.
# Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
# but you have a similar model available (e.g., Claude Sonnet 4).
#amp-model-mappings:
# - from: "claude-opus-4.5" # Model requested by Amp CLI
# to: "claude-sonnet-4" # Route to this available model instead
# - from: "gpt-5"
# to: "gemini-2.5-pro"
# - from: "claude-3-opus-20240229"
# to: "claude-3-5-sonnet-20241022"
# Gemini API keys (preferred)
#gemini-api-key:
# - api-key: "AIzaSy...01"
@@ -53,6 +84,11 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# excluded-models:
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
# - api-key: "AIzaSy...02"
# API keys for official Generative Language API (legacy compatibility)
@@ -67,6 +103,11 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# excluded-models:
# - "gpt-5.1" # exclude specific models (exact match)
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
# - "*-mini" # wildcard matching suffix (e.g. gpt-5-codex-mini)
# - "*codex*" # wildcard matching substring (e.g. gpt-5-codex-low)
# Claude API keys
#claude-api-key:
@@ -79,6 +120,11 @@ ws-auth: false
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# excluded-models:
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
# - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
# OpenAI compatibility providers
#openai-compatibility:
@@ -99,6 +145,19 @@ ws-auth: false
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
#vertex-api-key:
# - api-key: "vk-123..." # x-goog-api-key header
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names
# - name: "gemini-2.0-flash" # upstream model name
# alias: "vertex-flash" # client-visible alias
# - name: "gemini-1.5-pro"
# alias: "vertex-pro"
#payload: # Optional payload configuration
# default: # Default rules only set parameters when they are missing in the payload.
# - models:
@@ -112,3 +171,25 @@ ws-auth: false
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
# params: # JSON path (gjson/sjson syntax) -> value
# "reasoning.effort": "high"
# OAuth provider excluded models
#oauth-excluded-models:
# gemini-cli:
# - "gemini-2.5-pro" # exclude specific models (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
# vertex:
# - "gemini-3-pro-preview"
# aistudio:
# - "gemini-3-pro-preview"
# antigravity:
# - "gemini-3-pro-preview"
# claude:
# - "claude-3-5-haiku-20241022"
# codex:
# - "gpt-5-codex-mini"
# qwen:
# - "vision-model"
# iflow:
# - "tstars2.0"

View File

@@ -19,6 +19,7 @@ services:
- "8085:8085"
- "1455:1455"
- "54545:54545"
- "51121:51121"
- "11451:11451"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml

View File

@@ -8,6 +8,7 @@ This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions,
- [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate)
- [Architecture](#architecture)
- [Configuration](#configuration)
- [Model Mapping Configuration](#model-mapping-configuration)
- [Setup](#setup)
- [Usage](#usage)
- [Troubleshooting](#troubleshooting)
@@ -21,6 +22,7 @@ The Amp CLI integration adds specialized routing to support Amp's API patterns w
- **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers
- **Management proxy**: Forwards OAuth and account management requests to Amp's control plane
- **Smart fallback**: Automatically routes unconfigured models to ampcode.com
- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5``claude-sonnet-4`)
- **Secret management**: Configurable precedence (config > env > file) with 5-minute caching
- **Security-first**: Management routes restricted to localhost by default
- **Automatic gzip handling**: Decompresses responses from Amp upstream
@@ -75,7 +77,10 @@ Amp CLI/IDE
│ ↓
│ ├─ Model configured locally?
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
│ │ NO → Forward to ampcode.com (reverse proxy)
│ │ NO
│ │ ├─ Model mapping configured?
│ │ │ YES → Rewrite model → Use local handler (free)
│ │ │ NO → Forward to ampcode.com (uses Amp credits)
│ ↓
│ Response
@@ -115,6 +120,49 @@ amp-upstream-url: "https://ampcode.com"
amp-restrict-management-to-localhost: true
```
### Model Mapping Configuration
When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally.
```yaml
# Route unavailable models to alternatives
amp-model-mappings:
# Example: Route Claude Opus 4.5 requests to Claude Sonnet 4
- from: "claude-opus-4.5"
to: "claude-sonnet-4"
# Example: Route GPT-5 requests to Gemini 2.5 Pro
- from: "gpt-5"
to: "gemini-2.5-pro"
# Example: Map older model names to newer versions
- from: "claude-3-opus-20240229"
to: "claude-3-5-sonnet-20241022"
```
**How it works:**
1. Amp CLI requests a model (e.g., `claude-opus-4.5`)
2. CLIProxyAPI checks if a local provider is available for that model
3. If not available, it checks the model mappings
4. If a mapping exists, the request is rewritten to use the target model
5. The request is then handled locally (free, using your OAuth subscription)
**Benefits:**
- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com
- **Hot-reload**: Mappings can be updated without restarting the proxy
- **Structured logging**: Clear logs show when mappings are applied
**Routing Decision Logs:**
The proxy logs each routing decision with structured fields:
```
[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free)
[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free)
[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits)
```
### Secret Resolution Precedence
The Amp module resolves API keys using this precedence order:
@@ -301,11 +349,14 @@ When Amp requests a model:
1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider?
2. **If YES**: Route to local handler (use your OAuth subscription)
3. **If NO**: Forward to ampcode.com (use Amp's default routing)
3. **If NO**: Check if a model mapping exists
4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free)
5. **If no mapping**: Forward to ampcode.com (uses Amp credits)
This enables seamless mixed usage:
- Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions
- Models you haven't configured → Amp's default providers
- Models with mappings configured → Routed to alternative local models (free)
- Models you haven't configured and have no mapping → Amp's default providers (uses credits)
### Example API Calls

View File

@@ -220,6 +220,14 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
log.Infof("callback forwarder on port %d stopped", port)
}
func sanitizeAntigravityFileName(email string) string {
if strings.TrimSpace(email) == "" {
return "antigravity.json"
}
replacer := strings.NewReplacer("@", "_", ".", "_")
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
}
func (h *Handler) managementCallbackURL(path string) (string, error) {
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
return "", fmt.Errorf("server port is not configured")
@@ -227,7 +235,11 @@ func (h *Handler) managementCallbackURL(path string) (string, error) {
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("http://127.0.0.1:%d%s", h.cfg.Port, path), nil
scheme := "http"
if h.cfg.TLS.Enable {
scheme = "https"
}
return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, h.cfg.Port, path), nil
}
func (h *Handler) ListAuthFiles(c *gin.Context) {
@@ -1284,6 +1296,222 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
const (
antigravityCallbackPort = 51121
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
)
var antigravityScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
ctx := context.Background()
fmt.Println("Initializing Antigravity authentication...")
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Fatalf("Failed to generate state parameter: %v", errState)
return
}
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", antigravityClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(antigravityScopes, " "))
params.Set("state", state)
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
isWebUI := isWebUIRequest(c)
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute antigravity callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarder(antigravityCallbackPort)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
oauthStatus[state] = "OAuth flow timed out"
return
}
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
var payload map[string]string
_ = json.Unmarshal(data, &payload)
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
oauthStatus[state] = "Authentication failed"
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
log.Errorf("Authentication failed: state mismatch")
oauthStatus[state] = "Authentication failed: state mismatch"
return
}
authCode = strings.TrimSpace(payload["code"])
if authCode == "" {
log.Error("Authentication failed: code not found")
oauthStatus[state] = "Authentication failed: code not found"
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
form := url.Values{}
form.Set("code", authCode)
form.Set("client_id", antigravityClientID)
form.Set("client_secret", antigravityClientSecret)
form.Set("redirect_uri", redirectURI)
form.Set("grant_type", "authorization_code")
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
if errNewRequest != nil {
log.Errorf("Failed to build token request: %v", errNewRequest)
oauthStatus[state] = "Failed to build token request"
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.Errorf("Failed to execute token request: %v", errDo)
oauthStatus[state] = "Failed to exchange token"
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange close error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body)
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
return
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
log.Errorf("Failed to parse token response: %v", errDecode)
oauthStatus[state] = "Failed to parse token response"
return
}
email := ""
if strings.TrimSpace(tokenResp.AccessToken) != "" {
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errInfoReq != nil {
log.Errorf("Failed to build user info request: %v", errInfoReq)
oauthStatus[state] = "Failed to build user info request"
return
}
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
infoResp, errInfo := httpClient.Do(infoReq)
if errInfo != nil {
log.Errorf("Failed to execute user info request: %v", errInfo)
oauthStatus[state] = "Failed to execute user info request"
return
}
defer func() {
if errClose := infoResp.Body.Close(); errClose != nil {
log.Errorf("antigravity user info close error: %v", errClose)
}
}()
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
var infoPayload struct {
Email string `json:"email"`
}
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
email = strings.TrimSpace(infoPayload.Email)
}
} else {
bodyBytes, _ := io.ReadAll(infoResp.Body)
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
return
}
}
now := time.Now()
metadata := map[string]any{
"type": "antigravity",
"access_token": tokenResp.AccessToken,
"refresh_token": tokenResp.RefreshToken,
"expires_in": tokenResp.ExpiresIn,
"timestamp": now.UnixMilli(),
"expired": now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
}
if email != "" {
metadata["email"] = email
}
fileName := sanitizeAntigravityFileName(email)
label := strings.TrimSpace(email)
if label == "" {
label = "antigravity"
}
record := &coreauth.Auth{
ID: fileName,
Provider: "antigravity",
FileName: fileName,
Label: label,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save token to file: %v", errSave)
oauthStatus[state] = "Failed to save token to file"
return
}
delete(oauthStatus, state)
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use Antigravity services through this CLI")
}()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()

View File

@@ -172,6 +172,14 @@ func (h *Handler) PutRequestRetry(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v })
}
// Max retry interval
func (h *Handler) GetMaxRetryInterval(c *gin.Context) {
c.JSON(200, gin.H{"max-retry-interval": h.cfg.MaxRetryInterval})
}
func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
}
// Proxy URL
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
func (h *Handler) PutProxyURL(c *gin.Context) {

View File

@@ -223,6 +223,7 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) {
value.APIKey = strings.TrimSpace(value.APIKey)
value.BaseURL = strings.TrimSpace(value.BaseURL)
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
if value.APIKey == "" {
// Treat empty API key as delete.
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
@@ -504,6 +505,91 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
c.JSON(400, gin.H{"error": "missing name or index"})
}
// oauth-excluded-models: map[string][]string
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
}
func (h *Handler) PutOAuthExcludedModels(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var entries map[string][]string
if err = json.Unmarshal(data, &entries); err != nil {
var wrapper struct {
Items map[string][]string `json:"items"`
}
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
entries = wrapper.Items
}
h.cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries)
h.persist(c)
}
func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) {
var body struct {
Provider *string `json:"provider"`
Models []string `json:"models"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Provider == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
provider := strings.ToLower(strings.TrimSpace(*body.Provider))
if provider == "" {
c.JSON(400, gin.H{"error": "invalid provider"})
return
}
normalized := config.NormalizeExcludedModels(body.Models)
if len(normalized) == 0 {
if h.cfg.OAuthExcludedModels == nil {
c.JSON(404, gin.H{"error": "provider not found"})
return
}
if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
c.JSON(404, gin.H{"error": "provider not found"})
return
}
delete(h.cfg.OAuthExcludedModels, provider)
if len(h.cfg.OAuthExcludedModels) == 0 {
h.cfg.OAuthExcludedModels = nil
}
h.persist(c)
return
}
if h.cfg.OAuthExcludedModels == nil {
h.cfg.OAuthExcludedModels = make(map[string][]string)
}
h.cfg.OAuthExcludedModels[provider] = normalized
h.persist(c)
}
func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
provider := strings.ToLower(strings.TrimSpace(c.Query("provider")))
if provider == "" {
c.JSON(400, gin.H{"error": "missing provider"})
return
}
if h.cfg.OAuthExcludedModels == nil {
c.JSON(404, gin.H{"error": "provider not found"})
return
}
if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
c.JSON(404, gin.H{"error": "provider not found"})
return
}
delete(h.cfg.OAuthExcludedModels, provider)
if len(h.cfg.OAuthExcludedModels) == 0 {
h.cfg.OAuthExcludedModels = nil
}
h.persist(c)
}
// codex-api-key: []CodexKey
func (h *Handler) GetCodexKeys(c *gin.Context) {
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
@@ -533,6 +619,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if entry.BaseURL == "" {
continue
}
@@ -557,6 +644,7 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
value.BaseURL = strings.TrimSpace(value.BaseURL)
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
value.Headers = config.NormalizeHeaders(value.Headers)
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
// If base-url becomes empty, delete instead of update
if value.BaseURL == "" {
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
@@ -694,6 +782,7 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}

View File

@@ -58,8 +58,14 @@ func (h *Handler) GetLogs(c *gin.Context) {
return
}
limit, errLimit := parseLimit(c.Query("limit"))
if errLimit != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid limit: %v", errLimit)})
return
}
cutoff := parseCutoff(c.Query("after"))
acc := newLogAccumulator(cutoff)
acc := newLogAccumulator(cutoff, limit)
for i := range files {
if errProcess := acc.consumeFile(files[i]); errProcess != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file %s: %v", files[i], errProcess)})
@@ -139,6 +145,126 @@ func (h *Handler) DeleteLogs(c *gin.Context) {
})
}
// GetRequestErrorLogs lists error request log files when RequestLog is disabled.
// It returns an empty list when RequestLog is enabled.
func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
if h.cfg.RequestLog {
c.JSON(http.StatusOK, gin.H{"files": []any{}})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusOK, gin.H{"files": []any{}})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list request error logs: %v", err)})
return
}
type errorLog struct {
Name string `json:"name"`
Size int64 `json:"size"`
Modified int64 `json:"modified"`
}
files := make([]errorLog, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
continue
}
info, errInfo := entry.Info()
if errInfo != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log info for %s: %v", name, errInfo)})
return
}
files = append(files, errorLog{
Name: name,
Size: info.Size(),
Modified: info.ModTime().Unix(),
})
}
sort.Slice(files, func(i, j int) bool { return files[i].Modified > files[j].Modified })
c.JSON(http.StatusOK, gin.H{"files": files})
}
// DownloadRequestErrorLog downloads a specific error request log file by name.
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
name := strings.TrimSpace(c.Param("name"))
if name == "" || strings.Contains(name, "/") || strings.Contains(name, "\\") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file name"})
return
}
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
dirAbs, errAbs := filepath.Abs(dir)
if errAbs != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
return
}
fullPath := filepath.Clean(filepath.Join(dirAbs, name))
prefix := dirAbs + string(os.PathSeparator)
if !strings.HasPrefix(fullPath, prefix) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
return
}
info, errStat := os.Stat(fullPath)
if errStat != nil {
if os.IsNotExist(errStat) {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
return
}
if info.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
return
}
c.FileAttachment(fullPath, name)
}
func (h *Handler) logDirectory() string {
if h == nil {
return ""
@@ -194,16 +320,22 @@ func (h *Handler) collectLogFiles(dir string) ([]string, error) {
type logAccumulator struct {
cutoff int64
limit int
lines []string
total int
latest int64
include bool
}
func newLogAccumulator(cutoff int64) *logAccumulator {
func newLogAccumulator(cutoff int64, limit int) *logAccumulator {
capacity := 256
if limit > 0 && limit < capacity {
capacity = limit
}
return &logAccumulator{
cutoff: cutoff,
lines: make([]string, 0, 256),
limit: limit,
lines: make([]string, 0, capacity),
}
}
@@ -215,7 +347,9 @@ func (acc *logAccumulator) consumeFile(path string) error {
}
return err
}
defer file.Close()
defer func() {
_ = file.Close()
}()
scanner := bufio.NewScanner(file)
buf := make([]byte, 0, logScannerInitialBuffer)
@@ -239,12 +373,19 @@ func (acc *logAccumulator) addLine(raw string) {
if ts > 0 {
acc.include = acc.cutoff == 0 || ts > acc.cutoff
if acc.cutoff == 0 || acc.include {
acc.lines = append(acc.lines, line)
acc.append(line)
}
return
}
if acc.cutoff == 0 || acc.include {
acc.lines = append(acc.lines, line)
acc.append(line)
}
}
func (acc *logAccumulator) append(line string) {
acc.lines = append(acc.lines, line)
if acc.limit > 0 && len(acc.lines) > acc.limit {
acc.lines = acc.lines[len(acc.lines)-acc.limit:]
}
}
@@ -267,6 +408,21 @@ func parseCutoff(raw string) int64 {
return ts
}
func parseLimit(raw string) (int, error) {
value := strings.TrimSpace(raw)
if value == "" {
return 0, nil
}
limit, err := strconv.Atoi(value)
if err != nil {
return 0, fmt.Errorf("must be a positive integer")
}
if limit <= 0 {
return 0, fmt.Errorf("must be greater than zero")
}
return limit, nil
}
func parseTimestamp(line string) int64 {
if strings.HasPrefix(line, "[") {
line = line[1:]

View File

@@ -6,6 +6,7 @@ package middleware
import (
"bytes"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
@@ -15,8 +16,8 @@ import (
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. If logging is disabled in the
// logger, the middleware has minimal overhead.
// and uses the provided RequestLogger to record this data. When logging is disabled in the
// logger, it still captures data so that upstream errors can be persisted.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
if logger == nil {
@@ -24,14 +25,13 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
path := c.Request.URL.Path
if !shouldLogRequest(path) {
if c.Request.Method == http.MethodGet {
c.Next()
return
}
// Early return if logging is disabled (zero overhead)
if !logger.IsEnabled() {
path := c.Request.URL.Path
if !shouldLogRequest(path) {
c.Next()
return
}
@@ -47,6 +47,9 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
if !logger.IsEnabled() {
wrapper.logOnErrorOnly = true
}
c.Writer = wrapper
// Process the request

View File

@@ -5,6 +5,7 @@ package middleware
import (
"bytes"
"net/http"
"strings"
"github.com/gin-gonic/gin"
@@ -24,15 +25,16 @@ type RequestInfo struct {
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
type ResponseWriterWrapper struct {
gin.ResponseWriter
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
logger logging.RequestLogger // logger is the instance of the request logger service.
requestInfo *RequestInfo // requestInfo holds the details of the original request.
statusCode int // statusCode stores the HTTP status code of the response.
headers map[string][]string // headers stores the response headers.
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
logger logging.RequestLogger // logger is the instance of the request logger service.
requestInfo *RequestInfo // requestInfo holds the details of the original request.
statusCode int // statusCode stores the HTTP status code of the response.
headers map[string][]string // headers stores the response headers.
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
}
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
@@ -192,12 +194,34 @@ func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) {
// For non-streaming responses, it logs the complete request and response details,
// including any API-specific request/response data stored in the Gin context.
func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
if !w.logger.IsEnabled() {
if w.logger == nil {
return nil
}
finalStatusCode := w.statusCode
if finalStatusCode == 0 {
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
finalStatusCode = statusWriter.Status()
} else {
finalStatusCode = 200
}
}
var slicesAPIResponseError []*interfaces.ErrorMessage
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
if isExist {
if apiErrors, ok := apiResponseError.([]*interfaces.ErrorMessage); ok {
slicesAPIResponseError = apiErrors
}
}
hasAPIError := len(slicesAPIResponseError) > 0 || finalStatusCode >= http.StatusBadRequest
forceLog := w.logOnErrorOnly && hasAPIError && !w.logger.IsEnabled()
if !w.logger.IsEnabled() && !forceLog {
return nil
}
if w.isStreaming {
// Close streaming channel and writer
if w.chunkChannel != nil {
close(w.chunkChannel)
w.chunkChannel = nil
@@ -209,80 +233,98 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
}
if w.streamWriter != nil {
err := w.streamWriter.Close()
if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil
return err
}
w.streamWriter = nil
return err
}
} else {
// Capture final status code and headers if not already captured
finalStatusCode := w.statusCode
if finalStatusCode == 0 {
// Get status from underlying ResponseWriter if available
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
finalStatusCode = statusWriter.Status()
} else {
finalStatusCode = 200 // Default
}
if forceLog {
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
}
return nil
}
// Ensure we have the latest headers before finalizing
w.ensureHeadersCaptured()
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
}
// Use the captured headers as the final headers
finalHeaders := make(map[string][]string)
for key, values := range w.headers {
// Make a copy of the values slice to avoid reference issues
headerValues := make([]string, len(values))
copy(headerValues, values)
finalHeaders[key] = headerValues
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
w.ensureHeadersCaptured()
var apiRequestBody []byte
apiRequest, isExist := c.Get("API_REQUEST")
if isExist {
var ok bool
apiRequestBody, ok = apiRequest.([]byte)
if !ok {
apiRequestBody = nil
}
}
finalHeaders := make(map[string][]string, len(w.headers))
for key, values := range w.headers {
headerValues := make([]string, len(values))
copy(headerValues, values)
finalHeaders[key] = headerValues
}
var apiResponseBody []byte
apiResponse, isExist := c.Get("API_RESPONSE")
if isExist {
var ok bool
apiResponseBody, ok = apiResponse.([]byte)
if !ok {
apiResponseBody = nil
}
}
return finalHeaders
}
var slicesAPIResponseError []*interfaces.ErrorMessage
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
if isExist {
var ok bool
slicesAPIResponseError, ok = apiResponseError.([]*interfaces.ErrorMessage)
if !ok {
slicesAPIResponseError = nil
}
}
func (w *ResponseWriterWrapper) extractAPIRequest(c *gin.Context) []byte {
apiRequest, isExist := c.Get("API_REQUEST")
if !isExist {
return nil
}
data, ok := apiRequest.([]byte)
if !ok || len(data) == 0 {
return nil
}
return data
}
// Log complete non-streaming response
return w.logger.LogRequest(
func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
apiResponse, isExist := c.Get("API_RESPONSE")
if !isExist {
return nil
}
data, ok := apiResponse.([]byte)
if !ok || len(data) == 0 {
return nil
}
return data
}
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
var requestBody []byte
if len(w.requestInfo.Body) > 0 {
requestBody = w.requestInfo.Body
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
w.requestInfo.Method,
w.requestInfo.Headers,
w.requestInfo.Body,
finalStatusCode,
finalHeaders,
w.body.Bytes(),
requestBody,
statusCode,
headers,
body,
apiRequestBody,
apiResponseBody,
slicesAPIResponseError,
apiResponseErrors,
forceLog,
)
}
return nil
return w.logger.LogRequest(
w.requestInfo.URL,
w.requestInfo.Method,
w.requestInfo.Headers,
requestBody,
statusCode,
headers,
body,
apiRequestBody,
apiResponseBody,
apiResponseErrors,
)
}
// Status returns the HTTP response status code captured by the wrapper.

View File

@@ -23,11 +23,13 @@ type Option func(*AmpModule)
// - Reverse proxy to Amp control plane for OAuth/management
// - Provider-specific route aliases (/api/provider/{provider}/...)
// - Automatic gzip decompression for misconfigured upstreams
// - Model mapping for routing unavailable models to alternatives
type AmpModule struct {
secretSource SecretSource
proxy *httputil.ReverseProxy
accessManager *sdkaccess.Manager
authMiddleware_ gin.HandlerFunc
modelMapper *DefaultModelMapper
enabled bool
registerOnce sync.Once
}
@@ -101,6 +103,9 @@ func (m *AmpModule) Register(ctx modules.Context) error {
// Use registerOnce to ensure routes are only registered once
var regErr error
m.registerOnce.Do(func() {
// Initialize model mapper from config (for routing unavailable models to alternatives)
m.modelMapper = NewModelMapper(ctx.Config.AmpModelMappings)
// Always register provider aliases - these work without an upstream
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
@@ -159,8 +164,16 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
// OnConfigUpdated handles configuration updates.
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
// Update model mappings (hot-reload supported)
if m.modelMapper != nil {
log.Infof("amp config updated: reloading %d model mapping(s)", len(cfg.AmpModelMappings))
m.modelMapper.UpdateMappings(cfg.AmpModelMappings)
} else {
log.Warnf("amp model mapper not initialized, skipping model mapping update")
}
if !m.enabled {
log.Debug("Amp routing not enabled, skipping config update")
log.Debug("Amp routing not enabled, skipping other config updates")
return nil
}
@@ -182,4 +195,7 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
return nil
}
// GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper
}

View File

@@ -6,16 +6,75 @@ import (
"io"
"net/http/httputil"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// AmpRouteType represents the type of routing decision made for an Amp request
type AmpRouteType string
const (
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free)
RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER"
// RouteTypeModelMapping indicates the request was remapped to another available model (free)
RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING"
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits)
RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS"
// RouteTypeNoProvider indicates no provider or fallback available
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
)
// logAmpRouting logs the routing decision for an Amp request with structured fields
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
fields := log.Fields{
"component": "amp-routing",
"route_type": string(routeType),
"requested_model": requestedModel,
"path": path,
"timestamp": time.Now().Format(time.RFC3339),
}
if resolvedModel != "" && resolvedModel != requestedModel {
fields["resolved_model"] = resolvedModel
}
if provider != "" {
fields["provider"] = provider
}
switch routeType {
case RouteTypeLocalProvider:
fields["cost"] = "free"
fields["source"] = "local_oauth"
log.WithFields(fields).Infof("[AMP] Using local provider for model: %s", requestedModel)
case RouteTypeModelMapping:
fields["cost"] = "free"
fields["source"] = "local_oauth"
fields["mapping"] = requestedModel + " -> " + resolvedModel
log.WithFields(fields).Infof("[AMP] Model mapped: %s -> %s", requestedModel, resolvedModel)
case RouteTypeAmpCredits:
fields["cost"] = "amp_credits"
fields["source"] = "ampcode.com"
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
log.WithFields(fields).Warnf("[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
case RouteTypeNoProvider:
fields["cost"] = "none"
fields["source"] = "error"
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
log.WithFields(fields).Warnf("[AMP] No provider available for model_id: %s", requestedModel)
}
}
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
}
// NewFallbackHandler creates a new fallback handler wrapper
@@ -26,10 +85,25 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
}
}
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
modelMapper: mapper,
}
}
// SetModelMapper sets the model mapper for this handler (allows late binding)
func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
fh.modelMapper = mapper
}
// WrapHandler wraps a gin.HandlerFunc with fallback logic
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
requestPath := c.Request.URL.Path
// Read the request body to extract the model name
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
@@ -55,12 +129,33 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Check if we have providers for this model
providers := util.GetProviderName(normalizedModel)
// Track resolved model for logging (may change if mapping is applied)
resolvedModel := normalizedModel
usedMapping := false
if len(providers) == 0 {
// No providers configured - check if we have a proxy for fallback
// No providers configured - check if we have a model mapping
if fh.modelMapper != nil {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
// Mapping found - rewrite the model in request body
bodyBytes = rewriteModelInBody(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
resolvedModel = mappedModel
usedMapping = true
// Get providers for the mapped model
providers = util.GetProviderName(mappedModel)
// Continue to handler with remapped model
goto handleRequest
}
}
// No mapping found - check if we have a proxy for fallback
proxy := fh.getProxy()
if proxy != nil {
// Fallback to ampcode.com
log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName)
// Log: Forwarding to ampcode.com (uses Amp credits)
logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath)
// Restore body again for the proxy
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
@@ -71,7 +166,23 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
}
// No proxy available, let the normal handler return the error
log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName)
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
}
handleRequest:
// Log the routing decision
providerName := ""
if len(providers) > 0 {
providerName = providers[0]
}
if usedMapping {
// Log: Model was mapped to another model
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
} else if len(providers) > 0 {
// Log: Using local provider (free)
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
}
// Providers available or no proxy for fallback, restore body and use normal handler
@@ -91,6 +202,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
}
}
// rewriteModelInBody replaces the model name in a JSON request body
func rewriteModelInBody(body []byte, newModel string) []byte {
var payload map[string]interface{}
if err := json.Unmarshal(body, &payload); err != nil {
log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err)
return body
}
if _, exists := payload["model"]; exists {
payload["model"] = newModel
newBody, err := json.Marshal(payload)
if err != nil {
log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err)
return body
}
return newBody
}
return body
}
// extractModelFromRequest attempts to extract the model name from various request formats
func extractModelFromRequest(body []byte, c *gin.Context) string {
// First try to parse from JSON body (OpenAI, Claude, etc.)

View File

@@ -0,0 +1,113 @@
// Package amp provides model mapping functionality for routing Amp CLI requests
// to alternative models when the requested model is not available locally.
package amp
import (
"strings"
"sync"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// ModelMapper provides model name mapping/aliasing for Amp CLI requests.
// When an Amp request comes in for a model that isn't available locally,
// this mapper can redirect it to an alternative model that IS available.
type ModelMapper interface {
// MapModel returns the target model name if a mapping exists and the target
// model has available providers. Returns empty string if no mapping applies.
MapModel(requestedModel string) string
// UpdateMappings refreshes the mapping configuration (for hot-reload).
UpdateMappings(mappings []config.AmpModelMapping)
}
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct {
mu sync.RWMutex
mappings map[string]string // from -> to (normalized lowercase keys)
}
// NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{
mappings: make(map[string]string),
}
m.UpdateMappings(mappings)
return m
}
// MapModel checks if a mapping exists for the requested model and if the
// target model has available local providers. Returns the mapped model name
// or empty string if no valid mapping exists.
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
if requestedModel == "" {
return ""
}
m.mu.RLock()
defer m.mu.RUnlock()
// Normalize the requested model for lookup
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
// Check for direct mapping
targetModel, exists := m.mappings[normalizedRequest]
if !exists {
return ""
}
// Verify target model has available providers
providers := util.GetProviderName(targetModel)
if len(providers) == 0 {
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
return ""
}
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel)
return targetModel
}
// UpdateMappings refreshes the mapping configuration from config.
// This is called during initialization and on config hot-reload.
func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
m.mu.Lock()
defer m.mu.Unlock()
// Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings))
for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if from == "" || to == "" {
log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to)
continue
}
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
}
}
// GetMappings returns a copy of current mappings (for debugging/status).
func (m *DefaultModelMapper) GetMappings() map[string]string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]string, len(m.mappings))
for k, v := range m.mappings {
result[k] = v
}
return result
}

View File

@@ -0,0 +1,186 @@
package amp
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func TestNewModelMapper(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
{From: "gpt-5", To: "gemini-2.5-pro"},
}
mapper := NewModelMapper(mappings)
if mapper == nil {
t.Fatal("Expected non-nil mapper")
}
result := mapper.GetMappings()
if len(result) != 2 {
t.Errorf("Expected 2 mappings, got %d", len(result))
}
}
func TestNewModelMapper_Empty(t *testing.T) {
mapper := NewModelMapper(nil)
if mapper == nil {
t.Fatal("Expected non-nil mapper")
}
result := mapper.GetMappings()
if len(result) != 0 {
t.Errorf("Expected 0 mappings, got %d", len(result))
}
}
func TestModelMapper_MapModel_NoProvider(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
// Without a registered provider for the target, mapping should return empty
result := mapper.MapModel("claude-opus-4.5")
if result != "" {
t.Errorf("Expected empty result when target has no provider, got %s", result)
}
}
func TestModelMapper_MapModel_WithProvider(t *testing.T) {
// Register a mock provider for the target model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client")
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
// With a registered provider, mapping should work
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client2")
mappings := []config.AmpModelMapping{
{From: "Claude-Opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
// Should match case-insensitively
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_MapModel_NotFound(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
// Unknown model should return empty
result := mapper.MapModel("unknown-model")
if result != "" {
t.Errorf("Expected empty for unknown model, got %s", result)
}
}
func TestModelMapper_MapModel_EmptyInput(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("")
if result != "" {
t.Errorf("Expected empty for empty input, got %s", result)
}
}
func TestModelMapper_UpdateMappings(t *testing.T) {
mapper := NewModelMapper(nil)
// Initially empty
if len(mapper.GetMappings()) != 0 {
t.Error("Expected 0 initial mappings")
}
// Update with new mappings
mapper.UpdateMappings([]config.AmpModelMapping{
{From: "model-a", To: "model-b"},
{From: "model-c", To: "model-d"},
})
result := mapper.GetMappings()
if len(result) != 2 {
t.Errorf("Expected 2 mappings after update, got %d", len(result))
}
// Update again should replace, not append
mapper.UpdateMappings([]config.AmpModelMapping{
{From: "model-x", To: "model-y"},
})
result = mapper.GetMappings()
if len(result) != 1 {
t.Errorf("Expected 1 mapping after second update, got %d", len(result))
}
}
func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
mapper := NewModelMapper(nil)
mapper.UpdateMappings([]config.AmpModelMapping{
{From: "", To: "model-b"}, // Invalid: empty from
{From: "model-a", To: ""}, // Invalid: empty to
{From: " ", To: "model-b"}, // Invalid: whitespace from
{From: "model-c", To: "model-d"}, // Valid
})
result := mapper.GetMappings()
if len(result) != 1 {
t.Errorf("Expected 1 valid mapping, got %d", len(result))
}
}
func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "model-a", To: "model-b"},
}
mapper := NewModelMapper(mappings)
// Get mappings and modify the returned map
result := mapper.GetMappings()
result["new-key"] = "new-value"
// Original should be unchanged
original := mapper.GetMappings()
if len(original) != 1 {
t.Errorf("Expected original to have 1 mapping, got %d", len(original))
}
if _, exists := original["new-key"]; exists {
t.Error("Original map was modified")
}
}

View File

@@ -83,7 +83,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// Peek at first 2 bytes to detect gzip magic bytes
header := make([]byte, 2)
n, _ := io.ReadFull(originalBody, header)
// Check for gzip magic bytes (0x1f 0x8b)
// If n < 2, we didn't get enough bytes, so it's not gzip
if n >= 2 && header[0] == 0x1f && header[1] == 0x8b {
@@ -97,7 +97,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
}
return nil
}
// Reconstruct complete gzipped data
gzippedData := append(header[:n], rest...)
@@ -129,8 +129,8 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
resp.ContentLength = int64(len(decompressed))
// Update headers to reflect decompressed state
resp.Header.Del("Content-Encoding") // No longer compressed
resp.Header.Del("Content-Length") // Remove stale compressed length
resp.Header.Del("Content-Encoding") // No longer compressed
resp.Header.Del("Content-Length") // Remove stale compressed length
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length
log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed))

View File

@@ -440,52 +440,52 @@ func TestIsStreamingResponse(t *testing.T) {
func TestFilterBetaFeatures(t *testing.T) {
tests := []struct {
name string
header string
name string
header string
featureToRemove string
expected string
expected string
}{
{
name: "Remove context-1m from middle",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
name: "Remove context-1m from middle",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
{
name: "Remove context-1m from start",
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
name: "Remove context-1m from start",
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14",
expected: "fine-grained-tool-streaming-2025-05-14",
},
{
name: "Remove context-1m from end",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
name: "Remove context-1m from end",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14",
expected: "fine-grained-tool-streaming-2025-05-14",
},
{
name: "Feature not present",
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
name: "Feature not present",
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
{
name: "Only feature to remove",
header: "context-1m-2025-08-07",
name: "Only feature to remove",
header: "context-1m-2025-08-07",
featureToRemove: "context-1m-2025-08-07",
expected: "",
expected: "",
},
{
name: "Empty header",
header: "",
name: "Empty header",
header: "",
featureToRemove: "context-1m-2025-08-07",
expected: "",
expected: "",
},
{
name: "Header with spaces",
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
name: "Header with spaces",
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
}

View File

@@ -6,11 +6,11 @@ import (
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
@@ -111,6 +111,14 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
ampAPI.Any("/otel", proxyHandler)
ampAPI.Any("/otel/*path", proxyHandler)
// Root-level routes that AMP CLI expects without /api prefix
// These need the same security middleware as the /api/* routes
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()}
if restrictToLocalhost {
rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware())
}
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
// Google v1beta1 passthrough with OAuth fallback
// AMP CLI uses non-standard paths like /publishers/google/models/...
// We bridge these to our standard Gemini handler to enable local OAuth.
@@ -162,9 +170,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses lazy evaluation to access proxy (which is created after routes are registered)
fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy {
// Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.proxy
})
}, m.modelMapper)
// Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider")

View File

@@ -37,6 +37,7 @@ func TestRegisterManagementRoutes(t *testing.T) {
{"/api/meta", http.MethodGet},
{"/api/telemetry", http.MethodGet},
{"/api/threads", http.MethodGet},
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
{"/api/otel", http.MethodGet},
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
{"/api/provider/google/v1beta1/models", http.MethodGet},

View File

@@ -150,6 +150,9 @@ type Server struct {
// management handler
mgmt *managementHandlers.Handler
// ampModule is the Amp routing module for model mapping hot-reload
ampModule *ampmodule.AmpModule
// managementRoutesRegistered tracks whether the management routes have been attached to the engine.
managementRoutesRegistered atomic.Bool
// managementRoutesEnabled controls whether management endpoints serve real handlers.
@@ -247,6 +250,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Save initial YAML snapshot
s.oldConfigYaml, _ = yaml.Marshal(cfg)
s.applyAccessConfig(nil, cfg)
if authManager != nil {
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
}
managementasset.SetCurrentConfig(cfg)
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
// Initialize management handler
@@ -265,14 +271,14 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
s.setupRoutes()
// Register Amp module using V2 interface with Context
ampModule := ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
ctx := modules.Context{
Engine: engine,
BaseHandler: s.handlers,
Config: cfg,
AuthMiddleware: AuthMiddleware(accessManager),
}
if err := modules.RegisterModule(ctx, ampModule); err != nil {
if err := modules.RegisterModule(ctx, s.ampModule); err != nil {
log.Errorf("Failed to register Amp module: %v", err)
}
@@ -397,6 +403,18 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/antigravity/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if state != "" {
file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
}
@@ -497,6 +515,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/logs", s.mgmt.GetLogs)
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
@@ -507,6 +527,9 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval)
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
@@ -523,6 +546,11 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
@@ -532,6 +560,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
@@ -673,17 +702,33 @@ func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, cl
}
}
// Start begins listening for and serving HTTP requests.
// Start begins listening for and serving HTTP or HTTPS requests.
// It's a blocking call and will only return on an unrecoverable error.
//
// Returns:
// - error: An error if the server fails to start
func (s *Server) Start() error {
log.Debugf("Starting API server on %s", s.server.Addr)
if s == nil || s.server == nil {
return fmt.Errorf("failed to start HTTP server: server not initialized")
}
// Start the HTTP server.
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTP server: %v", err)
useTLS := s.cfg != nil && s.cfg.TLS.Enable
if useTLS {
cert := strings.TrimSpace(s.cfg.TLS.Cert)
key := strings.TrimSpace(s.cfg.TLS.Key)
if cert == "" || key == "" {
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
}
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
}
return nil
}
log.Debugf("Starting API server on %s", s.server.Addr)
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTP server: %v", errServe)
}
return nil
@@ -801,6 +846,9 @@ func (s *Server) UpdateClients(cfg *config.Config) {
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
}
}
if s.handlers != nil && s.handlers.AuthManager != nil {
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
}
// Update log level dynamically when debug flag changes
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
@@ -871,11 +919,22 @@ func (s *Server) UpdateClients(cfg *config.Config) {
s.mgmt.SetAuthManager(s.handlers.AuthManager)
}
// Notify Amp module of config changes (for model mapping hot-reload)
if s.ampModule != nil {
log.Debugf("triggering amp module config update")
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
log.Errorf("failed to update Amp module config: %v", err)
}
} else {
log.Warnf("amp module is nil, skipping config update")
}
// Count client sources from configuration and auth directory
authFiles := util.CountAuthFiles(cfg.AuthDir)
geminiAPIKeyCount := len(cfg.GeminiKey)
claudeAPIKeyCount := len(cfg.ClaudeKey)
codexAPIKeyCount := len(cfg.CodexKey)
vertexAICompatCount := len(cfg.VertexCompatAPIKey)
openAICompatCount := 0
for i := range cfg.OpenAICompatibility {
entry := cfg.OpenAICompatibility[i]
@@ -886,13 +945,14 @@ func (s *Server) UpdateClients(cfg *config.Config) {
openAICompatCount += len(entry.APIKeys)
}
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)\n",
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total,
authFiles,
geminiAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,
vertexAICompatCount,
openAICompatCount,
)
}

View File

@@ -23,6 +23,9 @@ type Config struct {
// Port is the network port on which the API server will listen.
Port int `yaml:"port" json:"-"`
// TLS config controls HTTPS server settings.
TLS TLSConfig `yaml:"tls" json:"tls"`
// AmpUpstreamURL defines the upstream Amp control plane used for non-provider calls.
AmpUpstreamURL string `yaml:"amp-upstream-url" json:"amp-upstream-url"`
@@ -34,6 +37,12 @@ type Config struct {
// browser attacks and remote access to management endpoints. Default: true (recommended).
AmpRestrictManagementToLocalhost bool `yaml:"amp-restrict-management-to-localhost" json:"amp-restrict-management-to-localhost"`
// AmpModelMappings defines model name mappings for Amp CLI requests.
// When Amp requests a model that isn't available locally, these mappings
// allow routing to an alternative model that IS available.
// Example: Map "claude-opus-4.5" -> "claude-sonnet-4" when opus isn't available.
AmpModelMappings []AmpModelMapping `yaml:"amp-model-mappings" json:"amp-model-mappings"`
// AuthDir is the directory where authentication token files are stored.
AuthDir string `yaml:"auth-dir" json:"-"`
@@ -61,8 +70,14 @@ type Config struct {
// GeminiKey defines Gemini API key configurations with optional routing overrides.
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
// VertexCompatAPIKey defines Vertex AI-compatible API key configurations for third-party providers.
// Used for services that use Vertex AI-style paths but with simple API key authentication.
VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"`
// RequestRetry defines the retry times when the request failed.
RequestRetry int `yaml:"request-retry" json:"request-retry"`
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
@@ -78,6 +93,19 @@ type Config struct {
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
Enable bool `yaml:"enable" json:"enable"`
// Cert is the path to the TLS certificate file.
Cert string `yaml:"cert" json:"cert"`
// Key is the path to the TLS private key file.
Key string `yaml:"key" json:"key"`
}
// RemoteManagement holds management API configuration under 'remote-management'.
@@ -100,6 +128,18 @@ type QuotaExceeded struct {
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available.
type AmpModelMapping struct {
// From is the model name that Amp CLI requests (e.g., "claude-opus-4.5").
From string `yaml:"from" json:"from"`
// To is the target model name to route to (e.g., "claude-sonnet-4").
// The target model must have available providers in the registry.
To string `yaml:"to" json:"to"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads.
type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload.
@@ -142,6 +182,9 @@ type ClaudeKey struct {
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
@@ -168,6 +211,9 @@ type CodexKey struct {
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// GeminiKey represents the configuration for a Gemini API key,
@@ -184,6 +230,9 @@ type GeminiKey struct {
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// OpenAICompatibility represents the configuration for OpenAI API compatibility
@@ -298,6 +347,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize Gemini API key configuration and migrate legacy entries.
cfg.SanitizeGeminiKeys()
// Sanitize Vertex-compatible API keys: drop entries without base-url
cfg.SanitizeVertexCompatKeys()
// Sanitize Codex keys: drop entries without base-url
cfg.SanitizeCodexKeys()
@@ -307,6 +359,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize OpenAI compatibility providers: drop entries without base-url
cfg.SanitizeOpenAICompatibility()
// Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Return the populated configuration struct.
return &cfg, nil
}
@@ -344,6 +399,7 @@ func (cfg *Config) SanitizeCodexKeys() {
e := cfg.CodexKey[i]
e.BaseURL = strings.TrimSpace(e.BaseURL)
e.Headers = NormalizeHeaders(e.Headers)
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
if e.BaseURL == "" {
continue
}
@@ -360,6 +416,7 @@ func (cfg *Config) SanitizeClaudeKeys() {
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
entry.Headers = NormalizeHeaders(entry.Headers)
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
}
}
@@ -380,6 +437,7 @@ func (cfg *Config) SanitizeGeminiKeys() {
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = NormalizeHeaders(entry.Headers)
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
if _, exists := seen[entry.APIKey]; exists {
continue
}
@@ -442,6 +500,55 @@ func NormalizeHeaders(headers map[string]string) map[string]string {
return clean
}
// NormalizeExcludedModels trims, lowercases, and deduplicates model exclusion patterns.
// It preserves the order of first occurrences and drops empty entries.
func NormalizeExcludedModels(models []string) []string {
if len(models) == 0 {
return nil
}
seen := make(map[string]struct{}, len(models))
out := make([]string, 0, len(models))
for _, raw := range models {
trimmed := strings.ToLower(strings.TrimSpace(raw))
if trimmed == "" {
continue
}
if _, exists := seen[trimmed]; exists {
continue
}
seen[trimmed] = struct{}{}
out = append(out, trimmed)
}
if len(out) == 0 {
return nil
}
return out
}
// NormalizeOAuthExcludedModels cleans provider -> excluded models mappings by normalizing provider keys
// and applying model exclusion normalization to each entry.
func NormalizeOAuthExcludedModels(entries map[string][]string) map[string][]string {
if len(entries) == 0 {
return nil
}
out := make(map[string][]string, len(entries))
for provider, models := range entries {
key := strings.ToLower(strings.TrimSpace(provider))
if key == "" {
continue
}
normalized := NormalizeExcludedModels(models)
if len(normalized) == 0 {
continue
}
out[key] = normalized
}
if len(out) == 0 {
return nil
}
return out
}
// hashSecret hashes the given secret using bcrypt.
func hashSecret(secret string) (string, error) {
// Use default cost for simplicity.
@@ -492,6 +599,7 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
// Remove deprecated auth block before merging to avoid persisting it again.
removeMapKey(original.Content[0], "auth")
removeLegacyOpenAICompatAPIKeys(original.Content[0])
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-excluded-models")
// Merge generated into original in-place, preserving comments/order of existing nodes.
mergeMappingPreserve(original.Content[0], generated.Content[0])
@@ -690,6 +798,10 @@ func mergeNodePreserve(dst, src *yaml.Node) {
continue
}
mergeNodePreserve(dst.Content[i], src.Content[i])
if dst.Content[i] != nil && src.Content[i] != nil &&
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
pruneMissingMapKeys(dst.Content[i], src.Content[i])
}
}
// Append any extra items from src
for i := len(dst.Content); i < len(src.Content); i++ {
@@ -731,6 +843,7 @@ func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
switch key {
case "generative-language-api-key",
"gemini-api-key",
"vertex-api-key",
"claude-api-key",
"codex-api-key",
"openai-compatibility":
@@ -967,6 +1080,73 @@ func removeLegacyOpenAICompatAPIKeys(root *yaml.Node) {
}
}
func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
if key == "" || dstRoot == nil || srcRoot == nil {
return
}
if dstRoot.Kind != yaml.MappingNode || srcRoot.Kind != yaml.MappingNode {
return
}
dstIdx := findMapKeyIndex(dstRoot, key)
if dstIdx < 0 || dstIdx+1 >= len(dstRoot.Content) {
return
}
srcIdx := findMapKeyIndex(srcRoot, key)
if srcIdx < 0 {
removeMapKey(dstRoot, key)
return
}
if srcIdx+1 >= len(srcRoot.Content) {
return
}
srcVal := srcRoot.Content[srcIdx+1]
dstVal := dstRoot.Content[dstIdx+1]
if srcVal == nil {
dstRoot.Content[dstIdx+1] = nil
return
}
if srcVal.Kind != yaml.MappingNode {
dstRoot.Content[dstIdx+1] = deepCopyNode(srcVal)
return
}
if dstVal == nil || dstVal.Kind != yaml.MappingNode {
dstRoot.Content[dstIdx+1] = deepCopyNode(srcVal)
return
}
pruneMissingMapKeys(dstVal, srcVal)
}
func pruneMissingMapKeys(dstMap, srcMap *yaml.Node) {
if dstMap == nil || srcMap == nil || dstMap.Kind != yaml.MappingNode || srcMap.Kind != yaml.MappingNode {
return
}
keep := make(map[string]struct{}, len(srcMap.Content)/2)
for i := 0; i+1 < len(srcMap.Content); i += 2 {
keyNode := srcMap.Content[i]
if keyNode == nil {
continue
}
key := strings.TrimSpace(keyNode.Value)
if key == "" {
continue
}
keep[key] = struct{}{}
}
for i := 0; i+1 < len(dstMap.Content); {
keyNode := dstMap.Content[i]
if keyNode == nil {
i += 2
continue
}
key := strings.TrimSpace(keyNode.Value)
if _, ok := keep[key]; !ok {
dstMap.Content = append(dstMap.Content[:i], dstMap.Content[i+2:]...)
continue
}
i += 2
}
}
// normalizeCollectionNodeStyles forces YAML collections to use block notation, keeping
// lists and maps readable. Empty sequences retain flow style ([]) so empty list markers
// remain compact.

View File

@@ -0,0 +1,84 @@
package config
import "strings"
// VertexCompatKey represents the configuration for Vertex AI-compatible API keys.
// This supports third-party services that use Vertex AI-style endpoint paths
// (/publishers/google/models/{model}:streamGenerateContent) but authenticate
// with simple API keys instead of Google Cloud service account credentials.
//
// Example services: zenmux.ai and similar Vertex-compatible providers.
type VertexCompatKey struct {
// APIKey is the authentication key for accessing the Vertex-compatible API.
// Maps to the x-goog-api-key header.
APIKey string `yaml:"api-key" json:"api-key"`
// BaseURL is the base URL for the Vertex-compatible API endpoint.
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
// ProxyURL optionally overrides the global proxy for this API key.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
// Commonly used for cookies, user-agent, and other authentication headers.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// Models defines the model configurations including aliases for routing.
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
}
// VertexCompatModel represents a model configuration for Vertex compatibility,
// including the actual model name and its alias for API routing.
type VertexCompatModel struct {
// Name is the actual model name used by the external provider.
Name string `yaml:"name" json:"name"`
// Alias is the model name alias that clients will use to reference this model.
Alias string `yaml:"alias" json:"alias"`
}
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
func (cfg *Config) SanitizeVertexCompatKeys() {
if cfg == nil {
return
}
seen := make(map[string]struct{}, len(cfg.VertexCompatAPIKey))
out := cfg.VertexCompatAPIKey[:0]
for i := range cfg.VertexCompatAPIKey {
entry := cfg.VertexCompatAPIKey[i]
entry.APIKey = strings.TrimSpace(entry.APIKey)
if entry.APIKey == "" {
continue
}
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
if entry.BaseURL == "" {
// BaseURL is required for Vertex API key entries
continue
}
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = NormalizeHeaders(entry.Headers)
// Sanitize models: remove entries without valid alias
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
for _, model := range entry.Models {
model.Alias = strings.TrimSpace(model.Alias)
model.Name = strings.TrimSpace(model.Name)
if model.Alias != "" && model.Name != "" {
sanitizedModels = append(sanitizedModels, model)
}
}
entry.Models = sanitizedModels
// Use API key + base URL as uniqueness key
uniqueKey := entry.APIKey + "|" + entry.BaseURL
if _, exists := seen[uniqueKey]; exists {
continue
}
seen[uniqueKey] = struct{}{}
out = append(out, entry)
}
cfg.VertexCompatAPIKey = out
}

View File

@@ -21,4 +21,7 @@ const (
// OpenaiResponse represents the OpenAI response format identifier.
OpenaiResponse = "openai-response"
// Antigravity represents the Antigravity response format identifier.
Antigravity = "antigravity"
)

View File

@@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
@@ -156,17 +157,30 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
// Returns:
// - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
if !l.enabled {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
}
// LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
}
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
if !l.enabled && !force {
return nil
}
// Ensure logs directory exists
if err := l.ensureLogsDir(); err != nil {
return fmt.Errorf("failed to create logs directory: %w", err)
if errEnsure := l.ensureLogsDir(); errEnsure != nil {
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
}
// Generate filename
filename := l.generateFilename(url)
if force && !l.enabled {
filename = l.generateErrorFilename(url)
}
filePath := filepath.Join(l.logsDir, filename)
// Decompress response if needed
@@ -184,6 +198,12 @@ func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[st
return fmt.Errorf("failed to write log file: %w", err)
}
if force && !l.enabled {
if errCleanup := l.cleanupOldErrorLogs(); errCleanup != nil {
log.WithError(errCleanup).Warn("failed to clean up old error logs")
}
}
return nil
}
@@ -239,6 +259,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
return writer, nil
}
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
func (l *FileRequestLogger) generateErrorFilename(url string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url))
}
// ensureLogsDir creates the logs directory if it doesn't exist.
//
// Returns:
@@ -312,6 +337,52 @@ func (l *FileRequestLogger) sanitizeForFilename(path string) string {
return sanitized
}
// cleanupOldErrorLogs keeps only the newest 10 forced error log files.
func (l *FileRequestLogger) cleanupOldErrorLogs() error {
entries, errRead := os.ReadDir(l.logsDir)
if errRead != nil {
return errRead
}
type logFile struct {
name string
modTime time.Time
}
var files []logFile
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
continue
}
info, errInfo := entry.Info()
if errInfo != nil {
log.WithError(errInfo).Warn("failed to read error log info")
continue
}
files = append(files, logFile{name: name, modTime: info.ModTime()})
}
if len(files) <= 10 {
return nil
}
sort.Slice(files, func(i, j int) bool {
return files[i].modTime.After(files[j].modTime)
})
for _, file := range files[10:] {
if errRemove := os.Remove(filepath.Join(l.logsDir, file.name)); errRemove != nil {
log.WithError(errRemove).Warnf("failed to remove old error log: %s", file.name)
}
}
return nil
}
// formatLogContent creates the complete log content for non-streaming requests.
//
// Parameters:

View File

@@ -8,60 +8,140 @@ func GetClaudeModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "claude-haiku-4-5-20251001",
Object: "model",
Created: 1759276800, // 2025-10-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Haiku",
ID: "claude-haiku-4-5-20251001",
Object: "model",
Created: 1759276800, // 2025-10-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Haiku",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "claude-sonnet-4-5-20250929",
Object: "model",
Created: 1759104000, // 2025-09-29
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Sonnet",
ID: "claude-sonnet-4-5-20250929",
Object: "model",
Created: 1759104000, // 2025-09-29
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "claude-opus-4-1-20250805",
Object: "model",
Created: 1722945600, // 2025-08-05
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.1 Opus",
ID: "claude-sonnet-4-5-thinking",
Object: "model",
Created: 1759104000, // 2025-09-29
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Sonnet Thinking",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "claude-opus-4-20250514",
Object: "model",
Created: 1715644800, // 2025-05-14
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4 Opus",
ID: "claude-opus-4-5-thinking",
Object: "model",
Created: 1761955200, // 2025-11-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Opus Thinking",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "claude-sonnet-4-20250514",
Object: "model",
Created: 1715644800, // 2025-05-14
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4 Sonnet",
ID: "claude-opus-4-5-thinking-low",
Object: "model",
Created: 1761955200, // 2025-11-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Opus Thinking Low",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "claude-3-7-sonnet-20250219",
Object: "model",
Created: 1708300800, // 2025-02-19
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 3.7 Sonnet",
ID: "claude-opus-4-5-thinking-medium",
Object: "model",
Created: 1761955200, // 2025-11-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Opus Thinking Medium",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "claude-3-5-haiku-20241022",
Object: "model",
Created: 1729555200, // 2024-10-22
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 3.5 Haiku",
ID: "claude-opus-4-5-thinking-high",
Object: "model",
Created: 1761955200, // 2025-11-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Opus Thinking High",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "claude-opus-4-5-20251101",
Object: "model",
Created: 1761955200, // 2025-11-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Opus",
Description: "Premium model combining maximum intelligence with practical performance",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "claude-opus-4-1-20250805",
Object: "model",
Created: 1722945600, // 2025-08-05
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.1 Opus",
ContextLength: 200000,
MaxCompletionTokens: 32000,
},
{
ID: "claude-opus-4-20250514",
Object: "model",
Created: 1715644800, // 2025-05-14
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4 Opus",
ContextLength: 200000,
MaxCompletionTokens: 32000,
},
{
ID: "claude-sonnet-4-20250514",
Object: "model",
Created: 1715644800, // 2025-05-14
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "claude-3-7-sonnet-20250219",
Object: "model",
Created: 1708300800, // 2025-02-19
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 3.7 Sonnet",
ContextLength: 128000,
MaxCompletionTokens: 8192,
},
{
ID: "claude-3-5-haiku-20241022",
Object: "model",
Created: 1729555200, // 2024-10-22
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 3.5 Haiku",
ContextLength: 128000,
MaxCompletionTokens: 8192,
},
}
}
@@ -129,6 +209,21 @@ func GetGeminiModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
Created: 1737158400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-pro-image-preview",
Version: "3.0",
DisplayName: "Gemini 3 Pro Image Preview",
Description: "Gemini 3 Pro Image Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
}
}
@@ -207,6 +302,7 @@ func GetGeminiVertexModels() []*ModelInfo {
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
}
}
@@ -858,7 +954,6 @@ func GetIFlowModels() []*ModelInfo {
}{
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
{ID: "qwen3-coder", DisplayName: "Qwen3-Coder-480B-A35B", Description: "Qwen3 Coder 480B A35B", Created: 1753228800},
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},

View File

@@ -826,7 +826,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
}
}
// GetFirstAvailableModel returns the first available model for the given handler type.
// It prioritizes models by their creation timestamp (newest first) and checks if they have
// available clients that are not suspended or over quota.

View File

@@ -129,18 +129,60 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
firstEvent, ok := <-wsStream
if !ok {
err = fmt.Errorf("wsrelay: stream closed before start")
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK {
metadataLogged := false
if firstEvent.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
metadataLogged = true
}
var body bytes.Buffer
if len(firstEvent.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(firstEvent.Payload))
body.Write(firstEvent.Payload)
}
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
}
for event := range wsStream {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
if body.Len() == 0 {
body.WriteString(event.Err.Error())
}
break
}
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
body.Write(event.Payload)
}
if event.Type == wsrelay.MessageTypeStreamEnd {
break
}
}
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
go func(first wsrelay.StreamEvent) {
defer close(out)
var param any
metadataLogged := false
for event := range wsStream {
processEvent := func(event wsrelay.StreamEvent) bool {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return
return false
}
switch event.Type {
case wsrelay.MessageTypeStreamStart:
@@ -151,7 +193,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
filtered := filterAIStudioUsageMetadata(event.Payload)
filtered := FilterSSEUsageMetadata(event.Payload)
if detail, ok := parseGeminiStreamUsage(filtered); ok {
reporter.publish(ctx, detail)
}
@@ -162,7 +204,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
break
}
case wsrelay.MessageTypeStreamEnd:
return
return false
case wsrelay.MessageTypeHTTPResp:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
@@ -176,15 +218,24 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
}
reporter.publish(ctx, parseGeminiUsage(event.Payload))
return
return false
case wsrelay.MessageTypeError:
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false
}
return true
}
if !processEvent(first) {
return
}
for event := range wsStream {
if !processEvent(event) {
return
}
}
}()
}(firstEvent)
return stream, nil
}
@@ -257,17 +308,14 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
}
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ConvertThinkingLevelToBudget(payload)
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfig(e.cfg, req.Model, payload)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
metadataAction := "generateContent"
if req.Metadata != nil {
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
@@ -296,65 +344,6 @@ func (e *AIStudioExecutor) buildEndpoint(model, action, alt string) string {
return base
}
// filterAIStudioUsageMetadata removes usageMetadata from intermediate SSE events so that
// only the terminal chunk retains token statistics.
func filterAIStudioUsageMetadata(payload []byte) []byte {
if len(payload) == 0 {
return payload
}
lines := bytes.Split(payload, []byte("\n"))
modified := false
for idx, line := range lines {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
dataIdx := bytes.Index(line, []byte("data:"))
if dataIdx < 0 {
continue
}
rawJSON := bytes.TrimSpace(line[dataIdx+5:])
cleaned, changed := stripUsageMetadataFromJSON(rawJSON)
if !changed {
continue
}
var rebuilt []byte
rebuilt = append(rebuilt, line[:dataIdx]...)
rebuilt = append(rebuilt, []byte("data:")...)
if len(cleaned) > 0 {
rebuilt = append(rebuilt, ' ')
rebuilt = append(rebuilt, cleaned...)
}
lines[idx] = rebuilt
modified = true
}
if !modified {
return payload
}
return bytes.Join(lines, []byte("\n"))
}
// stripUsageMetadataFromJSON drops usageMetadata when no finishReason is present.
func stripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
jsonBytes := bytes.TrimSpace(rawJSON)
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
return rawJSON, false
}
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
if finishReason.Exists() && finishReason.String() != "" {
return rawJSON, false
}
if !gjson.GetBytes(jsonBytes, "usageMetadata").Exists() {
return rawJSON, false
}
cleaned, err := sjson.DeleteBytes(jsonBytes, "usageMetadata")
if err != nil {
return rawJSON, false
}
return cleaned, true
}
// ensureColonSpacedJSON normalizes JSON objects so that colons are followed by a single space while
// keeping the payload otherwise compact. Non-JSON inputs are returned unchanged.
func ensureColonSpacedJSON(payload []byte) []byte {

View File

@@ -26,16 +26,18 @@ import (
)
const (
antigravityBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.3 windows/amd64"
antigravityAuthType = "antigravity"
refreshSkew = 5 * time.Minute
streamScannerBuffer int = 20_971_520
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
streamScannerBuffer int = 20_971_520
)
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
@@ -70,45 +72,81 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt)
if errReq != nil {
return resp, errReq
}
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
var lastStatus int
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
lastStatus = 0
lastBody = nil
lastErr = errDo
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errDo
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
return resp, err
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
return resp, err
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
}
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
switch {
case lastStatus != 0:
err = statusErr{code: lastStatus, msg: string(lastBody)}
case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
return resp, err
}
// ExecuteStream handles streaming requests via the antigravity upstream.
@@ -127,64 +165,126 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt)
if errReq != nil {
return nil, errReq
}
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
var lastStatus int
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return nil, err
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
lastStatus = 0
lastBody = nil
lastErr = errDo
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errDo
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
lastStatus = 0
lastBody = nil
lastErr = errRead
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errRead
return nil, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
return nil, err
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), &param)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}()
return stream, nil
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), &param)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}(httpResp)
return stream, nil
}
switch {
case lastStatus != 0:
err = statusErr{code: lastStatus, msg: string(lastBody)}
case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
return nil, err
}
// Refresh refreshes the OAuth token using the refresh token.
@@ -215,54 +315,86 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
auth = updatedAuth
}
modelsURL := buildBaseURL(auth) + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
return nil
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
if host := resolveHost(auth); host != "" {
httpReq.Host = host
}
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
return nil
}
defer func() {
for idx, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
return nil
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
if host := resolveHost(baseURL); host != "" {
httpReq.Host = host
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
if errRead != nil {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
return nil
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
return nil
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
return nil
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
return nil
now := time.Now().Unix()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for id := range result.Map() {
id = modelName2Alias(id)
if id != "" {
modelInfo := &registry.ModelInfo{
ID: id,
Name: id,
Description: id,
DisplayName: id,
Version: id,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Add Thinking support for thinking models
if strings.HasSuffix(id, "-thinking") || strings.Contains(id, "-thinking-") {
modelInfo.Thinking = &registry.ThinkingSupport{
Min: 1024,
Max: 100000,
ZeroAllowed: false,
DynamicAllowed: true,
}
}
models = append(models, modelInfo)
}
}
return models
}
now := time.Now().Unix()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for id := range result.Map() {
models = append(models, &registry.ModelInfo{
ID: id,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
})
}
return models
return nil
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
@@ -348,12 +480,15 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
return auth, nil
}
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt string) (*http.Request, error) {
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
if token == "" {
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
base := buildBaseURL(auth)
base := strings.TrimSuffix(baseURL, "/")
if base == "" {
base = buildBaseURL(auth)
}
path := antigravityGeneratePath
if stream {
path = antigravityStreamPath
@@ -374,6 +509,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
}
payload = geminiToAntigravity(modelName, payload)
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
if errReq != nil {
return nil, errReq
@@ -386,7 +522,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
} else {
httpReq.Header.Set("Accept", "application/json")
}
if host := resolveHost(auth); host != "" {
if host := resolveHost(base); host != "" {
httpReq.Host = host
}
@@ -470,26 +606,13 @@ func int64Value(value any) (int64, bool) {
}
func buildBaseURL(auth *cliproxyauth.Auth) string {
if auth != nil {
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["base_url"]); v != "" {
return strings.TrimSuffix(v, "/")
}
}
if auth.Metadata != nil {
if v, ok := auth.Metadata["base_url"].(string); ok {
v = strings.TrimSpace(v)
if v != "" {
return strings.TrimSuffix(v, "/")
}
}
}
if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 {
return baseURLs[0]
}
return antigravityBaseURL
return antigravityBaseURLAutopush
}
func resolveHost(auth *cliproxyauth.Auth) string {
base := buildBaseURL(auth)
func resolveHost(base string) string {
parsed, errParse := url.Parse(base)
if errParse != nil {
return ""
@@ -516,6 +639,37 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
return defaultAntigravityAgent
}
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
return []string{base}
}
return []string{
antigravityBaseURLDaily,
antigravityBaseURLAutopush,
// antigravityBaseURLProd,
}
}
func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["base_url"]); v != "" {
return strings.TrimSuffix(v, "/")
}
}
if auth.Metadata != nil {
if v, ok := auth.Metadata["base_url"].(string); ok {
v = strings.TrimSpace(v)
if v != "" {
return strings.TrimSuffix(v, "/")
}
}
}
return ""
}
func geminiToAntigravity(modelName string, payload []byte) []byte {
template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity")
@@ -525,18 +679,27 @@ func geminiToAntigravity(modelName string, payload []byte) []byte {
template, _ = sjson.Delete(template, "request.safetySettings")
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
if !strings.HasPrefix(modelName, "gemini-3-") {
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
template, _ = sjson.Set(template, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
}
gjson.Get(template, "request.contents").ForEach(func(key, content gjson.Result) bool {
if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
template, _ = sjson.Set(template, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
if strings.HasPrefix(modelName, "claude-sonnet-") {
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
if funcDecl.Get("parametersJsonSchema").Exists() {
template, _ = sjson.SetRaw(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()), funcDecl.Get("parametersJsonSchema").Raw)
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters.$schema", key.Int(), funKey.Int()))
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parametersJsonSchema", key.Int(), funKey.Int()))
}
return true
})
}
return true
})
return true
})
}
return []byte(template)
}
@@ -558,3 +721,39 @@ func generateProjectID() string {
randomPart := strings.ToLower(uuid.NewString())[:5]
return adj + "-" + noun + "-" + randomPart
}
func modelName2Alias(modelName string) string {
switch modelName {
case "rev19-uic3-1p":
return "gemini-2.5-computer-use-preview-10-2025"
case "gemini-3-pro-image":
return "gemini-3-pro-image-preview"
case "gemini-3-pro-high":
return "gemini-3-pro-preview"
case "claude-sonnet-4-5":
return "gemini-claude-sonnet-4-5"
case "claude-sonnet-4-5-thinking":
return "gemini-claude-sonnet-4-5-thinking"
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
return ""
default:
return modelName
}
}
func alias2ModelName(modelName string) string {
switch modelName {
case "gemini-2.5-computer-use-preview-10-2025":
return "rev19-uic3-1p"
case "gemini-3-pro-image-preview":
return "gemini-3-pro-image"
case "gemini-3-pro-preview":
return "gemini-3-pro-high"
case "gemini-claude-sonnet-4-5":
return "claude-sonnet-4-5"
case "gemini-claude-sonnet-4-5-thinking":
return "claude-sonnet-4-5-thinking"
default:
return modelName
}
}

View File

@@ -17,6 +17,7 @@ import (
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -58,18 +59,27 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
body, _ = sjson.SetBytes(body, "model", modelOverride)
modelForUpstream = modelOverride
}
// Inject thinking config based on model suffix for thinking variants
body = e.injectThinkingConfig(req.Model, body)
if !strings.HasPrefix(modelForUpstream, "claude-3-5-haiku") {
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
body = checkSystemInstructions(body)
}
body = applyPayloadConfig(e.cfg, req.Model, body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return resp, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -154,15 +164,24 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
body, _ = sjson.SetBytes(body, "model", modelOverride)
}
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
// Inject thinking config based on model suffix for thinking variants
body = e.injectThinkingConfig(req.Model, body)
body = checkSystemInstructions(body)
body = applyPayloadConfig(e.cfg, req.Model, body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
}
applyClaudeHeaders(httpReq, auth, apiKey, true)
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -283,15 +302,19 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}
if !strings.HasPrefix(modelForUpstream, "claude-3-5-haiku") {
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
body = checkSystemInstructions(body)
}
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return cliproxyexecutor.Response{}, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -383,10 +406,101 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (
return auth, nil
}
// extractAndRemoveBetas extracts the "betas" array from the body and removes it.
// Returns the extracted betas as a string slice and the modified body.
func extractAndRemoveBetas(body []byte) ([]string, []byte) {
betasResult := gjson.GetBytes(body, "betas")
if !betasResult.Exists() {
return nil, body
}
var betas []string
if betasResult.IsArray() {
for _, item := range betasResult.Array() {
if s := strings.TrimSpace(item.String()); s != "" {
betas = append(betas, s)
}
}
} else if s := strings.TrimSpace(betasResult.String()); s != "" {
betas = append(betas, s)
}
body, _ = sjson.DeleteBytes(body, "betas")
return betas, body
}
// injectThinkingConfig adds thinking configuration based on model name suffix
func (e *ClaudeExecutor) injectThinkingConfig(modelName string, body []byte) []byte {
// Only inject if thinking config is not already present
if gjson.GetBytes(body, "thinking").Exists() {
return body
}
var budgetTokens int
switch {
case strings.HasSuffix(modelName, "-thinking-low"):
budgetTokens = 1024
case strings.HasSuffix(modelName, "-thinking-medium"):
budgetTokens = 8192
case strings.HasSuffix(modelName, "-thinking-high"):
budgetTokens = 24576
case strings.HasSuffix(modelName, "-thinking"):
// Default thinking without suffix uses medium budget
budgetTokens = 8192
default:
return body
}
body, _ = sjson.SetBytes(body, "thinking.type", "enabled")
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", budgetTokens)
return body
}
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
// Anthropic API requires this constraint; violating it returns a 400 error.
// This function should be called after all thinking configuration is finalized.
// It looks up the model's MaxCompletionTokens from the registry to use as the cap.
func ensureMaxTokensForThinking(modelName string, body []byte) []byte {
thinkingType := gjson.GetBytes(body, "thinking.type").String()
if thinkingType != "enabled" {
return body
}
budgetTokens := gjson.GetBytes(body, "thinking.budget_tokens").Int()
if budgetTokens <= 0 {
return body
}
maxTokens := gjson.GetBytes(body, "max_tokens").Int()
// Look up the model's max completion tokens from the registry
maxCompletionTokens := 0
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(modelName); modelInfo != nil {
maxCompletionTokens = modelInfo.MaxCompletionTokens
}
// Fall back to budget + buffer if registry lookup fails or returns 0
const fallbackBuffer = 4000
requiredMaxTokens := budgetTokens + fallbackBuffer
if maxCompletionTokens > 0 {
requiredMaxTokens = int64(maxCompletionTokens)
}
if maxTokens < requiredMaxTokens {
body, _ = sjson.SetBytes(body, "max_tokens", requiredMaxTokens)
}
return body
}
func (e *ClaudeExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
if alias == "" {
return ""
}
// Hardcoded mappings for thinking models to actual Claude model names
switch alias {
case "claude-opus-4-5-thinking", "claude-opus-4-5-thinking-low", "claude-opus-4-5-thinking-medium", "claude-opus-4-5-thinking-high":
return "claude-opus-4-5-20251101"
case "claude-sonnet-4-5-thinking":
return "claude-sonnet-4-5-20250929"
}
entry := e.resolveClaudeConfig(auth)
if entry == nil {
return ""
@@ -530,7 +644,7 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
return body, nil
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool) {
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
r.Header.Set("Authorization", "Bearer "+apiKey)
r.Header.Set("Content-Type", "application/json")
@@ -539,15 +653,30 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
ginHeaders = ginCtx.Request.Header
}
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
baseBetas = val
if !strings.Contains(val, "oauth") {
val += ",oauth-2025-04-20"
baseBetas += ",oauth-2025-04-20"
}
r.Header.Set("Anthropic-Beta", val)
} else {
r.Header.Set("Anthropic-Beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14")
}
// Merge extra betas from request body
if len(extraBetas) > 0 {
existingSet := make(map[string]bool)
for _, b := range strings.Split(baseBetas, ",") {
existingSet[strings.TrimSpace(b)] = true
}
for _, beta := range extraBetas {
beta = strings.TrimSpace(beta)
if beta != "" && !existingSet[beta] {
baseBetas += "," + beta
existingSet[beta] = true
}
}
}
r.Header.Set("Anthropic-Beta", baseBetas)
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
@@ -590,3 +719,22 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
}
return
}
func checkSystemInstructions(payload []byte) []byte {
system := gjson.GetBytes(payload, "system")
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
if system.IsArray() {
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
system.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
}
return true
})
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
}
} else {
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
}
return payload
}

View File

@@ -62,15 +62,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if hasOverride && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride)
}
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
@@ -204,15 +197,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if hasOverride && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride)
}
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
@@ -408,16 +394,9 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
var lastStatus int
var lastBody []byte
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
if hasOverride && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
payload = util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
}
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")

View File

@@ -79,13 +79,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = applyThinkingMetadata(body, req.Metadata, req.Model)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
@@ -174,13 +168,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = applyThinkingMetadata(body, req.Metadata, req.Model)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
@@ -256,10 +244,15 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
filtered := FilterSSEUsageMetadata(line)
payload := jsonPayload(filtered)
if len(payload) == 0 {
continue
}
if detail, ok := parseGeminiStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(payload), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
@@ -283,13 +276,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = applyThinkingMetadata(translatedReq, req.Metadata, req.Model)
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)

View File

@@ -51,11 +51,238 @@ func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.A
// Execute handles non-streaming requests.
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return resp, errCreds
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
// If no API key found, fall back to service account authentication
if apiKey == "" {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return resp, errCreds
}
return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
}
// Use API key authentication
return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// ExecuteStream handles SSE streaming for Vertex.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
// If no API key found, fall back to service account authentication
if apiKey == "" {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return nil, errCreds
}
return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
}
// Use API key authentication
return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// CountTokens calls Vertex countTokens endpoint.
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
// If no API key found, fall back to service account authentication
if apiKey == "" {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return cliproxyexecutor.Response{}, errCreds
}
return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
}
// Use API key authentication
return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// countTokensWithServiceAccount handles token counting using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
}
// Refresh is a no-op for service account based credentials.
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}
// executeWithServiceAccount handles authentication using service account credentials.
// This method contains the original service account authentication logic.
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -149,13 +376,104 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
return resp, nil
}
// ExecuteStream handles SSE streaming for Vertex.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return nil, errCreds
// executeWithAPIKey handles authentication using API key credentials.
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
action := "generateContent"
if req.Metadata != nil {
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
action = "countTokens"
}
}
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if errNewReq != nil {
return resp, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -266,42 +584,44 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return stream, nil
}
// CountTokens calls Vertex countTokens endpoint.
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return cliproxyexecutor.Response{}, errCreds
}
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
return nil, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
@@ -315,7 +635,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
@@ -327,38 +647,53 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
return nil, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
}
// Refresh is a no-op for service account based credentials.
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 20_971_520)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
@@ -401,6 +736,23 @@ func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccou
return projectID, location, saJSON, nil
}
// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern.
func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""
}
if a.Attributes != nil {
apiKey = a.Attributes["api_key"]
baseURL = a.Attributes["base_url"]
}
if apiKey == "" && a.Metadata != nil {
if v, ok := a.Metadata["access_token"].(string); ok {
apiKey = v
}
}
return
}
func vertexBaseURL(location string) string {
loc := strings.TrimSpace(location)
if loc == "" {

View File

@@ -4,10 +4,42 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// applyThinkingMetadata applies thinking config from model suffix metadata (e.g., -reasoning, -thinking-N)
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
func applyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(metadata)
if !ok {
return payload
}
if !util.ModelSupportsThinking(model) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
}
// applyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., -reasoning, -thinking-N)
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(metadata)
if !ok {
return payload
}
if budgetOverride != nil && util.ModelSupportsThinking(model) {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
}
// applyPayloadConfig applies payload default and override rules from configuration
// to the given JSON payload for the specified model.
// Defaults only fill missing fields, while overrides always overwrite existing values.

View File

@@ -12,6 +12,7 @@ import (
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type usageReporter struct {
@@ -36,7 +37,7 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox
}
if auth != nil {
reporter.authID = auth.ID
reporter.authIndex = auth.Index
reporter.authIndex = auth.EnsureIndex()
}
return reporter
}
@@ -364,6 +365,204 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true
}
func parseAntigravityUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
if !node.Exists() {
node = usageNode.Get("usageMetadata")
}
if !node.Exists() {
node = usageNode.Get("usage_metadata")
}
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
}
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
node := gjson.GetBytes(payload, "response.usageMetadata")
if !node.Exists() {
node = gjson.GetBytes(payload, "usageMetadata")
}
if !node.Exists() {
node = gjson.GetBytes(payload, "usage_metadata")
}
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
}
var stopChunkWithoutUsage sync.Map
func rememberStopWithoutUsage(traceID string) {
stopChunkWithoutUsage.Store(traceID, struct{}{})
time.AfterFunc(10*time.Minute, func() { stopChunkWithoutUsage.Delete(traceID) })
}
// FilterSSEUsageMetadata removes usageMetadata from SSE events that are not
// terminal (finishReason != "stop"). Stop chunks are left untouched. This
// function is shared between aistudio and antigravity executors.
func FilterSSEUsageMetadata(payload []byte) []byte {
if len(payload) == 0 {
return payload
}
lines := bytes.Split(payload, []byte("\n"))
modified := false
foundData := false
for idx, line := range lines {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
foundData = true
dataIdx := bytes.Index(line, []byte("data:"))
if dataIdx < 0 {
continue
}
rawJSON := bytes.TrimSpace(line[dataIdx+5:])
traceID := gjson.GetBytes(rawJSON, "traceId").String()
if isStopChunkWithoutUsage(rawJSON) && traceID != "" {
rememberStopWithoutUsage(traceID)
continue
}
if traceID != "" {
if _, ok := stopChunkWithoutUsage.Load(traceID); ok && hasUsageMetadata(rawJSON) {
stopChunkWithoutUsage.Delete(traceID)
continue
}
}
cleaned, changed := StripUsageMetadataFromJSON(rawJSON)
if !changed {
continue
}
var rebuilt []byte
rebuilt = append(rebuilt, line[:dataIdx]...)
rebuilt = append(rebuilt, []byte("data:")...)
if len(cleaned) > 0 {
rebuilt = append(rebuilt, ' ')
rebuilt = append(rebuilt, cleaned...)
}
lines[idx] = rebuilt
modified = true
}
if !modified {
if !foundData {
// Handle payloads that are raw JSON without SSE data: prefix.
trimmed := bytes.TrimSpace(payload)
cleaned, changed := StripUsageMetadataFromJSON(trimmed)
if !changed {
return payload
}
return cleaned
}
return payload
}
return bytes.Join(lines, []byte("\n"))
}
// StripUsageMetadataFromJSON drops usageMetadata unless finishReason is present (terminal).
// It handles both formats:
// - Aistudio: candidates.0.finishReason
// - Antigravity: response.candidates.0.finishReason
func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
jsonBytes := bytes.TrimSpace(rawJSON)
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
return rawJSON, false
}
// Check for finishReason in both aistudio and antigravity formats
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
if !finishReason.Exists() {
finishReason = gjson.GetBytes(jsonBytes, "response.candidates.0.finishReason")
}
terminalReason := finishReason.Exists() && strings.TrimSpace(finishReason.String()) != ""
usageMetadata := gjson.GetBytes(jsonBytes, "usageMetadata")
if !usageMetadata.Exists() {
usageMetadata = gjson.GetBytes(jsonBytes, "response.usageMetadata")
}
// Terminal chunk: keep as-is.
if terminalReason {
return rawJSON, false
}
// Nothing to strip
if !usageMetadata.Exists() {
return rawJSON, false
}
// Remove usageMetadata from both possible locations
cleaned := jsonBytes
var changed bool
if gjson.GetBytes(cleaned, "usageMetadata").Exists() {
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
changed = true
}
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() {
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
changed = true
}
return cleaned, changed
}
func hasUsageMetadata(jsonBytes []byte) bool {
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
return false
}
if gjson.GetBytes(jsonBytes, "usageMetadata").Exists() {
return true
}
if gjson.GetBytes(jsonBytes, "response.usageMetadata").Exists() {
return true
}
return false
}
func isStopChunkWithoutUsage(jsonBytes []byte) bool {
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
return false
}
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
if !finishReason.Exists() {
finishReason = gjson.GetBytes(jsonBytes, "response.candidates.0.finishReason")
}
trimmed := strings.TrimSpace(finishReason.String())
if !finishReason.Exists() || trimmed == "" {
return false
}
return !hasUsageMetadata(jsonBytes)
}
func jsonPayload(line []byte) []byte {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 {

View File

@@ -0,0 +1,188 @@
// Package claude provides request translation functionality for Claude Code API compatibility.
// This package handles the conversion of Claude Code API requests into Gemini CLI-compatible
// JSON format, transforming message contents, system instructions, and tool declarations
// into the format expected by Gemini CLI API clients. It performs JSON data transformation
// to ensure compatibility between Claude Code API format and Gemini CLI API's expected format.
package claude
import (
"bytes"
"encoding/json"
"strings"
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
// The function performs the following transformations:
// 1. Extracts the model information from the request
// 2. Restructures the JSON to match Gemini CLI API format
// 3. Converts system instructions to the expected format
// 4. Maps message contents with proper role transformations
// 5. Handles tool declarations and tool choices
// 6. Maps generation configuration parameters
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the Claude Code API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// system instruction
var systemInstruction *client.Content
systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
systemPart := client.Part{Text: systemPrompt}
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
}
}
if len(systemInstruction.Parts) == 0 {
systemInstruction = nil
}
}
// contents
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{Name: functionName, Args: args},
ThoughtSignature: geminiCLIClaudeThoughtSignature,
})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
}
}
}
contents = append(contents, clientContent)
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
}
}
}
// tools
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples")
var toolDeclaration any
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
if systemInstruction != nil {
b, _ := json.Marshal(systemInstruction)
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
}
if len(contents) > 0 {
b, _ := json.Marshal(contents)
out, _ = sjson.SetRaw(out, "request.contents", string(b))
}
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
b, _ := json.Marshal(tools)
out, _ = sjson.SetRaw(out, "request.tools", string(b))
}
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) {
if t.Get("type").String() == "enabled" {
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
budget = util.NormalizeThinkingBudget(modelName, budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
}
}
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
}
outBytes := []byte(out)
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
return outBytes
}

View File

@@ -0,0 +1,447 @@
// Package claude provides response translation functionality for Claude Code API compatibility.
// This package handles the conversion of backend client responses into Claude Code-compatible
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
// different response types including text content, thinking processes, and function calls.
// The translation ensures proper sequencing of SSE events and maintains state across
// multiple response chunks to provide a seamless streaming experience.
package claude
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Params holds parameters for response conversion and maintains state across streaming chunks.
// This structure tracks the current state of the response translation process to ensure
// proper sequencing of SSE events and transitions between different content types.
type Params struct {
HasFirstResponse bool // Indicates if the initial message_start event has been sent
ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function
ResponseIndex int // Index counter for content blocks in the streaming response
HasFinishReason bool // Tracks whether a finish reason has been observed
FinishReason string // The finish reason string returned by the provider
HasUsageMetadata bool // Tracks whether usage metadata has been observed
PromptTokenCount int64 // Cached prompt token count from usage metadata
CandidatesTokenCount int64 // Cached candidate token count from usage metadata
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
TotalTokenCount int64 // Cached total token count from usage metadata
HasSentFinalEvents bool // Indicates if final content/message events have been sent
HasToolUse bool // Indicates if tool use was observed in the stream
}
// ConvertAntigravityResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &Params{
HasFirstResponse: false,
ResponseType: 0,
ResponseIndex: 0,
}
}
params := (*param).(*Params)
if bytes.Equal(rawJSON, []byte("[DONE]")) {
output := ""
appendFinalEvents(params, &output, true)
return []string{
output + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
}
output := ""
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session
if !params.HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
params.HasFirstResponse = true
}
// Process the response parts array from the backend client
// Each part can contain text content, thinking content, or function calls
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
// Extract the different types of content from each part
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
// Handle text content (both regular content and thinking)
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block if already in thinking state
if params.ResponseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to thinking
// First, close any existing content block
if params.ResponseType != 0 {
if params.ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 2 // Set state to thinking
}
} else {
// Process regular text content (user-visible output)
// Continue existing text block if already in content state
if params.ResponseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to text content
// First, close any existing content block
if params.ResponseType != 0 {
if params.ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 1 // Set state to content
}
}
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude Code API compatibility
params.HasToolUse = true
fcName := functionCallResult.Get("name").String()
// Handle state transitions when switching to function calls
// Close any existing function call block first
if params.ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
params.ResponseType = 0
}
// Special handling for thinking state transition
if params.ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
// Close any other existing content block
if params.ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
params.ResponseType = 3
}
}
}
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
params.HasFinishReason = true
params.FinishReason = finishReasonResult.String()
}
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
params.HasUsageMetadata = true
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int()
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 {
params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount
if params.CandidatesTokenCount < 0 {
params.CandidatesTokenCount = 0
}
}
}
if params.HasUsageMetadata && params.HasFinishReason {
appendFinalEvents(params, &output, false)
}
return []string{output}
}
func appendFinalEvents(params *Params, output *string, force bool) {
if params.HasSentFinalEvents {
return
}
if !params.HasUsageMetadata && !force {
return
}
if params.ResponseType != 0 {
*output = *output + "event: content_block_stop\n"
*output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
*output = *output + "\n\n\n"
params.ResponseType = 0
}
stopReason := resolveStopReason(params)
usageOutputTokens := params.CandidatesTokenCount + params.ThoughtsTokenCount
if usageOutputTokens == 0 && params.TotalTokenCount > 0 {
usageOutputTokens = params.TotalTokenCount - params.PromptTokenCount
if usageOutputTokens < 0 {
usageOutputTokens = 0
}
}
*output = *output + "event: message_delta\n"
*output = *output + "data: "
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
*output = *output + delta + "\n\n\n"
params.HasSentFinalEvents = true
}
func resolveStopReason(params *Params) string {
if params.HasToolUse {
return "tool_use"
}
switch params.FinishReason {
case "MAX_TOKENS":
return "max_tokens"
case "STOP", "FINISH_REASON_UNSPECIFIED", "UNKNOWN":
return "end_turn"
}
return "end_turn"
}
// ConvertAntigravityResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the Gemini CLI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = originalRequestRawJSON
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
outputTokens := candidateTokens + thoughtTokens
if outputTokens == 0 && totalTokens > 0 {
outputTokens = totalTokens - promptTokens
if outputTokens < 0 {
outputTokens = 0
}
}
response := map[string]interface{}{
"id": root.Get("response.responseId").String(),
"type": "message",
"role": "assistant",
"model": root.Get("response.modelVersion").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": promptTokens,
"output_tokens": outputTokens,
},
}
parts := root.Get("response.candidates.0.content.parts")
var contentBlocks []interface{}
textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{}
toolIDCounter := 0
hasToolCall := false
flushText := func() {
if textBuilder.Len() == 0 {
return
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": textBuilder.String(),
})
textBuilder.Reset()
}
flushThinking := func() {
if thinkingBuilder.Len() == 0 {
return
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "thinking",
"thinking": thinkingBuilder.String(),
})
thinkingBuilder.Reset()
}
if parts.IsArray() {
for _, part := range parts.Array() {
if text := part.Get("text"); text.Exists() && text.String() != "" {
if part.Get("thought").Bool() {
flushText()
thinkingBuilder.WriteString(text.String())
continue
}
flushThinking()
textBuilder.WriteString(text.String())
continue
}
if functionCall := part.Get("functionCall"); functionCall.Exists() {
flushThinking()
flushText()
hasToolCall = true
name := functionCall.Get("name").String()
toolIDCounter++
toolBlock := map[string]interface{}{
"type": "tool_use",
"id": fmt.Sprintf("tool_%d", toolIDCounter),
"name": name,
"input": map[string]interface{}{},
}
if args := functionCall.Get("args"); args.Exists() {
var parsed interface{}
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
toolBlock["input"] = parsed
}
}
contentBlocks = append(contentBlocks, toolBlock)
continue
}
}
}
flushThinking()
flushText()
response["content"] = contentBlocks
stopReason := "end_turn"
if hasToolCall {
stopReason = "tool_use"
} else {
if finish := root.Get("response.candidates.0.finishReason"); finish.Exists() {
switch finish.String() {
case "MAX_TOKENS":
stopReason = "max_tokens"
case "STOP", "FINISH_REASON_UNSPECIFIED", "UNKNOWN":
stopReason = "end_turn"
default:
stopReason = "end_turn"
}
}
}
response["stop_reason"] = stopReason
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) {
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
delete(response, "usage")
}
}
encoded, err := json.Marshal(response)
if err != nil {
return ""
}
return string(encoded)
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
}

View File

@@ -0,0 +1,20 @@
package claude
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Claude,
Antigravity,
ConvertClaudeRequestToAntigravity,
interfaces.TranslateResponse{
Stream: ConvertAntigravityResponseToClaude,
NonStream: ConvertAntigravityResponseToClaudeNonStream,
TokenCount: ClaudeTokenCount,
},
)
}

View File

@@ -0,0 +1,293 @@
// Package gemini provides request translation functionality for Gemini CLI to Gemini API compatibility.
// It handles parsing and transforming Gemini CLI API requests into Gemini API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini CLI API format and Gemini API's expected format.
package gemini
import (
"bytes"
"encoding/json"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiRequestToAntigravity parses and transforms a Gemini CLI API request into Gemini API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini API.
// The function performs the following transformations:
// 1. Extracts the model information from the request
// 2. Restructures the JSON to match Gemini API format
// 3. Converts system instructions to the expected format
// 4. Fixes CLI tool response format and grouping
//
// Parameters:
// - modelName: The name of the model to use for the request (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
template := ""
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
return []byte{}
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJSON = []byte(template)
// Normalize roles in request.contents: default to valid values if missing/invalid
contents := gjson.GetBytes(rawJSON, "request.contents")
if contents.Exists() {
prevRole := ""
idx := 0
contents.ForEach(func(_ gjson.Result, value gjson.Result) bool {
role := value.Get("role").String()
valid := role == "user" || role == "model"
if role == "" || !valid {
var newRole string
if prevRole == "" {
newRole = "user"
} else if prevRole == "user" {
newRole = "model"
} else {
newRole = "user"
}
path := fmt.Sprintf("request.contents.%d.role", idx)
rawJSON, _ = sjson.SetBytes(rawJSON, path, newRole)
role = newRole
}
prevRole = role
idx++
return true
})
}
toolsResult := gjson.GetBytes(rawJSON, "request.tools")
if toolsResult.Exists() && toolsResult.IsArray() {
toolResults := toolsResult.Array()
for i := 0; i < len(toolResults); i++ {
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations", i))
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
functionDeclarationsResults := functionDeclarationsResult.Array()
for j := 0; j < len(functionDeclarationsResults); j++ {
parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j))
if parametersResult.Exists() {
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("request.tools.%d.function_declarations.%d.parametersJsonSchema", i, j))
rawJSON = []byte(strJson)
}
}
}
}
}
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
} else if part.Get("thoughtSignature").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
}
return true
})
}
return true
})
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
}
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int
}
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
// This function transforms the CLI tool response format by intelligently grouping function calls
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
// and their responses are properly associated and structured.
//
// Parameters:
// - input: The input JSON string to be processed
//
// Returns:
// - string: The processed JSON string with grouped function calls and responses
// - error: An error if the processing fails
func fixCLIToolResponse(input string) (string, error) {
// Parse the input JSON to extract the conversation structure
parsed := gjson.Parse(input)
// Extract the contents array which contains the conversation messages
contents := parsed.Get("request.contents")
if !contents.Exists() {
// log.Debugf(input)
return input, fmt.Errorf("contents not found in input")
}
// Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
// Process each content object in the conversation
// This iterates through messages and groups function calls with their responses
contents.ForEach(func(key, value gjson.Result) bool {
role := value.Get("role").String()
parts := value.Get("parts")
// Check if this content has function responses
var responsePartsInThisContent []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
responsePartsInThisContent = append(responsePartsInThisContent, part)
}
return true
})
// If this content has function responses, collect them
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
}
}
return true // Skip adding this content, responses are merged
}
// If this is a model with function calls, create a new group
if role == "model" {
var functionCallsInThisModel []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part)
}
return true
})
if len(functionCallsInThisModel) > 0 {
// Add the model content
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ModelContent: contentMap,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
} else {
// Non-model content (user, etc.)
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
return true
})
// Handle any remaining pending groups with remaining responses
for _, group := range pendingGroups {
if len(collectedResponses) >= group.ResponsesNeeded {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
newContentsJSON, _ := json.Marshal(newContents)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil
}

View File

@@ -0,0 +1,86 @@
// Package gemini provides request translation functionality for Gemini to Gemini CLI API compatibility.
// It handles parsing and transforming Gemini API requests into Gemini CLI API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini API format and Gemini CLI API's expected format.
package gemini
import (
"bytes"
"context"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertAntigravityResponseToGemini parses and transforms a Gemini CLI API request into Gemini API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini API.
// The function performs the following transformations:
// 1. Extracts the response data from the request
// 2. Handles alternative response formats
// 3. Processes array responses by extracting individual response objects
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model to use for the request (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []string: The transformed request data in Gemini API format
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
if alt, ok := ctx.Value("alt").(string); ok {
var chunk []byte
if alt == "" {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
chunk = []byte(responseResult.Raw)
}
} else {
chunkTemplate := "[]"
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
}
}
}
chunk = []byte(chunkTemplate)
}
return []string{string(chunk)}
}
return []string{}
}
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
// This function processes the complete Gemini CLI request and transforms it into a single Gemini-compatible
// JSON response. It extracts the response data from the request and returns it in the expected format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing the response data
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return responseResult.Raw
}
return string(rawJSON)
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}

View File

@@ -0,0 +1,20 @@
package gemini
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Gemini,
Antigravity,
ConvertGeminiRequestToAntigravity,
interfaces.TranslateResponse{
Stream: ConvertAntigravityResponseToGemini,
NonStream: ConvertAntigravityResponseToGeminiNonStream,
TokenCount: GeminiTokenCount,
},
)
}

View File

@@ -0,0 +1,386 @@
// Package openai provides request translation functionality for OpenAI to Gemini CLI API compatibility.
// It converts OpenAI Chat Completions requests into Gemini CLI compatible JSON using gjson/sjson only.
package chat_completions
import (
"bytes"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
// ConvertOpenAIRequestToAntigravity converts an OpenAI Chat Completions request (raw JSON)
// into a complete Gemini CLI request JSON. All JSON construction uses sjson and lookups use gjson.
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the OpenAI API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
// Base envelope (no default thinkingConfig)
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
// Model
out, _ = sjson.SetBytes(out, "model", modelName)
// Reasoning effort -> thinkingBudget/include_thoughts
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
re := gjson.GetBytes(rawJSON, "reasoning_effort")
hasOfficialThinking := re.Exists()
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
switch re.String() {
case "none":
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
case "auto":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "low":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "medium":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "high":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
default:
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
}
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
var setBudget bool
var normalized int
if v := tc.Get("thinkingBudget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
setBudget = true
} else if v := tc.Get("thinking_budget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
setBudget = true
}
if v := tc.Get("includeThoughts"); v.Exists() {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if v := tc.Get("include_thoughts"); v.Exists() {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if setBudget && normalized != 0 {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
}
}
// For gemini-3-pro-preview, always send default thinkingConfig when none specified.
// This matches the official Gemini CLI behavior which always sends:
// { thinkingBudget: -1, includeThoughts: true }
// See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
// Temperature/top_p/top_k
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
}
if tpr := gjson.GetBytes(rawJSON, "top_p"); tpr.Exists() && tpr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", tpr.Num)
}
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
}
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
var responseMods []string
for _, m := range mods.Array() {
switch strings.ToLower(m.String()) {
case "text":
responseMods = append(responseMods, "TEXT")
case "image":
responseMods = append(responseMods, "IMAGE")
}
}
if len(responseMods) > 0 {
out, _ = sjson.SetBytes(out, "request.generationConfig.responseModalities", responseMods)
}
}
// OpenRouter-style image_config support
// If the input uses top-level image_config.aspect_ratio, map it into request.generationConfig.imageConfig.aspectRatio.
if imgCfg := gjson.GetBytes(rawJSON, "image_config"); imgCfg.Exists() && imgCfg.IsObject() {
if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.aspectRatio", ar.Str)
}
if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.imageSize", size.Str)
}
}
// messages -> systemInstruction + contents
messages := gjson.GetBytes(rawJSON, "messages")
if messages.IsArray() {
arr := messages.Array()
// First pass: assistant tool_calls id->name map
tcID2Name := map[string]string{}
for i := 0; i < len(arr); i++ {
m := arr[i]
if m.Get("role").String() == "assistant" {
tcs := m.Get("tool_calls")
if tcs.IsArray() {
for _, tc := range tcs.Array() {
if tc.Get("type").String() == "function" {
id := tc.Get("id").String()
name := tc.Get("function.name").String()
if id != "" && name != "" {
tcID2Name[id] = name
}
}
}
}
}
}
// Second pass build systemInstruction/tool responses cache
toolResponses := map[string]string{} // tool_call_id -> response text
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
if role == "tool" {
toolCallID := m.Get("tool_call_id").String()
if toolCallID != "" {
c := m.Get("content")
toolResponses[toolCallID] = c.Raw
}
}
}
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
content := m.Get("content")
if role == "system" && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.String())
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
}
} else if role == "user" || (role == "system" && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
} else if content.IsArray() {
items := content.Array()
p := 0
for _, item := range items {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 {
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
case "file":
filename := item.Get("file.filename").String()
fileData := item.Get("file.file_data").String()
ext := ""
if sp := strings.Split(filename, "."); len(sp) > 1 {
ext = sp[len(sp)-1]
}
if mimeType, ok := misc.MimeTypes[ext]; ok {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
p++
} else {
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
}
}
}
}
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
} else if role == "assistant" {
if content.Type == gjson.String {
// Assistant text -> single model content
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
} else if !content.Exists() || content.Type == gjson.Null {
// Tool calls -> single model content with functionCall parts
tcs := m.Get("tool_calls")
if tcs.IsArray() {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
fIDs := make([]string, 0)
for _, tc := range tcs.Array() {
if tc.Get("type").String() != "function" {
continue
}
fid := tc.Get("id").String()
fname := tc.Get("function.name").String()
fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
if fid != "" {
fIDs = append(fIDs, fid)
}
}
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
resp := toolResponses[fid]
if resp == "" {
resp = "{}"
}
// Handle non-JSON output gracefully (matches dev branch approach)
if resp != "null" {
parsed := gjson.Parse(resp)
if parsed.Type == gjson.JSON {
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
} else {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
}
}
pp++
}
}
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
}
}
}
}
}
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
toolNode := []byte(`{}`)
hasTool := false
hasFunction := false
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
if fn.Exists() && fn.IsObject() {
fnRaw := fn.Raw
if fn.Get("parameters").Exists() {
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
} else {
fnRaw = renamed
}
} else {
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
toolNode = tmp
hasFunction = true
hasTool = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
hasTool = true
}
}
if hasTool {
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
}
}
return common.AttachDefaultSafetySettings(out, "request.safetySettings")
}
// itoa converts int to string without strconv import for few usages.
func itoa(i int) string { return fmt.Sprintf("%d", i) }
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
func quoteIfNeeded(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return "\"\""
}
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
return s
}
// escape quotes minimally
s = strings.ReplaceAll(s, "\\", "\\\\")
s = strings.ReplaceAll(s, "\"", "\\\"")
return "\"" + s + "\""
}

View File

@@ -0,0 +1,215 @@
// Package openai provides response translation functionality for Gemini CLI to OpenAI API compatibility.
// This package handles the conversion of Gemini CLI API responses into OpenAI Chat Completions-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by OpenAI API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, reasoning content, and usage metadata appropriately.
package chat_completions
import (
"bytes"
"context"
"encoding/json"
"fmt"
"time"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// convertCliResponseToOpenAIChatParams holds parameters for response conversion.
type convertCliResponseToOpenAIChatParams struct {
UnixTimestamp int64
FunctionIndex int
}
// ConvertAntigravityResponseToOpenAI translates a single chunk of a streaming response from the
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
// The function handles text content, tool calls, reasoning content, and usage metadata, outputting
// responses that match the OpenAI API format. It supports incremental updates for streaming responses.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &convertCliResponseToOpenAIChatParams{
UnixTimestamp: 0,
FunctionIndex: 0,
}
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
} else {
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
}
// Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
}
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
hasFunctionCall := false
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
thoughtSignatureResult := partResult.Get("thoughtSignature")
if !thoughtSignatureResult.Exists() {
thoughtSignatureResult = partResult.Get("thought_signature")
}
inlineDataResult := partResult.Get("inlineData")
if !inlineDataResult.Exists() {
inlineDataResult = partResult.Get("inline_data")
}
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
// Ignore encrypted thoughtSignature but keep any actual content in the same part.
if hasThoughtSignature && !hasContentPayload {
continue
}
if partTextResult.Exists() {
textContent := partTextResult.String()
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
hasFunctionCall = true
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
functionCallIndex = len(toolCallsResult.Array())
} else {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
} else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String()
if data == "" {
continue
}
mimeType := inlineDataResult.Get("mimeType").String()
if mimeType == "" {
mimeType = inlineDataResult.Get("mime_type").String()
}
if mimeType == "" {
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload))
}
}
}
if hasFunctionCall {
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
}
return []string{template}
}
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
// This function processes the complete Gemini CLI response and transforms it into a single OpenAI-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the OpenAI API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
}
return ""
}

View File

@@ -0,0 +1,19 @@
package chat_completions
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI,
Antigravity,
ConvertOpenAIRequestToAntigravity,
interfaces.TranslateResponse{
Stream: ConvertAntigravityResponseToOpenAI,
NonStream: ConvertAntigravityResponseToOpenAINonStream,
},
)
}

View File

@@ -0,0 +1,14 @@
package responses
import (
"bytes"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
)
func ConvertOpenAIResponsesRequestToAntigravity(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
return ConvertGeminiRequestToAntigravity(modelName, rawJSON, stream)
}

View File

@@ -0,0 +1,35 @@
package responses
import (
"context"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
"github.com/tidwall/gjson"
)
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
}
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
}
requestResult := gjson.GetBytes(originalRequestRawJSON, "request")
if responseResult.Exists() {
originalRequestRawJSON = []byte(requestResult.Raw)
}
requestResult = gjson.GetBytes(requestRawJSON, "request")
if responseResult.Exists() {
requestRawJSON = []byte(requestResult.Raw)
}
return ConvertGeminiResponseToOpenAIResponsesNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}

View File

@@ -0,0 +1,19 @@
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenaiResponse,
Antigravity,
ConvertOpenAIResponsesRequestToAntigravity,
interfaces.TranslateResponse{
Stream: ConvertAntigravityResponseToOpenAIResponses,
NonStream: ConvertAntigravityResponseToOpenAIResponsesNonStream,
},
)
}

View File

@@ -98,6 +98,20 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
}
}
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
} else if part.Get("thoughtSignature").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
}
return true
})
}
return true
})
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
}

View File

@@ -131,6 +131,9 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.aspectRatio", ar.Str)
}
if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.imageSize", size.Str)
}
}
// messages -> systemInstruction + contents
@@ -281,11 +284,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
}
}
// tools -> request.tools[0].functionDeclarations
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "request.tools", []byte(`[{"functionDeclarations":[]}]`))
fdPath := "request.tools.0.functionDeclarations"
toolNode := []byte(`{}`)
hasTool := false
hasFunction := false
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
@@ -323,14 +327,32 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
}
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
tmp, errSet := sjson.SetRawBytes(out, fdPath+".-1", []byte(fnRaw))
if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
out = tmp
toolNode = tmp
hasFunction = true
hasTool = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
hasTool = true
}
}
if hasTool {
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
}
}

View File

@@ -104,17 +104,31 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
thoughtSignatureResult := partResult.Get("thoughtSignature")
if !thoughtSignatureResult.Exists() {
thoughtSignatureResult = partResult.Get("thought_signature")
}
inlineDataResult := partResult.Get("inlineData")
if !inlineDataResult.Exists() {
inlineDataResult = partResult.Get("inline_data")
}
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
// Ignore encrypted thoughtSignature but keep any actual content in the same part.
if hasThoughtSignature && !hasContentPayload {
continue
}
if partTextResult.Exists() {
textContent := partTextResult.String()
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {

View File

@@ -46,5 +46,19 @@ func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []by
}
}
gjson.GetBytes(rawJSON, "contents").ForEach(func(key, content gjson.Result) bool {
if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
} else if part.Get("thoughtSignature").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
}
return true
})
}
return true
})
return common.AttachDefaultSafetySettings(rawJSON, "safetySettings")
}

View File

@@ -30,6 +30,11 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
if toolsResult.Exists() && toolsResult.IsArray() {
toolResults := toolsResult.Array()
for i := 0; i < len(toolResults); i++ {
if gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.functionDeclarations", i)).Exists() {
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("tools.%d.functionDeclarations", i), fmt.Sprintf("tools.%d.function_declarations", i))
rawJSON = []byte(strJson)
}
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations", i))
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
functionDeclarationsResults := functionDeclarationsResult.Array()
@@ -72,7 +77,25 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
return true
})
out = common.AttachDefaultSafetySettings(out, "safetySettings")
gjson.GetBytes(out, "contents").ForEach(func(key, content gjson.Result) bool {
if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
out, _ = sjson.SetBytes(out, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
} else if part.Get("thoughtSignature").Exists() {
out, _ = sjson.SetBytes(out, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
}
return true
})
}
return true
})
if gjson.GetBytes(rawJSON, "generationConfig.responseSchema").Exists() {
strJson, _ := util.RenameKey(string(out), "generationConfig.responseSchema", "generationConfig.responseJsonSchema")
out = []byte(strJson)
}
out = common.AttachDefaultSafetySettings(out, "safetySettings")
return out
}

View File

@@ -122,6 +122,9 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String {
out, _ = sjson.SetBytes(out, "generationConfig.imageConfig.aspectRatio", ar.Str)
}
if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String {
out, _ = sjson.SetBytes(out, "generationConfig.imageConfig.imageSize", size.Str)
}
}
// messages -> systemInstruction + contents
@@ -297,11 +300,12 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
}
// tools -> tools[0].functionDeclarations
// tools -> tools[0].functionDeclarations + tools[0].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[{"functionDeclarations":[]}]`))
fdPath := "tools.0.functionDeclarations"
toolNode := []byte(`{}`)
hasTool := false
hasFunction := false
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
@@ -311,6 +315,17 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
} else {
fnRaw = renamed
}
@@ -328,14 +343,32 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
tmp, errSet := sjson.SetRawBytes(out, fdPath+".-1", []byte(fnRaw))
if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
out = tmp
toolNode = tmp
hasFunction = true
hasTool = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
hasTool = true
}
}
if hasTool {
out, _ = sjson.SetRawBytes(out, "tools", []byte("[]"))
out, _ = sjson.SetRawBytes(out, "tools.0", toolNode)
}
}

View File

@@ -111,13 +111,26 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
if !inlineDataResult.Exists() {
inlineDataResult = partResult.Get("inline_data")
}
thoughtSignatureResult := partResult.Get("thoughtSignature")
if !thoughtSignatureResult.Exists() {
thoughtSignatureResult = partResult.Get("thought_signature")
}
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
// Skip pure thoughtSignature parts but keep any actual payload in the same part.
if hasThoughtSignature && !hasContentPayload {
continue
}
if partTextResult.Exists() {
text := partTextResult.String()
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", text)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
template, _ = sjson.Set(template, "choices.0.delta.content", text)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {

View File

@@ -33,7 +33,83 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
// Convert input messages to Gemini contents format
if input := root.Get("input"); input.Exists() && input.IsArray() {
input.ForEach(func(_, item gjson.Result) bool {
items := input.Array()
// Normalize consecutive function calls and outputs so each call is immediately followed by its response
normalized := make([]gjson.Result, 0, len(items))
for i := 0; i < len(items); {
item := items[i]
itemType := item.Get("type").String()
itemRole := item.Get("role").String()
if itemType == "" && itemRole != "" {
itemType = "message"
}
if itemType == "function_call" {
var calls []gjson.Result
var outputs []gjson.Result
for i < len(items) {
next := items[i]
nextType := next.Get("type").String()
nextRole := next.Get("role").String()
if nextType == "" && nextRole != "" {
nextType = "message"
}
if nextType != "function_call" {
break
}
calls = append(calls, next)
i++
}
for i < len(items) {
next := items[i]
nextType := next.Get("type").String()
nextRole := next.Get("role").String()
if nextType == "" && nextRole != "" {
nextType = "message"
}
if nextType != "function_call_output" {
break
}
outputs = append(outputs, next)
i++
}
if len(calls) > 0 {
outputMap := make(map[string]gjson.Result, len(outputs))
for _, out := range outputs {
outputMap[out.Get("call_id").String()] = out
}
for _, call := range calls {
normalized = append(normalized, call)
callID := call.Get("call_id").String()
if resp, ok := outputMap[callID]; ok {
normalized = append(normalized, resp)
delete(outputMap, callID)
}
}
for _, out := range outputs {
if _, ok := outputMap[out.Get("call_id").String()]; ok {
normalized = append(normalized, out)
}
}
continue
}
}
if itemType == "function_call_output" {
normalized = append(normalized, item)
i++
continue
}
normalized = append(normalized, item)
i++
}
for _, item := range normalized {
itemType := item.Get("type").String()
itemRole := item.Get("role").String()
if itemType == "" && itemRole != "" {
@@ -59,7 +135,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
}
}
return true
continue
}
// Handle regular messages
@@ -67,39 +143,101 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
// even when the message.role is "user". We split such items into distinct Gemini messages
// with roles derived from the content type to match docs/convert-2.md.
if contentArray := item.Get("content"); contentArray.Exists() && contentArray.IsArray() {
currentRole := ""
var currentParts []string
flush := func() {
if currentRole == "" || len(currentParts) == 0 {
currentParts = nil
return
}
one := `{"role":"","parts":[]}`
one, _ = sjson.Set(one, "role", currentRole)
for _, part := range currentParts {
one, _ = sjson.SetRaw(one, "parts.-1", part)
}
out, _ = sjson.SetRaw(out, "contents.-1", one)
currentParts = nil
}
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
contentType := contentItem.Get("type").String()
if contentType == "" {
contentType = "input_text"
}
effRole := "user"
if itemRole != "" {
switch strings.ToLower(itemRole) {
case "assistant", "model":
effRole = "model"
default:
effRole = strings.ToLower(itemRole)
}
}
if contentType == "output_text" {
effRole = "model"
}
if effRole == "assistant" {
effRole = "model"
}
if currentRole != "" && effRole != currentRole {
flush()
currentRole = ""
}
if currentRole == "" {
currentRole = effRole
}
var partJSON string
switch contentType {
case "input_text", "output_text":
if text := contentItem.Get("text"); text.Exists() {
effRole := "user"
if itemRole != "" {
switch strings.ToLower(itemRole) {
case "assistant", "model":
effRole = "model"
default:
effRole = strings.ToLower(itemRole)
partJSON = `{"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
}
case "input_image":
imageURL := contentItem.Get("image_url").String()
if imageURL == "" {
imageURL = contentItem.Get("url").String()
}
if imageURL != "" {
mimeType := "application/octet-stream"
data := ""
if strings.HasPrefix(imageURL, "data:") {
trimmed := strings.TrimPrefix(imageURL, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mimeType = mediaAndData[0]
}
data = mediaAndData[1]
} else {
mediaAndData = strings.SplitN(trimmed, ",", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mimeType = mediaAndData[0]
}
data = mediaAndData[1]
}
}
}
if contentType == "output_text" {
effRole = "model"
if data != "" {
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
}
if effRole == "assistant" {
effRole = "model"
}
one := `{"role":"","parts":[]}`
one, _ = sjson.Set(one, "role", effRole)
textPart := `{"text":""}`
textPart, _ = sjson.Set(textPart, "text", text.String())
one, _ = sjson.SetRaw(one, "parts.-1", textPart)
out, _ = sjson.SetRaw(out, "contents.-1", one)
}
}
if partJSON != "" {
currentParts = append(currentParts, partJSON)
}
return true
})
flush()
}
case "function_call":
@@ -124,7 +262,8 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
case "function_call_output":
// Handle function call outputs - convert to function message with functionResponse
callID := item.Get("call_id").String()
output := item.Get("output").String()
// Use .Raw to preserve the JSON encoding (includes quotes for strings)
outputRaw := item.Get("output").Str
functionContent := `{"role":"function","parts":[]}`
functionResponse := `{"functionResponse":{"name":"","response":{}}}`
@@ -147,18 +286,19 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName)
// Parse output JSON string and set as response content
if output != "" {
outputResult := gjson.Parse(output)
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputResult.Raw)
// Set the raw JSON output directly (preserves string encoding)
if outputRaw != "" && outputRaw != "null" {
output := gjson.Parse(outputRaw)
if output.Type == gjson.JSON {
functionResponse, _ = sjson.SetRaw(functionResponse, "functionResponse.response.result", output.Raw)
} else {
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputRaw)
}
}
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
}
return true
})
}
} else if input.Exists() && input.Type == gjson.String {
// Simple string input conversion to user message
userContent := `{"role":"user","parts":[{"text":""}]}`

View File

@@ -433,12 +433,18 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// output tokens
if v := um.Get("candidatesTokenCount"); v.Exists() {
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
} else {
completed, _ = sjson.Set(completed, "response.usage.output_tokens", 0)
}
if v := um.Get("thoughtsTokenCount"); v.Exists() {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", v.Int())
} else {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", 0)
}
if v := um.Get("totalTokenCount"); v.Exists() {
completed, _ = sjson.Set(completed, "response.usage.total_tokens", v.Int())
} else {
completed, _ = sjson.Set(completed, "response.usage.total_tokens", 0)
}
}

View File

@@ -28,4 +28,9 @@ import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini-cli"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
)

View File

@@ -202,6 +202,8 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
out, _ = sjson.Set(out, "reasoning_effort", "medium")
case "high":
out, _ = sjson.Set(out, "reasoning_effort", "high")
case "xhigh":
out, _ = sjson.Set(out, "reasoning_effort", "xhigh")
default:
out, _ = sjson.Set(out, "reasoning_effort", "auto")
}

View File

@@ -34,6 +34,15 @@ func ParseGeminiThinkingSuffix(model string) (string, *int, *bool, bool) {
return base, &budgetValue, &include, true
}
// Handle "-reasoning" suffix: enables thinking with dynamic budget (-1)
// Maps: gemini-2.5-flash-reasoning -> gemini-2.5-flash with thinkingBudget=-1
if strings.HasSuffix(lower, "-reasoning") {
base := model[:len(model)-len("-reasoning")]
budgetValue := -1 // Dynamic budget
include := true
return base, &budgetValue, &include, true
}
idx := strings.LastIndex(lower, "-thinking-")
if idx == -1 {
return model, nil, nil, false

View File

@@ -30,6 +30,16 @@ import (
log "github.com/sirupsen/logrus"
)
func matchProvider(provider string, targets []string) (string, bool) {
p := strings.ToLower(strings.TrimSpace(provider))
for _, t := range targets {
if strings.EqualFold(p, strings.TrimSpace(t)) {
return p, true
}
}
return p, false
}
// storePersister captures persistence-capable token store methods used by the watcher.
type storePersister interface {
PersistConfig(ctx context.Context) error
@@ -54,6 +64,7 @@ type Watcher struct {
lastConfigHash string
authQueue chan<- AuthUpdate
currentAuths map[string]*coreauth.Auth
runtimeAuths map[string]*coreauth.Auth
dispatchMu sync.Mutex
dispatchCond *sync.Cond
pendingUpdates map[string]AuthUpdate
@@ -169,7 +180,7 @@ func (w *Watcher) Start(ctx context.Context) error {
go w.processEvents(ctx)
// Perform an initial full reload based on current config and auth dir
w.reloadClients(true)
w.reloadClients(true, nil)
return nil
}
@@ -221,9 +232,57 @@ func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
}
}
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
// to push auth updates through the same queue used by file/config watchers.
// Returns true if the update was enqueued; false if no queue is configured.
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
if w == nil {
return false
}
w.clientsMutex.Lock()
if w.runtimeAuths == nil {
w.runtimeAuths = make(map[string]*coreauth.Auth)
}
switch update.Action {
case AuthUpdateActionAdd, AuthUpdateActionModify:
if update.Auth != nil && update.Auth.ID != "" {
clone := update.Auth.Clone()
w.runtimeAuths[clone.ID] = clone
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
w.currentAuths[clone.ID] = clone.Clone()
}
case AuthUpdateActionDelete:
id := update.ID
if id == "" && update.Auth != nil {
id = update.Auth.ID
}
if id != "" {
delete(w.runtimeAuths, id)
if w.currentAuths != nil {
delete(w.currentAuths, id)
}
}
}
w.clientsMutex.Unlock()
if w.getAuthQueue() == nil {
return false
}
w.dispatchAuthUpdates([]AuthUpdate{update})
return true
}
func (w *Watcher) refreshAuthState() {
auths := w.SnapshotCoreAuths()
w.clientsMutex.Lock()
if len(w.runtimeAuths) > 0 {
for _, a := range w.runtimeAuths {
if a != nil {
auths = append(auths, a.Clone())
}
}
}
updates := w.prepareAuthUpdatesLocked(auths)
w.clientsMutex.Unlock()
w.dispatchAuthUpdates(updates)
@@ -437,6 +496,18 @@ func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str
return hex.EncodeToString(sum[:])
}
func computeVertexCompatModelsHash(models []config.VertexCompatModel) string {
if len(models) == 0 {
return ""
}
data, err := json.Marshal(models)
if err != nil || len(data) == 0 {
return ""
}
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
// computeClaudeModelsHash returns a stable hash for Claude model aliases.
func computeClaudeModelsHash(models []config.ClaudeModel) string {
if len(models) == 0 {
@@ -450,6 +521,142 @@ func computeClaudeModelsHash(models []config.ClaudeModel) string {
return hex.EncodeToString(sum[:])
}
func computeExcludedModelsHash(excluded []string) string {
if len(excluded) == 0 {
return ""
}
normalized := make([]string, 0, len(excluded))
for _, entry := range excluded {
if trimmed := strings.TrimSpace(entry); trimmed != "" {
normalized = append(normalized, strings.ToLower(trimmed))
}
}
if len(normalized) == 0 {
return ""
}
sort.Strings(normalized)
data, err := json.Marshal(normalized)
if err != nil || len(data) == 0 {
return ""
}
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
type excludedModelsSummary struct {
hash string
count int
}
func summarizeExcludedModels(list []string) excludedModelsSummary {
if len(list) == 0 {
return excludedModelsSummary{}
}
seen := make(map[string]struct{}, len(list))
normalized := make([]string, 0, len(list))
for _, entry := range list {
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
if _, exists := seen[trimmed]; exists {
continue
}
seen[trimmed] = struct{}{}
normalized = append(normalized, trimmed)
}
}
sort.Strings(normalized)
return excludedModelsSummary{
hash: computeExcludedModelsHash(normalized),
count: len(normalized),
}
}
func summarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary {
if len(entries) == 0 {
return nil
}
out := make(map[string]excludedModelsSummary, len(entries))
for k, v := range entries {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
out[key] = summarizeExcludedModels(v)
}
return out
}
func diffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) {
oldSummary := summarizeOAuthExcludedModels(oldMap)
newSummary := summarizeOAuthExcludedModels(newMap)
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
for k := range oldSummary {
keys[k] = struct{}{}
}
for k := range newSummary {
keys[k] = struct{}{}
}
changes := make([]string, 0, len(keys))
affected := make([]string, 0, len(keys))
for key := range keys {
oldInfo, okOld := oldSummary[key]
newInfo, okNew := newSummary[key]
switch {
case okOld && !okNew:
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key))
affected = append(affected, key)
case !okOld && okNew:
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count))
affected = append(affected, key)
case okOld && okNew && oldInfo.hash != newInfo.hash:
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
affected = append(affected, key)
}
}
sort.Strings(changes)
sort.Strings(affected)
return changes, affected
}
func applyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
if auth == nil || cfg == nil {
return
}
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
seen := make(map[string]struct{})
add := func(list []string) {
for _, entry := range list {
if trimmed := strings.TrimSpace(entry); trimmed != "" {
key := strings.ToLower(trimmed)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
}
}
}
if authKindKey == "apikey" {
add(perKey)
} else if cfg.OAuthExcludedModels != nil {
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
add(cfg.OAuthExcludedModels[providerKey])
}
combined := make([]string, 0, len(seen))
for k := range seen {
combined = append(combined, k)
}
sort.Strings(combined)
hash := computeExcludedModelsHash(combined)
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
if hash != "" {
auth.Attributes["excluded_models_hash"] = hash
}
if authKind != "" {
auth.Attributes["auth_kind"] = authKind
}
}
// SetClients sets the file-based clients.
// SetClients removed
// SetAPIKeyClients removed
@@ -474,6 +681,33 @@ func (w *Watcher) processEvents(ctx context.Context) {
}
}
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
data, errRead := os.ReadFile(path)
if errRead != nil {
return false, errRead
}
if len(data) == 0 {
return false, nil
}
sum := sha256.Sum256(data)
curHash := hex.EncodeToString(sum[:])
w.clientsMutex.RLock()
prevHash, ok := w.lastAuthHashes[path]
w.clientsMutex.RUnlock()
if ok && prevHash == curHash {
return true, nil
}
return false, nil
}
func (w *Watcher) isKnownAuthFile(path string) bool {
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
_, ok := w.lastAuthHashes[path]
return ok
}
// handleEvent processes individual file system events
func (w *Watcher) handleEvent(event fsnotify.Event) {
// Filter only relevant events: config file or auth-dir JSON files.
@@ -497,19 +731,33 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
}
// Handle auth directory changes incrementally (.json only)
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
// Wait briefly; if the path exists again, treat as an update instead of removal.
time.Sleep(replaceCheckDelay)
if _, statErr := os.Stat(event.Name); statErr == nil {
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
return
}
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
w.addOrUpdateClient(event.Name)
return
}
if !w.isKnownAuthFile(event.Name) {
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
return
}
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
w.removeClient(event.Name)
return
}
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
return
}
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
w.addOrUpdateClient(event.Name)
}
}
@@ -593,6 +841,11 @@ func (w *Watcher) reloadConfig() bool {
w.config = newConfig
w.clientsMutex.Unlock()
var affectedOAuthProviders []string
if oldConfig != nil {
_, affectedOAuthProviders = diffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
}
// Always apply the current log level based on the latest config.
// This ensures logrus reflects the desired level even if change detection misses.
util.SetLogLevel(newConfig)
@@ -618,12 +871,12 @@ func (w *Watcher) reloadConfig() bool {
log.Infof("config successfully reloaded, triggering client reload")
// Reload clients with new config
w.reloadClients(authDirChanged)
w.reloadClients(authDirChanged, affectedOAuthProviders)
return true
}
// reloadClients performs a full scan and reload of all clients.
func (w *Watcher) reloadClients(rescanAuth bool) {
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) {
log.Debugf("starting full client load process")
w.clientsMutex.RLock()
@@ -635,12 +888,34 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
return
}
if len(affectedOAuthProviders) > 0 {
w.clientsMutex.Lock()
if w.currentAuths != nil {
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
for id, auth := range w.currentAuths {
if auth == nil {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if _, match := matchProvider(provider, affectedOAuthProviders); match {
continue
}
filtered[id] = auth
}
w.currentAuths = filtered
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
} else {
w.currentAuths = nil
}
w.clientsMutex.Unlock()
}
// Unregister all old API key clients before creating new ones
// no legacy clients to unregister
// Create new API key clients based on the new config
geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
totalAPIKeyClients := geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
var authFileCount int
@@ -683,7 +958,7 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
w.clientsMutex.Unlock()
}
totalNewClients := authFileCount + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
// Ensure consumers observe the new configuration before auth updates dispatch.
if w.reloadCallback != nil {
@@ -693,10 +968,11 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
w.refreshAuthState()
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
totalNewClients,
authFileCount,
geminiAPIKeyCount,
vertexCompatAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,
openAICompatCount,
@@ -808,8 +1084,10 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now,
UpdatedAt: now,
}
applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey")
out = append(out, a)
}
// Claude API keys -> synthesize auths
for i := range cfg.ClaudeKey {
ck := cfg.ClaudeKey[i]
@@ -841,6 +1119,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now,
UpdatedAt: now,
}
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
out = append(out, a)
}
// Codex API keys -> synthesize auths
@@ -870,6 +1149,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now,
UpdatedAt: now,
}
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
out = append(out, a)
}
for i := range cfg.OpenAICompatibility {
@@ -974,6 +1254,43 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
}
}
}
// Process Vertex API key providers (Vertex-compatible endpoints)
for i := range cfg.VertexCompatAPIKey {
compat := &cfg.VertexCompatAPIKey[i]
providerName := "vertex"
base := strings.TrimSpace(compat.BaseURL)
key := strings.TrimSpace(compat.APIKey)
proxyURL := strings.TrimSpace(compat.ProxyURL)
idKind := fmt.Sprintf("vertex:apikey:%s", base)
id, token := idGen.next(idKind, key, base, proxyURL)
attrs := map[string]string{
"source": fmt.Sprintf("config:vertex-apikey[%s]", token),
"base_url": base,
"provider_key": providerName,
}
if key != "" {
attrs["api_key"] = key
}
if hash := computeVertexCompatModelsHash(compat.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(compat.Headers, attrs)
a := &coreauth.Auth{
ID: id,
Provider: providerName,
Label: "vertex-apikey",
Status: coreauth.StatusActive,
ProxyURL: proxyURL,
Attributes: attrs,
CreatedAt: now,
UpdatedAt: now,
}
applyAuthExcludedModelsMeta(a, cfg, nil, "apikey")
out = append(out, a)
}
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
entries, _ := os.ReadDir(w.authDir)
for _, e := range entries {
@@ -1030,8 +1347,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now,
UpdatedAt: now,
}
applyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
if provider == "gemini-cli" {
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
applyAuthExcludedModelsMeta(v, cfg, nil, "oauth")
}
out = append(out, a)
out = append(out, virtuals...)
continue
@@ -1186,8 +1507,9 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int {
return authFileCount
}
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
geminiAPIKeyCount := 0
vertexCompatAPIKeyCount := 0
claudeAPIKeyCount := 0
codexAPIKeyCount := 0
openAICompatCount := 0
@@ -1196,6 +1518,9 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
// Stateless executor handles Gemini API keys; avoid constructing legacy clients.
geminiAPIKeyCount += len(cfg.GeminiKey)
}
if len(cfg.VertexCompatAPIKey) > 0 {
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
}
if len(cfg.ClaudeKey) > 0 {
claudeAPIKeyCount += len(cfg.ClaudeKey)
}
@@ -1213,7 +1538,7 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
}
}
}
return geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
}
func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
@@ -1378,6 +1703,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.RequestRetry != newCfg.RequestRetry {
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
}
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
}
if oldCfg.ProxyURL != newCfg.ProxyURL {
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
}
@@ -1420,6 +1748,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
}
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
newExcluded := summarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
}
}
if !reflect.DeepEqual(trimStrings(oldCfg.GlAPIKey), trimStrings(newCfg.GlAPIKey)) {
changes = append(changes, "generative-language-api-key: values updated (legacy view, redacted)")
@@ -1448,6 +1781,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
}
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
newExcluded := summarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
}
}
}
@@ -1473,9 +1811,18 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
}
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
newExcluded := summarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
}
}
}
if entries, _ := diffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
changes = append(changes, entries...)
}
// Remote management (never print the key)
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))

View File

@@ -69,6 +69,27 @@ func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
return
}
switch request.Action {
case "gemini-3-pro-preview":
c.JSON(http.StatusOK, gin.H{
"name": "models/gemini-3-pro-preview",
"version": "3",
"displayName": "Gemini 3 Pro Preview",
"description": "Gemini 3 Pro Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": []string{
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
)
case "gemini-2.5-pro":
c.JSON(http.StatusOK, gin.H{
"name": "models/gemini-2.5-pro",

View File

@@ -4,6 +4,7 @@
package handlers
import (
"bytes"
"fmt"
"net/http"
"strings"
@@ -120,11 +121,11 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
data := params[0]
switch data.(type) {
case []byte:
c.Set("API_RESPONSE", data.([]byte))
appendAPIResponse(c, data.([]byte))
case error:
c.Set("API_RESPONSE", []byte(data.(error).Error()))
appendAPIResponse(c, []byte(data.(error).Error()))
case string:
c.Set("API_RESPONSE", []byte(data.(string)))
appendAPIResponse(c, []byte(data.(string)))
case bool:
case nil:
}
@@ -135,6 +136,28 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
}
}
// appendAPIResponse preserves any previously captured API response and appends new data.
func appendAPIResponse(c *gin.Context, data []byte) {
if c == nil || len(data) == 0 {
return
}
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
combined := make([]byte, 0, len(existingBytes)+len(data)+1)
combined = append(combined, existingBytes...)
if existingBytes[len(existingBytes)-1] != '\n' {
combined = append(combined, '\n')
}
combined = append(combined, data...)
c.Set("API_RESPONSE", combined)
return
}
}
c.Set("API_RESPONSE", bytes.Clone(data))
}
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
@@ -297,7 +320,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
// Resolve "auto" model to an actual available model first
resolvedModelName := util.ResolveAutoModel(modelName)
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
// First, normalize the model name to handle suffixes like "-thinking-128"

View File

@@ -21,6 +21,7 @@ import (
const (
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
antigravityCallbackPort = 51121
)
var antigravityScopes = []string{
@@ -58,6 +59,8 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
opts = &LoginOptions{}
}
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("antigravity: failed to generate state: %w", err)
@@ -112,14 +115,14 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
return nil, fmt.Errorf("antigravity: missing authorization code")
}
tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI)
tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI, httpClient)
if errToken != nil {
return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken)
}
email := ""
if tokenResp.AccessToken != "" {
if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken); errInfo == nil && strings.TrimSpace(info.Email) != "" {
if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" {
email = strings.TrimSpace(info.Email)
}
}
@@ -160,7 +163,8 @@ type callbackResult struct {
}
func startAntigravityCallbackServer() (*http.Server, int, <-chan callbackResult, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
addr := fmt.Sprintf(":%d", antigravityCallbackPort)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, 0, nil, err
}
@@ -200,7 +204,7 @@ type antigravityTokenResponse struct {
TokenType string `json:"token_type"`
}
func exchangeAntigravityCode(ctx context.Context, code, redirectURI string) (*antigravityTokenResponse, error) {
func exchangeAntigravityCode(ctx context.Context, code, redirectURI string, httpClient *http.Client) (*antigravityTokenResponse, error) {
data := url.Values{}
data.Set("code", code)
data.Set("client_id", antigravityClientID)
@@ -214,7 +218,7 @@ func exchangeAntigravityCode(ctx context.Context, code, redirectURI string) (*an
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := http.DefaultClient.Do(req)
resp, errDo := httpClient.Do(req)
if errDo != nil {
return nil, errDo
}
@@ -238,7 +242,7 @@ type antigravityUserInfo struct {
Email string `json:"email"`
}
func fetchAntigravityUserInfo(ctx context.Context, accessToken string) (*antigravityUserInfo, error) {
func fetchAntigravityUserInfo(ctx context.Context, accessToken string, httpClient *http.Client) (*antigravityUserInfo, error) {
if strings.TrimSpace(accessToken) == "" {
return &antigravityUserInfo{}, nil
}
@@ -248,7 +252,7 @@ func fetchAntigravityUserInfo(ctx context.Context, accessToken string) (*antigra
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, errDo := http.DefaultClient.Do(req)
resp, errDo := httpClient.Do(req)
if errDo != nil {
return nil, errDo
}

View File

@@ -106,6 +106,10 @@ type Manager struct {
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int
// Retry controls request retry behavior.
requestRetry atomic.Int32
maxRetryInterval atomic.Int64
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider
@@ -145,6 +149,21 @@ func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
m.mu.Unlock()
}
// SetRetryConfig updates retry attempts and cooldown wait interval.
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration) {
if m == nil {
return
}
if retry < 0 {
retry = 0
}
if maxRetryInterval < 0 {
maxRetryInterval = 0
}
m.requestRetry.Store(int32(retry))
m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds())
}
// RegisterExecutor registers a provider executor with the manager.
func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
if executor == nil {
@@ -188,8 +207,12 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil || auth.ID == "" {
return nil, nil
}
auth.EnsureIndex()
m.mu.Lock()
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == 0 {
auth.Index = existing.Index
auth.indexAssigned = existing.indexAssigned
}
auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone()
m.mu.Unlock()
_ = m.persist(ctx, auth)
@@ -229,13 +252,28 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
var lastErr error
for _, provider := range rotated {
resp, errExec := m.executeWithProvider(ctx, provider, req, opts)
for attempt := 0; attempt < attempts; attempt++ {
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
return m.executeWithProvider(execCtx, provider, req, opts)
})
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return cliproxyexecutor.Response{}, errWait
}
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
@@ -253,13 +291,28 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
var lastErr error
for _, provider := range rotated {
resp, errExec := m.executeCountWithProvider(ctx, provider, req, opts)
for attempt := 0; attempt < attempts; attempt++ {
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
return m.executeCountWithProvider(execCtx, provider, req, opts)
})
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return cliproxyexecutor.Response{}, errWait
}
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
@@ -277,13 +330,28 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
var lastErr error
for _, provider := range rotated {
chunks, errStream := m.executeStreamWithProvider(ctx, provider, req, opts)
for attempt := 0; attempt < attempts; attempt++ {
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) {
return m.executeStreamWithProvider(execCtx, provider, req, opts)
})
if errStream == nil {
return chunks, nil
}
lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return nil, errWait
}
}
if lastErr != nil {
return nil, lastErr
@@ -507,6 +575,123 @@ func (m *Manager) advanceProviderCursor(model string, providers []string) {
m.mu.Unlock()
}
func (m *Manager) retrySettings() (int, time.Duration) {
if m == nil {
return 0, 0
}
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
}
func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) {
if m == nil || len(providers) == 0 {
return 0, false
}
now := time.Now()
providerSet := make(map[string]struct{}, len(providers))
for i := range providers {
key := strings.TrimSpace(strings.ToLower(providers[i]))
if key == "" {
continue
}
providerSet[key] = struct{}{}
}
m.mu.RLock()
defer m.mu.RUnlock()
var (
found bool
minWait time.Duration
)
for _, auth := range m.auths {
if auth == nil {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
}
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
if !blocked || next.IsZero() || reason == blockReasonDisabled {
continue
}
wait := next.Sub(now)
if wait < 0 {
continue
}
if !found || wait < minWait {
minWait = wait
found = true
}
}
return minWait, found
}
func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
if err == nil || attempt >= maxAttempts-1 {
return 0, false
}
if maxWait <= 0 {
return 0, false
}
if status := statusCodeFromError(err); status == http.StatusOK {
return 0, false
}
wait, found := m.closestCooldownWait(providers, model)
if !found || wait > maxWait {
return 0, false
}
return wait, true
}
func waitForCooldown(ctx context.Context, wait time.Duration) error {
if wait <= 0 {
return nil
}
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func (m *Manager) executeProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (cliproxyexecutor.Response, error)) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
var lastErr error
for _, provider := range providers {
resp, errExec := fn(ctx, provider)
if errExec == nil {
return resp, nil
}
lastErr = errExec
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
func (m *Manager) executeStreamProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (<-chan cliproxyexecutor.StreamChunk, error)) (<-chan cliproxyexecutor.StreamChunk, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
var lastErr error
for _, provider := range providers {
chunks, errExec := fn(ctx, provider)
if errExec == nil {
return chunks, nil
}
lastErr = errExec
}
if lastErr != nil {
return nil, lastErr
}
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
// MarkResult records an execution result and notifies hooks.
func (m *Manager) MarkResult(ctx context.Context, result Result) {
if result.AuthID == "" {
@@ -762,6 +947,20 @@ func cloneError(err *Error) *Error {
}
}
func statusCodeFromError(err error) int {
if err == nil {
return 0
}
type statusCoder interface {
StatusCode() int
}
var sc statusCoder
if errors.As(err, &sc) && sc != nil {
return sc.StatusCode()
}
return 0
}
func retryAfterFromError(err error) *time.Duration {
if err == nil {
return nil
@@ -919,6 +1118,14 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
}
authCopy := selected.Clone()
m.mu.RUnlock()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}

View File

@@ -29,7 +29,7 @@ func NewAPIKeyClientProvider() APIKeyClientProvider {
type apiKeyClientProvider struct{}
func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) {
geminiCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg)
geminiCount, vertexCompatCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg)
if ctx != nil {
select {
case <-ctx.Done():
@@ -38,9 +38,10 @@ func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*A
}
}
return &APIKeyClientResult{
GeminiKeyCount: geminiCount,
ClaudeKeyCount: claudeCount,
CodexKeyCount: codexCount,
OpenAICompatCount: openAICompat,
GeminiKeyCount: geminiCount,
VertexCompatKeyCount: vertexCompatCount,
ClaudeKeyCount: claudeCount,
CodexKeyCount: codexCount,
OpenAICompatCount: openAICompat,
}, nil
}

View File

@@ -146,6 +146,27 @@ func (s *Service) consumeAuthUpdates(ctx context.Context) {
}
}
func (s *Service) emitAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
if s == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
if s.watcher != nil && s.watcher.DispatchRuntimeAuthUpdate(update) {
return
}
if s.authUpdates != nil {
select {
case s.authUpdates <- update:
return
default:
log.Debugf("auth update queue saturated, applying inline action=%v id=%s", update.Action, update.ID)
}
}
s.handleAuthUpdate(ctx, update)
}
func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
if s == nil {
return
@@ -220,7 +241,11 @@ func (s *Service) wsOnConnected(channelID string) {
Metadata: map[string]any{"email": channelID}, // metadata drives logging and usage tracking
}
log.Infof("websocket provider connected: %s", channelID)
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
s.emitAuthUpdate(context.Background(), watcher.AuthUpdate{
Action: watcher.AuthUpdateActionAdd,
ID: auth.ID,
Auth: auth,
})
}
func (s *Service) wsOnDisconnected(channelID string, reason error) {
@@ -237,7 +262,10 @@ func (s *Service) wsOnDisconnected(channelID string, reason error) {
log.Infof("websocket provider disconnected: %s", channelID)
}
ctx := context.Background()
s.applyCoreAuthRemoval(ctx, channelID)
s.emitAuthUpdate(ctx, watcher.AuthUpdate{
Action: watcher.AuthUpdateActionDelete,
ID: channelID,
})
}
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
@@ -281,6 +309,14 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
}
}
func (s *Service) applyRetryConfig(cfg *config.Config) {
if s == nil || s.coreManager == nil || cfg == nil {
return
}
maxInterval := time.Duration(cfg.MaxRetryInterval) * time.Second
s.coreManager.SetRetryConfig(cfg.RequestRetry, maxInterval)
}
func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName string, ok bool) {
if a == nil {
return "", "", false
@@ -288,7 +324,7 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
if len(a.Attributes) > 0 {
providerKey = strings.TrimSpace(a.Attributes["provider_key"])
compatName = strings.TrimSpace(a.Attributes["compat_name"])
if providerKey != "" || compatName != "" {
if compatName != "" {
if providerKey == "" {
providerKey = compatName
}
@@ -394,6 +430,8 @@ func (s *Service) Run(ctx context.Context) error {
return err
}
s.applyRetryConfig(s.cfg)
if s.coreManager != nil {
if errLoad := s.coreManager.Load(ctx); errLoad != nil {
log.Warnf("failed to load auth store: %v", errLoad)
@@ -460,7 +498,7 @@ func (s *Service) Run(ctx context.Context) error {
}()
time.Sleep(100 * time.Millisecond)
fmt.Println("API server started successfully")
fmt.Printf("API server started successfully on: %d\n", s.cfg.Port)
if s.hooks.OnAfterStart != nil {
s.hooks.OnAfterStart(s)
@@ -476,6 +514,7 @@ func (s *Service) Run(ctx context.Context) error {
if newCfg == nil {
return
}
s.applyRetryConfig(newCfg)
if s.server != nil {
s.server.UpdateClients(newCfg)
}
@@ -606,6 +645,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if a == nil || a.ID == "" {
return
}
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
GlobalModelRegistry().UnregisterClient(a.ID)
@@ -625,32 +665,62 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if compatDetected {
provider = "openai-compatibility"
}
excluded := s.oauthExcludedModels(provider, authKind)
var models []*ModelInfo
switch provider {
case "gemini":
models = registry.GetGeminiModels()
if entry := s.resolveConfigGeminiKey(a); entry != nil {
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
}
models = applyExcludedModels(models, excluded)
case "vertex":
// Vertex AI Gemini supports the same model identifiers as Gemini.
models = registry.GetGeminiVertexModels()
if authKind == "apikey" {
if entry := s.resolveConfigVertexCompatKey(a); entry != nil && len(entry.Models) > 0 {
models = buildVertexCompatConfigModels(entry)
}
}
models = applyExcludedModels(models, excluded)
case "gemini-cli":
models = registry.GetGeminiCLIModels()
models = applyExcludedModels(models, excluded)
case "aistudio":
models = registry.GetAIStudioModels()
models = applyExcludedModels(models, excluded)
case "antigravity":
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
models = executor.FetchAntigravityModels(ctx, a, s.cfg)
cancel()
models = applyExcludedModels(models, excluded)
case "claude":
models = registry.GetClaudeModels()
if entry := s.resolveConfigClaudeKey(a); entry != nil && len(entry.Models) > 0 {
models = buildClaudeConfigModels(entry)
if entry := s.resolveConfigClaudeKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildClaudeConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
}
models = applyExcludedModels(models, excluded)
case "codex":
models = registry.GetOpenAIModels()
if entry := s.resolveConfigCodexKey(a); entry != nil {
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
}
models = applyExcludedModels(models, excluded)
case "qwen":
models = registry.GetQwenModels()
models = applyExcludedModels(models, excluded)
case "iflow":
models = registry.GetIFlowModels()
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {
@@ -738,7 +808,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
key = strings.ToLower(strings.TrimSpace(a.Provider))
}
GlobalModelRegistry().RegisterClient(a.ID, key, models)
return
}
GlobalModelRegistry().UnregisterClient(a.ID)
}
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
@@ -780,6 +853,222 @@ func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey
return nil
}
func (s *Service) resolveConfigGeminiKey(auth *coreauth.Auth) *config.GeminiKey {
if auth == nil || s.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range s.cfg.GeminiKey {
entry := &s.cfg.GeminiKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
return nil
}
func (s *Service) resolveConfigVertexCompatKey(auth *coreauth.Auth) *config.VertexCompatKey {
if auth == nil || s.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range s.cfg.VertexCompatAPIKey {
entry := &s.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range s.cfg.VertexCompatAPIKey {
entry := &s.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
func (s *Service) resolveConfigCodexKey(auth *coreauth.Auth) *config.CodexKey {
if auth == nil || s.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range s.cfg.CodexKey {
entry := &s.cfg.CodexKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
return nil
}
func (s *Service) oauthExcludedModels(provider, authKind string) []string {
cfg := s.cfg
if cfg == nil {
return nil
}
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
providerKey := strings.ToLower(strings.TrimSpace(provider))
if authKindKey == "apikey" {
return nil
}
return cfg.OAuthExcludedModels[providerKey]
}
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
if len(models) == 0 || len(excluded) == 0 {
return models
}
patterns := make([]string, 0, len(excluded))
for _, item := range excluded {
if trimmed := strings.TrimSpace(item); trimmed != "" {
patterns = append(patterns, strings.ToLower(trimmed))
}
}
if len(patterns) == 0 {
return models
}
filtered := make([]*ModelInfo, 0, len(models))
for _, model := range models {
if model == nil {
continue
}
modelID := strings.ToLower(strings.TrimSpace(model.ID))
blocked := false
for _, pattern := range patterns {
if matchWildcard(pattern, modelID) {
blocked = true
break
}
}
if !blocked {
filtered = append(filtered, model)
}
}
return filtered
}
// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring.
func matchWildcard(pattern, value string) bool {
if pattern == "" {
return false
}
// Fast path for exact match (no wildcard present).
if !strings.Contains(pattern, "*") {
return pattern == value
}
parts := strings.Split(pattern, "*")
// Handle prefix.
if prefix := parts[0]; prefix != "" {
if !strings.HasPrefix(value, prefix) {
return false
}
value = value[len(prefix):]
}
// Handle suffix.
if suffix := parts[len(parts)-1]; suffix != "" {
if !strings.HasSuffix(value, suffix) {
return false
}
value = value[:len(value)-len(suffix)]
}
// Handle middle segments in order.
for i := 1; i < len(parts)-1; i++ {
segment := parts[i]
if segment == "" {
continue
}
idx := strings.Index(value, segment)
if idx < 0 {
return false
}
value = value[idx+len(segment):]
}
return true
}
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 {
return nil
}
now := time.Now().Unix()
out := make([]*ModelInfo, 0, len(entry.Models))
seen := make(map[string]struct{}, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if alias == "" {
alias = name
}
if alias == "" {
continue
}
key := strings.ToLower(alias)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
display := name
if display == "" {
display = alias
}
out = append(out, &ModelInfo{
ID: alias,
Object: "model",
Created: now,
OwnedBy: "vertex",
Type: "vertex",
DisplayName: display,
})
}
return out
}
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 {
return nil

View File

@@ -49,19 +49,21 @@ type APIKeyClientProvider interface {
Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error)
}
// APIKeyClientResult contains API key based clients along with type counts.
// It provides metadata about the number of clients loaded for each provider type.
// APIKeyClientResult is returned by APIKeyClientProvider.Load()
type APIKeyClientResult struct {
// GeminiKeyCount is the number of Gemini API key clients loaded.
// GeminiKeyCount is the number of Gemini API keys loaded
GeminiKeyCount int
// ClaudeKeyCount is the number of Claude API key clients loaded.
// VertexCompatKeyCount is the number of Vertex-compatible API keys loaded
VertexCompatKeyCount int
// ClaudeKeyCount is the number of Claude API keys loaded
ClaudeKeyCount int
// CodexKeyCount is the number of Codex API key clients loaded.
// CodexKeyCount is the number of Codex API keys loaded
CodexKeyCount int
// OpenAICompatCount is the number of OpenAI-compatible API key clients loaded.
// OpenAICompatCount is the number of OpenAI compatibility API keys loaded
OpenAICompatCount int
}
@@ -83,9 +85,10 @@ type WatcherWrapper struct {
start func(ctx context.Context) error
stop func() error
setConfig func(cfg *config.Config)
snapshotAuths func() []*coreauth.Auth
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
setConfig func(cfg *config.Config)
snapshotAuths func() []*coreauth.Auth
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool
}
// Start proxies to the underlying watcher Start implementation.
@@ -112,6 +115,16 @@ func (w *WatcherWrapper) SetConfig(cfg *config.Config) {
w.setConfig(cfg)
}
// DispatchRuntimeAuthUpdate forwards runtime auth updates (e.g., websocket providers)
// into the watcher-managed auth update queue when available.
// Returns true if the update was enqueued successfully.
func (w *WatcherWrapper) DispatchRuntimeAuthUpdate(update watcher.AuthUpdate) bool {
if w == nil || w.dispatchRuntimeUpdate == nil {
return false
}
return w.dispatchRuntimeUpdate(update)
}
// SetClients updates the watcher file-backed clients registry.
// SetClients and SetAPIKeyClients removed; watcher manages its own caches

View File

@@ -28,5 +28,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi
setUpdateQueue: func(queue chan<- watcher.AuthUpdate) {
w.SetAuthUpdateQueue(queue)
},
dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool {
return w.DispatchRuntimeAuthUpdate(update)
},
}, nil
}