Compare commits

..

63 Commits

Author SHA1 Message Date
Luis Pater
2e6a2b655c Merge pull request #1132 from XYenon/fix/gemini-models-displayname-override
fix(gemini): preserve displayName and description in models list
2026-01-25 03:40:04 +08:00
Luis Pater
cb47ac21bf Merge pull request #1179 from mallendeo/main
fix(claude): skip built-in tools in OAuth tool prefix
2026-01-25 03:31:58 +08:00
Luis Pater
a1394b4596 Merge pull request #1183 from Darley-Wey/fix/api-align
fix(api): enhance ClaudeModels response to align with api.anthropic.com
2026-01-25 03:30:14 +08:00
Luis Pater
9e97948f03 Merge pull request #1185 from router-for-me/auth
Refactor authentication handling for Antigravity, Claude, Codex, and Gemini
2026-01-25 03:28:53 +08:00
Darley
46c6fb1e7a fix(api): enhance ClaudeModels response to align with api.anthropic.com 2026-01-24 04:41:08 +03:30
hkfires
9f9fec5d4c fix(auth): improve antigravity token exchange errors 2026-01-24 09:04:15 +08:00
hkfires
e95be10485 fix(auth): validate antigravity token userinfo email 2026-01-24 08:33:52 +08:00
hkfires
f3d58fa0ce fix(auth): correct antigravity oauth redirect and expiry 2026-01-24 08:33:52 +08:00
hkfires
8c0eaa1f71 refactor(auth): export Gemini constants and use in handler 2026-01-24 08:33:52 +08:00
hkfires
405df58f72 refactor(auth): export Codex constants and slim down handler 2026-01-24 08:33:52 +08:00
hkfires
e7f13aa008 refactor(api): slim down RequestAnthropicToken to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
7cb6a9b89a refactor(auth): export Claude OAuth constants for reuse 2026-01-24 08:33:51 +08:00
hkfires
9aa5344c29 refactor(api): slim down RequestAntigravityToken to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
8ba0ebbd2a refactor(sdk): slim down Antigravity authenticator to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
c65407ab9f refactor(auth): extract Antigravity OAuth constants to internal/auth 2026-01-24 08:33:51 +08:00
hkfires
9e59685212 refactor(auth): implement Antigravity AuthService in internal/auth 2026-01-24 08:33:51 +08:00
hkfires
4a4dfaa910 refactor(auth): replace sanitizeAntigravityFileName with antigravity.CredentialFileName 2026-01-24 08:33:51 +08:00
Luis Pater
0d6ecb0191 Fixed: #1077
refactor(translator): improve tools handling by separating functionDeclarations and googleSearch nodes
2026-01-24 05:51:11 +08:00
Mauricio Allende
f16461bfe7 fix(claude): skip built-in tools in OAuth tool prefix 2026-01-23 21:29:39 +00:00
Luis Pater
c32e2a8196 fix(auth): handle context cancellation in executor methods 2026-01-24 04:56:55 +08:00
Luis Pater
873d41582f Merge pull request #1125 from NightHammer1000/dev
Filter out Top_P when Temp is set on Claude
2026-01-24 02:03:33 +08:00
Luis Pater
6fb7d85558 Merge pull request #1137 from augustVino/fix/remove_empty_systemmsg
fix(translator): ensure system message is only added if it contains c…
2026-01-24 02:02:18 +08:00
hkfires
d5e3e32d58 fix(auth): normalize plan type filenames to lowercase 2026-01-23 20:13:09 +08:00
Chén Mù
f353a54555 Merge pull request #1171 from router-for-me/auth
refactor(auth): remove unused provider execution helpers
2026-01-23 19:43:42 +08:00
Chén Mù
1d6e2e751d Merge pull request #1140 from sxjeru/main
fix(auth): handle quota cooldown in retry logic for transient errors
2026-01-23 19:43:17 +08:00
hkfires
cc50b63422 refactor(auth): remove unused provider execution helpers 2026-01-23 19:12:55 +08:00
Luis Pater
15ae83a15b Merge pull request #1169 from router-for-me/payload
feat(executor): apply payload rules using requested model
2026-01-23 18:41:31 +08:00
hkfires
81b369aed9 fix(auth): include requested model in executor metadata 2026-01-23 18:30:08 +08:00
hkfires
ecc850bfb7 feat(executor): apply payload rules using requested model 2026-01-23 16:38:41 +08:00
Chén Mù
19b4ef33e0 Merge pull request #1102 from aldinokemal/main
feat(management): add PATCH endpoint to enable/disable auth files
2026-01-23 09:05:24 +08:00
hkfires
7ca045d8b9 fix(executor): adjust model-specific request payload 2026-01-22 20:28:08 +08:00
hkfires
abfca6aab2 refactor(util): reorder gemini schema cleaner helpers 2026-01-22 18:38:48 +08:00
Chén Mù
3c71c075db Merge pull request #1131 from sowar1987/fix/gemini-malformed-function-call
Fix Gemini tool calling for Antigravity (malformed_function_call)
2026-01-22 18:07:03 +08:00
sowar1987
9c2992bfb2 test: align signature cache tests with cache behavior
Co-Authored-By: Warp <agent@warp.dev>
2026-01-22 17:12:47 +08:00
sowar1987
269a1c5452 refactor: reuse placeholder reason description
Co-Authored-By: Warp <agent@warp.dev>
2026-01-22 17:12:47 +08:00
sowar1987
22ce65ac72 test: update signature cache tests
Revert gemini translator changes for scheme A

Co-Authored-By: Warp <agent@warp.dev>
2026-01-22 17:12:47 +08:00
sowar1987
a2f8f59192 Fix Gemini function-calling INVALID_ARGUMENT by relaxing Gemini tool validation and cleaning schema 2026-01-22 17:11:07 +08:00
XYenon
8c7c446f33 fix(gemini): preserve displayName and description in models list
Previously GeminiModels handler unconditionally overwrote displayName
and description with the model name, losing the original values defined
in model definitions (e.g., 'Gemini 3 Pro Preview').

Now only set these fields as fallback when they are missing or empty.
2026-01-22 15:19:27 +08:00
sxjeru
30a59168d7 fix(auth): handle quota cooldown in retry logic for transient errors 2026-01-21 21:48:23 +08:00
hkfires
c8884f5e25 refactor(translator): enhance signature handling in Claude and Gemini requests, streamline cache usage and remove unnecessary tests 2026-01-21 20:21:49 +08:00
Luis Pater
d9c6317c84 refactor(cache, translator): refine signature caching logic and tests, replace session-based logic with model group handling 2026-01-21 18:30:05 +08:00
Vino
d29ec95526 fix(translator): ensure system message is only added if it contains content 2026-01-21 16:45:50 +08:00
Luis Pater
ef4508dbc8 refactor(cache, translator): remove session ID from signature caching and clean up logic 2026-01-21 13:37:10 +08:00
Luis Pater
f775e46fe2 refactor(translator): remove session ID logic from signature caching and associated tests 2026-01-21 12:45:07 +08:00
Luis Pater
65ad5c0c9d refactor(cache): simplify signature caching by removing sessionID parameter 2026-01-21 12:38:05 +08:00
Luis Pater
88bf4e77ec fix(translator): update HasValidSignature to require modelName parameter for improved validation 2026-01-21 11:31:37 +08:00
Luis Pater
a4f8015caa test(logging): add unit tests for GinLogrusRecovery middleware panic handling 2026-01-21 10:57:27 +08:00
Luis Pater
ffd129909e Merge pull request #1130 from router-for-me/agty
fix(executor): only strip maxOutputTokens for non-claude models
2026-01-21 10:50:39 +08:00
hkfires
9332316383 fix(translator): preserve thinking blocks by skipping signature 2026-01-21 10:49:20 +08:00
hkfires
6dcbbf64c3 fix(executor): only strip maxOutputTokens for non-claude models 2026-01-21 10:49:20 +08:00
Luis Pater
2ce3553612 feat(cache): handle gemini family in signature cache with fallback validator logic 2026-01-21 10:11:21 +08:00
Luis Pater
2e14f787d4 feat(translator): enhance ConvertGeminiRequestToAntigravity with model name and refine reasoning block handling 2026-01-21 08:31:23 +08:00
Luis Pater
523b41ccd2 test(responses): add comprehensive tests for SSE event ordering and response transformations 2026-01-21 07:08:59 +08:00
N1GHT
09970dc7af Accept Geminis Review Suggestion 2026-01-20 17:51:36 +01:00
N1GHT
d81abd401c Returned the Code Comment I trashed 2026-01-20 17:36:27 +01:00
N1GHT
a6cba25bc1 Small fix to filter out Top_P when Temperature is set on Claude to make requests go through 2026-01-20 17:34:26 +01:00
Luis Pater
c6fa1d0e67 Merge pull request #1117 from router-for-me/cache
fix(translator): enhance signature cache clearing logic and update test cases with model name
2026-01-20 23:18:48 +08:00
Luis Pater
ac56e1e88b Merge pull request #1116 from bexcodex/fix/antigravity
Fix antigravity malformed_function_call
2026-01-20 22:40:00 +08:00
hkfires
9b72ea9efa fix(translator): enhance signature cache clearing logic and update test cases with model name 2026-01-20 20:02:29 +08:00
bexcodex
9f364441e8 Fix antigravity malformed_function_call 2026-01-20 19:54:54 +08:00
Luis Pater
e49a1c07bf chore(translator): update cache functions to include model name parameter in tests 2026-01-20 18:36:51 +08:00
Aldino Kemal
2f6004d74a perf(management): optimize auth lookup in PatchAuthFileStatus
Use GetByID() for O(1) map lookup first, falling back to iteration
only for FileName matching. Consistent with pattern in disableAuth().
2026-01-19 20:05:37 +07:00
Aldino Kemal
a1634909e8 feat(management): add PATCH endpoint to enable/disable auth files
Add new PATCH /v0/management/auth-files/status endpoint that allows
toggling the disabled state of auth files without deleting them.
This enables users to temporarily disable credentials from the
management UI.
2026-01-19 19:50:36 +07:00
49 changed files with 1766 additions and 1390 deletions

2
go.mod
View File

@@ -21,6 +21,7 @@ require (
golang.org/x/crypto v0.45.0
golang.org/x/net v0.47.0
golang.org/x/oauth2 v0.30.0
golang.org/x/text v0.31.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
)
@@ -70,7 +71,6 @@ require (
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)

View File

@@ -11,7 +11,6 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"sort"
@@ -21,6 +20,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
@@ -232,14 +232,6 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
log.Infof("callback forwarder on port %d stopped", port)
}
func sanitizeAntigravityFileName(email string) string {
if strings.TrimSpace(email) == "" {
return "antigravity.json"
}
replacer := strings.NewReplacer("@", "_", ".", "_")
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
}
func (h *Handler) managementCallbackURL(path string) (string, error) {
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
return "", fmt.Errorf("server port is not configured")
@@ -749,6 +741,72 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
return err
}
// PatchAuthFileStatus toggles the disabled state of an auth file
func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Disabled *bool `json:"disabled"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if req.Disabled == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "disabled is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
// Update disabled state
targetAuth.Disabled = *req.Disabled
if *req.Disabled {
targetAuth.Status = coreauth.StatusDisabled
targetAuth.StatusMessage = "disabled via management API"
} else {
targetAuth.Status = coreauth.StatusActive
targetAuth.StatusMessage = ""
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return
@@ -915,67 +973,14 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
rawCode := resultMap["code"]
code := strings.Split(rawCode, "#")[0]
// Exchange code for tokens (replicate logic using updated redirect_uri)
// Extract client_id from the modified auth URL
clientID := ""
if u2, errP := url.Parse(authURL); errP == nil {
clientID = u2.Query().Get("client_id")
}
// Build request
bodyMap := map[string]any{
"code": code,
"state": state,
"grant_type": "authorization_code",
"client_id": clientID,
"redirect_uri": "http://localhost:54545/callback",
"code_verifier": pkceCodes.CodeVerifier,
}
bodyJSON, _ := json.Marshal(bodyMap)
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
// Exchange code for tokens using internal auth service
bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes)
if errExchange != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("failed to close response body: %v", errClose)
}
}()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
return
}
var tResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
Account struct {
EmailAddress string `json:"email_address"`
} `json:"account"`
}
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
log.Errorf("failed to parse token response: %v", errU)
SetOAuthSessionError(state, "Failed to parse token response")
return
}
bundle := &claude.ClaudeAuthBundle{
TokenData: claude.ClaudeTokenData{
AccessToken: tResp.AccessToken,
RefreshToken: tResp.RefreshToken,
Email: tResp.Account.EmailAddress,
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
}
// Create token storage
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
@@ -1015,17 +1020,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
fmt.Println("Initializing Google authentication...")
// OAuth2 configuration (mirrors internal/auth/gemini)
// OAuth2 configuration using exported constants from internal/auth/gemini
conf := &oauth2.Config{
ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com",
ClientSecret: "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",
RedirectURL: "http://localhost:8085/oauth2callback",
Scopes: []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
ClientID: geminiAuth.ClientID,
ClientSecret: geminiAuth.ClientSecret,
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort),
Scopes: geminiAuth.Scopes,
Endpoint: google.Endpoint,
}
// Build authorization URL and return it immediately
@@ -1147,13 +1148,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
}
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
ifToken["client_secret"] = "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"
ifToken["scopes"] = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
ifToken["client_id"] = geminiAuth.ClientID
ifToken["client_secret"] = geminiAuth.ClientSecret
ifToken["scopes"] = geminiAuth.Scopes
ifToken["universe_domain"] = "googleapis.com"
ts := geminiAuth.GeminiTokenStorage{
@@ -1340,73 +1337,25 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
}
log.Debug("Authorization code received, exchanging for tokens...")
// Extract client_id from authURL
clientID := ""
if u2, errP := url.Parse(authURL); errP == nil {
clientID = u2.Query().Get("client_id")
}
// Exchange code for tokens with redirect equal to mgmtRedirect
form := url.Values{
"grant_type": {"authorization_code"},
"client_id": {clientID},
"code": {code},
"redirect_uri": {"http://localhost:1455/auth/callback"},
"code_verifier": {pkceCodes.CodeVerifier},
}
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
// Exchange code for tokens using internal auth service
bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes)
if errExchange != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
return
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
return
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
ExpiresIn int `json:"expires_in"`
}
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
SetOAuthSessionError(state, "Failed to parse token response")
log.Errorf("failed to parse token response: %v", errU)
return
}
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
email := ""
accountID := ""
// Extract additional info for filename generation
claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken)
planType := ""
if claims != nil {
email = claims.GetUserEmail()
accountID = claims.GetAccountID()
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
}
hashAccountID := ""
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
// Build bundle compatible with existing storage
bundle := &codex.CodexAuthBundle{
TokenData: codex.CodexTokenData{
IDToken: tokenResp.IDToken,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
AccountID: accountID,
Email: email,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
if claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
if accountID := claims.GetAccountID(); accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
// Create token storage and persist
@@ -1441,23 +1390,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
}
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
const (
antigravityCallbackPort = 51121
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
)
var antigravityScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
ctx := context.Background()
fmt.Println("Initializing Antigravity authentication...")
authSvc := antigravity.NewAntigravityAuth(h.cfg, nil)
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Errorf("Failed to generate state parameter: %v", errState)
@@ -1465,17 +1403,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
return
}
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", antigravityClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(antigravityScopes, " "))
params.Set("state", state)
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort)
authURL := authSvc.BuildAuthURL(state, redirectURI)
RegisterOAuthSession(state, "antigravity")
@@ -1489,7 +1418,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
@@ -1498,7 +1427,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
@@ -1538,93 +1467,36 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
time.Sleep(500 * time.Millisecond)
}
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
form := url.Values{}
form.Set("code", authCode)
form.Set("client_id", antigravityClientID)
form.Set("client_secret", antigravityClientSecret)
form.Set("redirect_uri", redirectURI)
form.Set("grant_type", "authorization_code")
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
if errNewRequest != nil {
log.Errorf("Failed to build token request: %v", errNewRequest)
SetOAuthSessionError(state, "Failed to build token request")
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.Errorf("Failed to execute token request: %v", errDo)
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI)
if errToken != nil {
log.Errorf("Failed to exchange token: %v", errToken)
SetOAuthSessionError(state, "Failed to exchange token")
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange close error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body)
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
accessToken := strings.TrimSpace(tokenResp.AccessToken)
if accessToken == "" {
log.Error("antigravity: token exchange returned empty access token")
SetOAuthSessionError(state, "Failed to exchange token")
return
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
log.Errorf("Failed to parse token response: %v", errDecode)
SetOAuthSessionError(state, "Failed to parse token response")
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
if errInfo != nil {
log.Errorf("Failed to fetch user info: %v", errInfo)
SetOAuthSessionError(state, "Failed to fetch user info")
return
}
email := ""
if strings.TrimSpace(tokenResp.AccessToken) != "" {
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errInfoReq != nil {
log.Errorf("Failed to build user info request: %v", errInfoReq)
SetOAuthSessionError(state, "Failed to build user info request")
return
}
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
infoResp, errInfo := httpClient.Do(infoReq)
if errInfo != nil {
log.Errorf("Failed to execute user info request: %v", errInfo)
SetOAuthSessionError(state, "Failed to execute user info request")
return
}
defer func() {
if errClose := infoResp.Body.Close(); errClose != nil {
log.Errorf("antigravity user info close error: %v", errClose)
}
}()
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
var infoPayload struct {
Email string `json:"email"`
}
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
email = strings.TrimSpace(infoPayload.Email)
}
} else {
bodyBytes, _ := io.ReadAll(infoResp.Body)
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
return
}
email = strings.TrimSpace(email)
if email == "" {
log.Error("antigravity: user info returned empty email")
SetOAuthSessionError(state, "Failed to fetch user info")
return
}
projectID := ""
if strings.TrimSpace(tokenResp.AccessToken) != "" {
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
if accessToken != "" {
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
@@ -1649,7 +1521,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
metadata["project_id"] = projectID
}
fileName := sanitizeAntigravityFileName(email)
fileName := antigravity.CredentialFileName(email)
label := strings.TrimSpace(email)
if label == "" {
label = "antigravity"

View File

@@ -610,6 +610,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)

View File

@@ -0,0 +1,344 @@
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
package antigravity
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// TokenResponse represents OAuth token response from Google
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
// userInfo represents Google user profile
type userInfo struct {
Email string `json:"email"`
}
// AntigravityAuth handles Antigravity OAuth authentication
type AntigravityAuth struct {
httpClient *http.Client
}
// NewAntigravityAuth creates a new Antigravity auth service.
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
if httpClient != nil {
return &AntigravityAuth{httpClient: httpClient}
}
if cfg == nil {
cfg = &config.Config{}
}
return &AntigravityAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
}
}
// BuildAuthURL generates the OAuth authorization URL.
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
if strings.TrimSpace(redirectURI) == "" {
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
}
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", ClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(Scopes, " "))
params.Set("state", state)
return AuthEndpoint + "?" + params.Encode()
}
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
data := url.Values{}
data.Set("code", code)
data.Set("client_id", ClientID)
data.Set("client_secret", ClientSecret)
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if errRead != nil {
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
}
body := strings.TrimSpace(string(bodyBytes))
if body == "" {
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
}
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
}
var token TokenResponse
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
}
return &token, nil
}
// FetchUserInfo retrieves user email from Google
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", fmt.Errorf("antigravity userinfo: missing access token")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
if err != nil {
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity userinfo: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if errRead != nil {
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
}
body := strings.TrimSpace(string(bodyBytes))
if body == "" {
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
}
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
}
var info userInfo
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
}
email := strings.TrimSpace(info.Email)
if email == "" {
return "", fmt.Errorf("antigravity userinfo: response missing email")
}
return email, nil
}
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(loadReqBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", APIUserAgent)
req.Header.Set("X-Goog-Api-Client", APIClient)
req.Header.Set("Client-Metadata", ClientMetadata)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var loadResp map[string]any
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
if projectID == "" {
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", err
}
return projectID, nil
}
return projectID, nil
}
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(requestBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
reqCtx := ctx
var cancel context.CancelFunc
if reqCtx == nil {
reqCtx = context.Background()
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", APIUserAgent)
req.Header.Set("X-Goog-Api-Client", APIClient)
req.Header.Set("Client-Metadata", ClientMetadata)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
cancel()
return "", fmt.Errorf("execute request: %w", errDo)
}
bodyBytes, errRead := io.ReadAll(resp.Body)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("close body error: %v", errClose)
}
cancel()
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode == http.StatusOK {
var data map[string]any
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
return projectID, nil
}
return "", fmt.Errorf("no project_id in response")
}
time.Sleep(2 * time.Second)
continue
}
responsePreview := strings.TrimSpace(string(bodyBytes))
if len(responsePreview) > 500 {
responsePreview = responsePreview[:500]
}
responseErr := responsePreview
if len(responseErr) > 200 {
responseErr = responseErr[:200]
}
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
}

View File

@@ -0,0 +1,34 @@
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
package antigravity
// OAuth client credentials and configuration
const (
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
CallbackPort = 51121
)
// Scopes defines the OAuth scopes required for Antigravity authentication
var Scopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
// OAuth2 endpoints for Google authentication
const (
TokenEndpoint = "https://oauth2.googleapis.com/token"
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
)
// Antigravity API configuration
const (
APIEndpoint = "https://cloudcode-pa.googleapis.com"
APIVersion = "v1internal"
APIUserAgent = "google-api-nodejs-client/9.15.1"
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
)

View File

@@ -0,0 +1,16 @@
package antigravity
import (
"fmt"
"strings"
)
// CredentialFileName returns the filename used to persist Antigravity credentials.
// It uses the email as a suffix to disambiguate accounts.
func CredentialFileName(email string) string {
email = strings.TrimSpace(email)
if email == "" {
return "antigravity.json"
}
return fmt.Sprintf("antigravity-%s.json", email)
}

View File

@@ -18,11 +18,12 @@ import (
log "github.com/sirupsen/logrus"
)
// OAuth configuration constants for Claude/Anthropic
const (
anthropicAuthURL = "https://claude.ai/oauth/authorize"
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
redirectURI = "http://localhost:54545/callback"
AuthURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
)
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
@@ -82,16 +83,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
params := url.Values{
"code": {"true"},
"client_id": {anthropicClientID},
"client_id": {ClientID},
"response_type": {"code"},
"redirect_uri": {redirectURI},
"redirect_uri": {RedirectURI},
"scope": {"org:create_api_key user:profile user:inference"},
"code_challenge": {pkceCodes.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
}
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
return authURL, state, nil
}
@@ -137,8 +138,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
"code": newCode,
"state": state,
"grant_type": "authorization_code",
"client_id": anthropicClientID,
"redirect_uri": redirectURI,
"client_id": ClientID,
"redirect_uri": RedirectURI,
"code_verifier": pkceCodes.CodeVerifier,
}
@@ -154,7 +155,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
// log.Debugf("Token exchange request: %s", string(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -221,7 +222,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
}
reqBody := map[string]interface{}{
"client_id": anthropicClientID,
"client_id": ClientID,
"grant_type": "refresh_token",
"refresh_token": refreshToken,
}
@@ -231,7 +232,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}

View File

@@ -4,9 +4,6 @@ import (
"fmt"
"strings"
"unicode"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
@@ -43,15 +40,7 @@ func normalizePlanTypeForFilename(planType string) string {
}
for i, part := range parts {
parts[i] = titleToken(part)
parts[i] = strings.ToLower(strings.TrimSpace(part))
}
return strings.Join(parts, "-")
}
func titleToken(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
return cases.Title(language.English).String(token)
}

View File

@@ -19,11 +19,12 @@ import (
log "github.com/sirupsen/logrus"
)
// OAuth configuration constants for OpenAI Codex
const (
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
openaiTokenURL = "https://auth.openai.com/oauth/token"
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
redirectURI = "http://localhost:1455/auth/callback"
AuthURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
RedirectURI = "http://localhost:1455/auth/callback"
)
// CodexAuth handles the OpenAI OAuth2 authentication flow.
@@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
}
params := url.Values{
"client_id": {openaiClientID},
"client_id": {ClientID},
"response_type": {"code"},
"redirect_uri": {redirectURI},
"redirect_uri": {RedirectURI},
"scope": {"openid email profile offline_access"},
"state": {state},
"code_challenge": {pkceCodes.CodeChallenge},
@@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
"codex_cli_simplified_flow": {"true"},
}
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
return authURL, nil
}
@@ -77,13 +78,13 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {openaiClientID},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {redirectURI},
"redirect_uri": {RedirectURI},
"code_verifier": {pkceCodes.CodeVerifier},
}
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -163,13 +164,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
}
data := url.Values{
"client_id": {openaiClientID},
"client_id": {ClientID},
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"scope": {"openid profile email"},
}
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}

View File

@@ -28,19 +28,19 @@ import (
"golang.org/x/oauth2/google"
)
// OAuth configuration constants for Gemini
const (
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOauthClientSecret = "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"
geminiDefaultCallbackPort = 8085
ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
ClientSecret = "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"
DefaultCallbackPort = 8085
)
var (
geminiOauthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
)
// OAuth scopes for Gemini authentication
var Scopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
@@ -74,7 +74,7 @@ func NewGeminiAuth() *GeminiAuth {
// - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
callbackPort := geminiDefaultCallbackPort
callbackPort := DefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
@@ -112,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
// Configure the OAuth2 client.
conf := &oauth2.Config{
ClientID: geminiOauthClientID,
ClientSecret: geminiOauthClientSecret,
ClientID: ClientID,
ClientSecret: ClientSecret,
RedirectURL: callbackURL, // This will be used by the local server.
Scopes: geminiOauthScopes,
Scopes: Scopes,
Endpoint: google.Endpoint,
}
@@ -198,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
}
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = geminiOauthClientID
ifToken["client_secret"] = geminiOauthClientSecret
ifToken["scopes"] = geminiOauthScopes
ifToken["client_id"] = ClientID
ifToken["client_secret"] = ClientSecret
ifToken["scopes"] = Scopes
ifToken["universe_domain"] = "googleapis.com"
ts := GeminiTokenStorage{
@@ -226,7 +226,7 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
callbackPort := geminiDefaultCallbackPort
callbackPort := DefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}

View File

@@ -3,7 +3,6 @@ package cache
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"time"
@@ -25,18 +24,18 @@ const (
// MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50
// SessionCleanupInterval controls how often stale sessions are purged
SessionCleanupInterval = 10 * time.Minute
// CacheCleanupInterval controls how often stale entries are purged
CacheCleanupInterval = 10 * time.Minute
)
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
var signatureCache sync.Map
// sessionCleanupOnce ensures the background cleanup goroutine starts only once
var sessionCleanupOnce sync.Once
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
var cacheCleanupOnce sync.Once
// sessionCache is the inner map type
type sessionCache struct {
// groupCache is the inner map type
type groupCache struct {
mu sync.RWMutex
entries map[string]SignatureEntry
}
@@ -47,36 +46,36 @@ func hashText(text string) string {
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
}
// getOrCreateSession gets or creates a session cache
func getOrCreateSession(sessionID string) *sessionCache {
// getOrCreateGroupCache gets or creates a cache bucket for a model group
func getOrCreateGroupCache(groupKey string) *groupCache {
// Start background cleanup on first access
sessionCleanupOnce.Do(startSessionCleanup)
cacheCleanupOnce.Do(startCacheCleanup)
if val, ok := signatureCache.Load(sessionID); ok {
return val.(*sessionCache)
if val, ok := signatureCache.Load(groupKey); ok {
return val.(*groupCache)
}
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
return actual.(*sessionCache)
sc := &groupCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
return actual.(*groupCache)
}
// startSessionCleanup launches a background goroutine that periodically
// removes sessions where all entries have expired.
func startSessionCleanup() {
// startCacheCleanup launches a background goroutine that periodically
// removes caches where all entries have expired.
func startCacheCleanup() {
go func() {
ticker := time.NewTicker(SessionCleanupInterval)
ticker := time.NewTicker(CacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredSessions()
purgeExpiredCaches()
}
}()
}
// purgeExpiredSessions removes sessions with no valid (non-expired) entries.
func purgeExpiredSessions() {
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
func purgeExpiredCaches() {
now := time.Now()
signatureCache.Range(func(key, value any) bool {
sc := value.(*sessionCache)
sc := value.(*groupCache)
sc.mu.Lock()
// Remove expired entries
for k, entry := range sc.entries {
@@ -86,7 +85,7 @@ func purgeExpiredSessions() {
}
isEmpty := len(sc.entries) == 0
sc.mu.Unlock()
// Remove session if empty
// Remove cache bucket if empty
if isEmpty {
signatureCache.Delete(key)
}
@@ -94,19 +93,19 @@ func purgeExpiredSessions() {
})
}
// CacheSignature stores a thinking signature for a given session and text.
// CacheSignature stores a thinking signature for a given model group and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(modelName, sessionID, text, signature string) {
if sessionID == "" || text == "" || signature == "" {
func CacheSignature(modelName, text, signature string) {
if text == "" || signature == "" {
return
}
if len(signature) < MinValidSignatureLen {
return
}
sc := getOrCreateSession(fmt.Sprintf("%s#%s", GetModelGroup(modelName), sessionID))
groupKey := GetModelGroup(modelName)
textHash := hashText(text)
sc := getOrCreateGroupCache(groupKey)
sc.mu.Lock()
defer sc.mu.Unlock()
@@ -116,18 +115,25 @@ func CacheSignature(modelName, sessionID, text, signature string) {
}
}
// GetCachedSignature retrieves a cached signature for a given session and text.
// GetCachedSignature retrieves a cached signature for a given model group and text.
// Returns empty string if not found or expired.
func GetCachedSignature(modelName, sessionID, text string) string {
if sessionID == "" || text == "" {
return ""
}
func GetCachedSignature(modelName, text string) string {
groupKey := GetModelGroup(modelName)
val, ok := signatureCache.Load(fmt.Sprintf("%s#%s", GetModelGroup(modelName), sessionID))
if !ok {
if text == "" {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
sc := val.(*sessionCache)
val, ok := signatureCache.Load(groupKey)
if !ok {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
sc := val.(*groupCache)
textHash := hashText(text)
@@ -137,11 +143,17 @@ func GetCachedSignature(modelName, sessionID, text string) string {
entry, exists := sc.entries[textHash]
if !exists {
sc.mu.Unlock()
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, textHash)
sc.mu.Unlock()
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
@@ -153,21 +165,22 @@ func GetCachedSignature(modelName, sessionID, text string) string {
return entry.Signature
}
// ClearSignatureCache clears signature cache for a specific session or all sessions.
func ClearSignatureCache(sessionID string) {
if sessionID != "" {
signatureCache.Delete(sessionID)
} else {
// ClearSignatureCache clears signature cache for a specific model group or all groups.
func ClearSignatureCache(modelName string) {
if modelName == "" {
signatureCache.Range(func(key, _ any) bool {
signatureCache.Delete(key)
return true
})
return
}
groupKey := GetModelGroup(modelName)
signatureCache.Delete(groupKey)
}
// HasValidSignature checks if a signature is valid (non-empty and long enough)
func HasValidSignature(signature string) bool {
return signature != "" && len(signature) >= MinValidSignatureLen
func HasValidSignature(modelName, signature string) bool {
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
}
func GetModelGroup(modelName string) string {

View File

@@ -5,38 +5,40 @@ import (
"time"
)
const testModelName = "claude-sonnet-4-5"
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-session-1"
text := "This is some thinking text content"
signature := "abc123validSignature1234567890123456789012345678901234567890"
// Store signature
CacheSignature(sessionID, text, signature)
CacheSignature(testModelName, text, signature)
// Retrieve signature
retrieved := GetCachedSignature(sessionID, text)
retrieved := GetCachedSignature(testModelName, text)
if retrieved != signature {
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
}
}
func TestCacheSignature_DifferentSessions(t *testing.T) {
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
ClearSignatureCache("")
text := "Same text in different sessions"
text := "Same text across models"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature("session-a", text, sig1)
CacheSignature("session-b", text, sig2)
geminiModel := "gemini-3-pro-preview"
CacheSignature(testModelName, text, sig1)
CacheSignature(geminiModel, text, sig2)
if GetCachedSignature("session-a", text) != sig1 {
t.Error("Session-a signature mismatch")
if GetCachedSignature(testModelName, text) != sig1 {
t.Error("Claude signature mismatch")
}
if GetCachedSignature("session-b", text) != sig2 {
t.Error("Session-b signature mismatch")
if GetCachedSignature(geminiModel, text) != sig2 {
t.Error("Gemini signature mismatch")
}
}
@@ -44,13 +46,13 @@ func TestCacheSignature_NotFound(t *testing.T) {
ClearSignatureCache("")
// Non-existent session
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
if got := GetCachedSignature(testModelName, "some text"); got != "" {
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
}
// Existing session but different text
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature("session-x", "text-b"); got != "" {
CacheSignature(testModelName, "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature(testModelName, "text-b"); got != "" {
t.Errorf("Expected empty string for different text, got '%s'", got)
}
}
@@ -59,12 +61,11 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
ClearSignatureCache("")
// All empty/invalid inputs should be no-ops
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "text", "")
CacheSignature("session", "text", "short") // Too short
CacheSignature(testModelName, "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature(testModelName, "text", "")
CacheSignature(testModelName, "text", "short") // Too short
if got := GetCachedSignature("session", "text"); got != "" {
if got := GetCachedSignature(testModelName, "text"); got != "" {
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
}
}
@@ -72,31 +73,27 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-short-sig"
text := "Some text"
shortSig := "abc123" // Less than 50 chars
CacheSignature(sessionID, text, shortSig)
CacheSignature(testModelName, text, shortSig)
if got := GetCachedSignature(sessionID, text); got != "" {
if got := GetCachedSignature(testModelName, text); got != "" {
t.Errorf("Short signature should be rejected, got '%s'", got)
}
}
func TestClearSignatureCache_SpecificSession(t *testing.T) {
func TestClearSignatureCache_ModelGroup(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
CacheSignature(testModelName, "text", sig)
CacheSignature(testModelName, "text-2", sig)
ClearSignatureCache("session-1")
if got := GetCachedSignature("session-1", "text"); got != "" {
t.Error("session-1 should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != sig {
t.Error("session-2 should still exist")
if got := GetCachedSignature(testModelName, "text"); got != sig {
t.Error("signature should remain when clearing unknown session")
}
}
@@ -104,35 +101,37 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
CacheSignature(testModelName, "text", sig)
CacheSignature(testModelName, "text-2", sig)
ClearSignatureCache("")
if got := GetCachedSignature("session-1", "text"); got != "" {
t.Error("session-1 should be cleared")
if got := GetCachedSignature(testModelName, "text"); got != "" {
t.Error("text should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != "" {
t.Error("session-2 should be cleared")
if got := GetCachedSignature(testModelName, "text-2"); got != "" {
t.Error("text-2 should be cleared")
}
}
func TestHasValidSignature(t *testing.T) {
tests := []struct {
name string
modelName string
signature string
expected bool
}{
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
{"empty string", "", false},
{"short signature", "abc", false},
{"valid long signature", testModelName, "abc123validSignature1234567890123456789012345678901234567890", true},
{"exactly 50 chars", testModelName, "12345678901234567890123456789012345678901234567890", true},
{"49 chars - invalid", testModelName, "1234567890123456789012345678901234567890123456789", false},
{"empty string", testModelName, "", false},
{"short signature", testModelName, "abc", false},
{"gemini sentinel", "gemini-3-pro-preview", "skip_thought_signature_validator", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := HasValidSignature(tt.signature)
result := HasValidSignature(tt.modelName, tt.signature)
if result != tt.expected {
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
}
@@ -143,21 +142,19 @@ func TestHasValidSignature(t *testing.T) {
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
ClearSignatureCache("")
sessionID := "hash-test-session"
// Different texts should produce different hashes
text1 := "First thinking text"
text2 := "Second thinking text"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text1, sig1)
CacheSignature(sessionID, text2, sig2)
CacheSignature(testModelName, text1, sig1)
CacheSignature(testModelName, text2, sig2)
if GetCachedSignature(sessionID, text1) != sig1 {
if GetCachedSignature(testModelName, text1) != sig1 {
t.Error("text1 signature mismatch")
}
if GetCachedSignature(sessionID, text2) != sig2 {
if GetCachedSignature(testModelName, text2) != sig2 {
t.Error("text2 signature mismatch")
}
}
@@ -165,13 +162,12 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
func TestCacheSignature_UnicodeText(t *testing.T) {
ClearSignatureCache("")
sessionID := "unicode-session"
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
sig := "unicodeSig123456789012345678901234567890123456789012345"
CacheSignature(sessionID, text, sig)
CacheSignature(testModelName, text, sig)
if got := GetCachedSignature(sessionID, text); got != sig {
if got := GetCachedSignature(testModelName, text); got != sig {
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
}
}
@@ -179,15 +175,14 @@ func TestCacheSignature_UnicodeText(t *testing.T) {
func TestCacheSignature_Overwrite(t *testing.T) {
ClearSignatureCache("")
sessionID := "overwrite-session"
text := "Same text"
sig1 := "firstSignature12345678901234567890123456789012345678901"
sig2 := "secondSignature1234567890123456789012345678901234567890"
CacheSignature(sessionID, text, sig1)
CacheSignature(sessionID, text, sig2) // Overwrite
CacheSignature(testModelName, text, sig1)
CacheSignature(testModelName, text, sig2) // Overwrite
if got := GetCachedSignature(sessionID, text); got != sig2 {
if got := GetCachedSignature(testModelName, text); got != sig2 {
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
}
}
@@ -199,14 +194,13 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
// This test verifies the expiration check exists
// In a real scenario, we'd mock time.Now()
sessionID := "expiration-test"
text := "text"
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text, sig)
CacheSignature(testModelName, text, sig)
// Fresh entry should be retrievable
if got := GetCachedSignature(sessionID, text); got != sig {
if got := GetCachedSignature(testModelName, text); got != sig {
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
}

View File

@@ -4,6 +4,7 @@
package logging
import (
"errors"
"fmt"
"net/http"
"runtime/debug"
@@ -112,6 +113,11 @@ func isAIAPIPath(path string) bool {
// - gin.HandlerFunc: A middleware handler for panic recovery
func GinLogrusRecovery() gin.HandlerFunc {
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
panic(http.ErrAbortHandler)
}
log.WithFields(log.Fields{
"panic": recovered,
"stack": string(debug.Stack()),

View File

@@ -0,0 +1,60 @@
package logging
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(GinLogrusRecovery())
engine.GET("/abort", func(c *gin.Context) {
panic(http.ErrAbortHandler)
})
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
recorder := httptest.NewRecorder()
defer func() {
recovered := recover()
if recovered == nil {
t.Fatalf("expected panic, got nil")
}
err, ok := recovered.(error)
if !ok {
t.Fatalf("expected error panic, got %T", recovered)
}
if !errors.Is(err, http.ErrAbortHandler) {
t.Fatalf("expected ErrAbortHandler, got %v", err)
}
if err != http.ErrAbortHandler {
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
}
}()
engine.ServeHTTP(recorder, req)
}
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(GinLogrusRecovery())
engine.GET("/panic", func(c *gin.Context) {
panic("boom")
})
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
recorder := httptest.NewRecorder()
engine.ServeHTTP(recorder, req)
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", recorder.Code)
}
}

View File

@@ -1033,10 +1033,10 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
"owned_by": model.OwnedBy,
}
if model.Created > 0 {
result["created"] = model.Created
result["created_at"] = model.Created
}
if model.Type != "" {
result["type"] = model.Type
result["type"] = "model"
}
if model.DisplayName != "" {
result["display_name"] = model.DisplayName

View File

@@ -398,7 +398,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
return nil, translatedPayload{}, err
}
payload = fixGeminiImageAspectRatio(baseModel, payload)
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")

View File

@@ -142,7 +142,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
return resp, err
}
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -261,7 +262,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
return resp, err
}
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -627,7 +629,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
return nil, err
}
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -994,7 +997,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
now := time.Now().Unix()
modelConfig := registry.GetAntigravityModelConfig()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for originalName := range result.Map() {
for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
@@ -1004,12 +1007,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
continue
}
modelCfg := modelConfig[modelID]
modelName := modelID
// Extract displayName from upstream response, fallback to modelID
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
modelInfo := &registry.ModelInfo{
ID: modelID,
Name: modelName,
Description: modelID,
DisplayName: modelID,
Name: modelID,
Description: displayName,
DisplayName: displayName,
Version: modelID,
Object: "model",
Created: now,
@@ -1202,7 +1211,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
payload = geminiToAntigravity(modelName, payload, projectID)
payload, _ = sjson.SetBytes(payload, "model", modelName)
if strings.Contains(modelName, "claude") {
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
strJSON := string(payload)
paths := make([]string, 0)
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
@@ -1213,7 +1222,17 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
// Use the centralized schema cleaner to handle unsupported keywords,
// const->enum conversion, and flattening of types/anyOf.
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
payload = []byte(strJSON)
} else {
strJSON := string(payload)
paths := make([]string, 0)
util.Walk(gjson.Parse(strJSON), "", "parametersJsonSchema", &paths)
for _, p := range paths {
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
}
// Clean tool schemas for Gemini to remove unsupported JSON Schema keywords
// without adding empty-schema placeholders.
strJSON = util.CleanJSONSchemaForGemini(strJSON)
payload = []byte(strJSON)
}
@@ -1230,6 +1249,12 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
}
}
if strings.Contains(modelName, "claude") {
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
} else {
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
if errReq != nil {
return nil, errReq
@@ -1405,24 +1430,10 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
template, _ = sjson.Delete(template, "request.safetySettings")
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
if strings.Contains(modelName, "claude") {
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
if funcDecl.Get("parametersJsonSchema").Exists() {
template, _ = sjson.SetRaw(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()), funcDecl.Get("parametersJsonSchema").Raw)
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters.$schema", key.Int(), funKey.Int()))
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parametersJsonSchema", key.Int(), funKey.Int()))
}
return true
})
return true
})
} else {
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
template, _ = sjson.Delete(template, "toolConfig")
}
return []byte(template)
}

View File

@@ -114,7 +114,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
@@ -245,7 +246,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
@@ -731,6 +733,11 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have
// a "type" field and require their name to remain unchanged.
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
return true
}
name := tool.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true

View File

@@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("built-in tool name should not be prefixed: tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_my_custom_tool" {
t.Fatalf("custom tool should be prefixed: tools.1.name = %q, want %q", got, "proxy_my_custom_tool")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")

View File

@@ -101,7 +101,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
return resp, err
}
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
@@ -213,7 +214,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err
}
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")

View File

@@ -129,7 +129,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
action := "generateContent"
if req.Metadata != nil {
@@ -278,7 +279,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
projectID := resolveGeminiProjectID(auth)

View File

@@ -126,7 +126,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := "generateContent"
@@ -228,7 +229,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
baseURL := resolveGeminiBaseURL(auth)

View File

@@ -325,7 +325,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
}
@@ -438,7 +439,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, false)
@@ -541,7 +543,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
@@ -664,7 +667,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)

View File

@@ -98,7 +98,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
}
body = preserveReasoningContentInMessages(body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -201,7 +202,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body)
}
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint

View File

@@ -90,7 +90,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
@@ -185,7 +186,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {

View File

@@ -5,6 +5,8 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -12,8 +14,9 @@ import (
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied. Defaults are checked
// against the original payload when provided.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte) []byte {
// against the original payload when provided. requestedModel carries the client-visible
// model name before alias resolution so payload rules can target aliases precisely.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
if cfg == nil || len(payload) == 0 {
return payload
}
@@ -22,10 +25,11 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
return payload
}
model = strings.TrimSpace(model)
if model == "" {
requestedModel = strings.TrimSpace(requestedModel)
if model == "" && requestedModel == "" {
return payload
}
candidates := payloadModelCandidates(cfg, model, protocol)
candidates := payloadModelCandidates(model, requestedModel)
out := payload
source := original
if len(source) == 0 {
@@ -163,65 +167,42 @@ func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) b
return false
}
func payloadModelCandidates(cfg *config.Config, model, protocol string) []string {
func payloadModelCandidates(model, requestedModel string) []string {
model = strings.TrimSpace(model)
if model == "" {
requestedModel = strings.TrimSpace(requestedModel)
if model == "" && requestedModel == "" {
return nil
}
candidates := []string{model}
if cfg == nil {
return candidates
}
aliases := payloadModelAliases(cfg, model, protocol)
if len(aliases) == 0 {
return candidates
}
seen := map[string]struct{}{strings.ToLower(model): struct{}{}}
for _, alias := range aliases {
alias = strings.TrimSpace(alias)
if alias == "" {
continue
candidates := make([]string, 0, 3)
seen := make(map[string]struct{}, 3)
addCandidate := func(value string) {
value = strings.TrimSpace(value)
if value == "" {
return
}
key := strings.ToLower(alias)
key := strings.ToLower(value)
if _, ok := seen[key]; ok {
continue
return
}
seen[key] = struct{}{}
candidates = append(candidates, alias)
candidates = append(candidates, value)
}
if model != "" {
addCandidate(model)
}
if requestedModel != "" {
parsed := thinking.ParseSuffix(requestedModel)
base := strings.TrimSpace(parsed.ModelName)
if base != "" {
addCandidate(base)
}
if parsed.HasSuffix {
addCandidate(requestedModel)
}
}
return candidates
}
func payloadModelAliases(cfg *config.Config, model, protocol string) []string {
if cfg == nil {
return nil
}
model = strings.TrimSpace(model)
if model == "" {
return nil
}
channel := strings.ToLower(strings.TrimSpace(protocol))
if channel == "" {
return nil
}
entries := cfg.OAuthModelAlias[channel]
if len(entries) == 0 {
return nil
}
aliases := make([]string, 0, 2)
for _, entry := range entries {
if !strings.EqualFold(strings.TrimSpace(entry.Name), model) {
continue
}
alias := strings.TrimSpace(entry.Alias)
if alias == "" {
continue
}
aliases = append(aliases, alias)
}
return aliases
}
// buildPayloadPath combines an optional root path with a relative parameter path.
// When root is empty, the parameter path is used as-is. When root is non-empty,
// the parameter path is treated as relative to root.
@@ -258,6 +239,35 @@ func payloadRawValue(value any) ([]byte, bool) {
}
}
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
fallback = strings.TrimSpace(fallback)
if len(opts.Metadata) == 0 {
return fallback
}
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
if !ok || raw == nil {
return fallback
}
switch v := raw.(type) {
case string:
if strings.TrimSpace(v) == "" {
return fallback
}
return strings.TrimSpace(v)
case []byte:
if len(v) == 0 {
return fallback
}
trimmed := strings.TrimSpace(string(v))
if trimmed == "" {
return fallback
}
return trimmed
default:
return fallback
}
}
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
// Examples:
//

View File

@@ -91,7 +91,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -184,7 +185,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
}
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))

View File

@@ -7,8 +7,6 @@ package claude
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
@@ -19,37 +17,6 @@ import (
"github.com/tidwall/sjson"
)
// deriveSessionID generates a stable session ID from the request.
// Uses the hash of the first user message to identify the conversation.
func deriveSessionID(rawJSON []byte) string {
userIDResult := gjson.GetBytes(rawJSON, "metadata.user_id")
if userIDResult.Exists() {
userID := userIDResult.String()
idx := strings.Index(userID, "session_")
if idx != -1 {
return userID[idx+8:]
}
}
messages := gjson.GetBytes(rawJSON, "messages")
if !messages.IsArray() {
return ""
}
for _, msg := range messages.Array() {
if msg.Get("role").String() == "user" {
content := msg.Get("content").String()
if content == "" {
// Try to get text from content array
content = msg.Get("content.0.text").String()
}
if content != "" {
h := sha256.Sum256([]byte(content))
return hex.EncodeToString(h[:16])
}
}
}
return ""
}
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
@@ -72,9 +39,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
enableThoughtTranslate := true
rawJSON := bytes.Clone(inputRawJSON)
// Derive session ID for signature caching
sessionID := deriveSessionID(rawJSON)
// system instruction
systemInstructionJSON := ""
hasSystemInstruction := false
@@ -137,8 +101,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if sessionID != "" && thinkingText != "" {
if cachedSig := cache.GetCachedSignature(modelName, sessionID, thinkingText); cachedSig != "" {
if thinkingText != "" {
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
signature = cachedSig
// log.Debugf("Using cached signature for thinking block")
}
@@ -156,19 +120,19 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
}
}
if cache.HasValidSignature(clientSignature) {
if cache.HasValidSignature(modelName, clientSignature) {
signature = clientSignature
}
// log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(signature) {
if cache.HasValidSignature(modelName, signature) {
currentMessageThinkingSignature = signature
}
// Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(signature)
isUnsigned := !cache.HasValidSignature(modelName, signature)
// If unsigned, skip entirely (don't convert to text)
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
@@ -223,7 +187,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// This is the approach used in opencode-google-antigravity-auth for Gemini
// and also works for Claude through Antigravity API
const skipSentinel = "skip_thought_signature_validator"
if cache.HasValidSignature(currentMessageThinkingSignature) {
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
} else {
// No valid signature - use skip sentinel to bypass validation

View File

@@ -74,13 +74,13 @@ func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
}
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
cache.ClearSignatureCache("")
// Valid signature must be at least 50 characters
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think..."
// Pre-cache the signature (simulating a response from the same session)
// The session ID is derived from the first user message hash
// Since there's no user message in this test, we need to add one
// Pre-cache the signature (simulating a previous response for the same thinking text)
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
@@ -98,10 +98,7 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
]
}`)
// Derive session ID and cache the signature
sessionID := deriveSessionID(inputJSON)
cache.CacheSignature(sessionID, thinkingText, validSignature)
defer cache.ClearSignatureCache(sessionID)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
@@ -120,6 +117,8 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
}
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
cache.ClearSignatureCache("")
// Unsigned thinking blocks should be removed entirely (not converted to text)
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
@@ -241,6 +240,8 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
}
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
cache.ClearSignatureCache("")
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think..."
@@ -266,10 +267,7 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
]
}`)
// Derive session ID and cache the signature
sessionID := deriveSessionID(inputJSON)
cache.CacheSignature(sessionID, thinkingText, validSignature)
defer cache.ClearSignatureCache(sessionID)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
@@ -285,6 +283,8 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
}
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
cache.ClearSignatureCache("")
// Case: text block followed by thinking block -> should be reordered to thinking first
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Planning..."
@@ -306,10 +306,7 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
]
}`)
// Derive session ID and cache the signature
sessionID := deriveSessionID(inputJSON)
cache.CacheSignature(sessionID, thinkingText, validSignature)
defer cache.ClearSignatureCache(sessionID)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
@@ -496,6 +493,8 @@ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *t
}
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
cache.ClearSignatureCache("")
// Last assistant message ends with signed thinking block - should be kept
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Valid thinking..."
@@ -517,10 +516,7 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin
]
}`)
// Derive session ID and cache the signature
sessionID := deriveSessionID(inputJSON)
cache.CacheSignature(sessionID, thinkingText, validSignature)
defer cache.ClearSignatureCache(sessionID)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)

View File

@@ -41,7 +41,6 @@ type Params struct {
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
// Signature caching support
SessionID string // Session ID derived from request for signature caching
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
}
@@ -70,7 +69,6 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
HasFirstResponse: false,
ResponseType: 0,
ResponseIndex: 0,
SessionID: deriveSessionID(originalRequestRawJSON),
}
}
modelName := gjson.GetBytes(requestRawJSON, "model").String()
@@ -139,9 +137,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
// log.Debug("Branch: signature_delta")
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(modelName, params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
if params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
// log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset()
}

View File

@@ -12,10 +12,10 @@ import (
// Signature Caching Tests
// ============================================================================
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
func TestConvertAntigravityResponseToClaude_ParamsInitialized(t *testing.T) {
cache.ClearSignatureCache("")
// Request with user message - should derive session ID
// Request with user message - should initialize params
requestJSON := []byte(`{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
@@ -37,10 +37,12 @@ func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
ctx := context.Background()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, &param)
// Verify session ID was set
params := param.(*Params)
if params.SessionID == "" {
t.Error("SessionID should be derived from request")
if !params.HasFirstResponse {
t.Error("HasFirstResponse should be set after first chunk")
}
if params.CurrentThinkingText.Len() == 0 {
t.Error("Thinking text should be accumulated")
}
}
@@ -97,6 +99,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
}`)
@@ -129,12 +132,8 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
// Process thinking chunk
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, &param)
params := param.(*Params)
sessionID := params.SessionID
thinkingText := params.CurrentThinkingText.String()
if sessionID == "" {
t.Fatal("SessionID should be set")
}
if thinkingText == "" {
t.Fatal("Thinking text should be accumulated")
}
@@ -143,7 +142,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, &param)
// Verify signature was cached
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
cachedSig := cache.GetCachedSignature("claude-sonnet-4-5-thinking", thinkingText)
if cachedSig != validSignature {
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
}
@@ -158,6 +157,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
}`)
@@ -221,13 +221,12 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
// Process first thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, &param)
params := param.(*Params)
sessionID := params.SessionID
firstThinkingText := params.CurrentThinkingText.String()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, &param)
// Verify first signature cached
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", firstThinkingText) != validSig1 {
t.Error("First thinking block signature should be cached")
}
@@ -241,76 +240,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, &param)
// Verify second signature cached
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", secondThinkingText) != validSig2 {
t.Error("Second thinking block signature should be cached")
}
}
func TestDeriveSessionIDFromRequest(t *testing.T) {
tests := []struct {
name string
input []byte
wantEmpty bool
}{
{
name: "valid user message",
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
wantEmpty: false,
},
{
name: "user message with content array",
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
wantEmpty: false,
},
{
name: "no user message",
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
wantEmpty: true,
},
{
name: "empty messages",
input: []byte(`{"messages": []}`),
wantEmpty: true,
},
{
name: "no messages field",
input: []byte(`{}`),
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := deriveSessionID(tt.input)
if tt.wantEmpty && result != "" {
t.Errorf("Expected empty session ID, got '%s'", result)
}
if !tt.wantEmpty && result == "" {
t.Error("Expected non-empty session ID")
}
})
}
}
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
id1 := deriveSessionID(input)
id2 := deriveSessionID(input)
if id1 != id2 {
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
}
}
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
id1 := deriveSessionID(input1)
id2 := deriveSessionID(input2)
if id1 == id2 {
t.Error("Different messages should produce different session IDs")
}
}

View File

@@ -8,6 +8,7 @@ package gemini
import (
"bytes"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -32,12 +33,12 @@ import (
//
// Returns:
// - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []byte {
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
template := ""
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template)
@@ -97,37 +98,40 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
}
}
// Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
const skipSentinel = "skip_thought_signature_validator"
// Gemini-specific handling for non-Claude models:
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
if !strings.Contains(modelName, "claude") {
const skipSentinel = "skip_thought_signature_validator"
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" {
// First pass: collect indices of thinking parts to remove
var thinkingIndicesToRemove []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Mark thinking blocks for removal
if part.Get("thought").Bool() {
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
}
// Add skip sentinel to functionCall parts
if part.Get("functionCall").Exists() {
existingSig := part.Get("thoughtSignature").String()
if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" {
// First pass: collect indices of thinking parts to mark with skip sentinel
var thinkingIndicesToSkipSignature []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Collect indices of thinking blocks to mark with skip sentinel
if part.Get("thought").Bool() {
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
}
}
return true
})
// Add skip sentinel to functionCall parts
if part.Get("functionCall").Exists() {
existingSig := part.Get("thoughtSignature").String()
if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
}
}
return true
})
// Remove thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
idx := thinkingIndicesToRemove[i]
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
idx := thinkingIndicesToSkipSignature[i]
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
}
}
}
return true
})
return true
})
}
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
}

View File

@@ -62,40 +62,6 @@ func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *test
}
}
func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) {
// Thinking blocks should be removed entirely for Gemini
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(fmt.Sprintf(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"thought": true, "text": "Thinking...", "thoughtSignature": "%s"},
{"text": "Here is my response"}
]
}
]
}`, validSignature))
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that thinking block is removed
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed for Gemini")
}
if parts[0].Get("text").String() != "Here is my response" {
t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String())
}
}
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
// Multiple functionCalls should all get skip_thought_signature_validator
inputJSON := []byte(`{

View File

@@ -305,12 +305,12 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
}
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
toolNode := []byte(`{}`)
hasTool := false
functionToolNode := []byte(`{}`)
hasFunction := false
googleSearchNodes := make([][]byte, 0)
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
@@ -349,31 +349,37 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
toolNode = tmp
functionToolNode = tmp
hasFunction = true
hasTool = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
hasTool = true
googleSearchNodes = append(googleSearchNodes, googleToolNode)
}
}
if hasTool {
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
if hasFunction || len(googleSearchNodes) > 0 {
toolsNode := []byte("[]")
if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
}
}

View File

@@ -98,9 +98,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
// Temperature setting for controlling response randomness
if temp := genConfig.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
}
// Top P setting for nucleus sampling
if topP := genConfig.Get("topP"); topP.Exists() {
} else if topP := genConfig.Get("topP"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float())
}
// Stop sequences configuration for custom termination conditions

View File

@@ -110,10 +110,8 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
// Temperature setting for controlling response randomness
if temp := root.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
}
// Top P setting for nucleus sampling
if topP := root.Get("top_p"); topP.Exists() {
} else if topP := root.Get("top_p"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float())
}

View File

@@ -283,12 +283,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
}
}
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
toolNode := []byte(`{}`)
hasTool := false
functionToolNode := []byte(`{}`)
hasFunction := false
googleSearchNodes := make([][]byte, 0)
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
@@ -327,31 +327,37 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
toolNode = tmp
functionToolNode = tmp
hasFunction = true
hasTool = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
hasTool = true
googleSearchNodes = append(googleSearchNodes, googleToolNode)
}
}
if hasTool {
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
if hasFunction || len(googleSearchNodes) > 0 {
toolsNode := []byte("[]")
if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
}
}

View File

@@ -289,12 +289,12 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
}
// tools -> tools[0].functionDeclarations + tools[0].googleSearch passthrough
// tools -> tools[].functionDeclarations + tools[].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
toolNode := []byte(`{}`)
hasTool := false
functionToolNode := []byte(`{}`)
hasFunction := false
googleSearchNodes := make([][]byte, 0)
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
@@ -333,31 +333,37 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
toolNode = tmp
functionToolNode = tmp
hasFunction = true
hasTool = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
hasTool = true
googleSearchNodes = append(googleSearchNodes, googleToolNode)
}
}
if hasTool {
out, _ = sjson.SetRawBytes(out, "tools", []byte("[]"))
out, _ = sjson.SetRawBytes(out, "tools.0", toolNode)
if hasFunction || len(googleSearchNodes) > 0 {
toolsNode := []byte("[]")
if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
out, _ = sjson.SetRawBytes(out, "tools", toolsNode)
}
}

View File

@@ -298,6 +298,15 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
}
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
case "reasoning":
thoughtContent := `{"role":"model","parts":[]}`
thought := `{"text":"","thoughtSignature":"","thought":true}`
thought, _ = sjson.Set(thought, "text", item.Get("summary.0.text").String())
thought, _ = sjson.Set(thought, "thoughtSignature", item.Get("encrypted_content").String())
thoughtContent, _ = sjson.SetRaw(thoughtContent, "parts.-1", thought)
out, _ = sjson.SetRaw(out, "contents.-1", thoughtContent)
}
}
} else if input.Exists() && input.Type == gjson.String {

View File

@@ -20,6 +20,7 @@ type geminiToResponsesState struct {
// message aggregation
MsgOpened bool
MsgClosed bool
MsgIndex int
CurrentMsgID string
TextBuf strings.Builder
@@ -29,6 +30,7 @@ type geminiToResponsesState struct {
ReasoningOpened bool
ReasoningIndex int
ReasoningItemID string
ReasoningEnc string
ReasoningBuf strings.Builder
ReasoningClosed bool
@@ -37,6 +39,7 @@ type geminiToResponsesState struct {
FuncArgsBuf map[int]*strings.Builder
FuncNames map[int]string
FuncCallIDs map[int]string
FuncDone map[int]bool
}
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
@@ -45,6 +48,39 @@ var responseIDCounter uint64
// funcCallIDCounter provides a process-wide unique counter for function call identifiers.
var funcCallIDCounter uint64
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
return originalRequestRawJSON
}
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
return requestRawJSON
}
return nil
}
func unwrapRequestRoot(root gjson.Result) gjson.Result {
req := root.Get("request")
if !req.Exists() {
return root
}
if req.Get("model").Exists() || req.Get("input").Exists() || req.Get("instructions").Exists() {
return req
}
return root
}
func unwrapGeminiResponseRoot(root gjson.Result) gjson.Result {
resp := root.Get("response")
if !resp.Exists() {
return root
}
// Vertex-style Gemini responses wrap the actual payload in a "response" object.
if resp.Get("candidates").Exists() || resp.Get("responseId").Exists() || resp.Get("usageMetadata").Exists() {
return resp
}
return root
}
func emitEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
}
@@ -56,18 +92,37 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
FuncArgsBuf: make(map[int]*strings.Builder),
FuncNames: make(map[int]string),
FuncCallIDs: make(map[int]string),
FuncDone: make(map[int]bool),
}
}
st := (*param).(*geminiToResponsesState)
if st.FuncArgsBuf == nil {
st.FuncArgsBuf = make(map[int]*strings.Builder)
}
if st.FuncNames == nil {
st.FuncNames = make(map[int]string)
}
if st.FuncCallIDs == nil {
st.FuncCallIDs = make(map[int]string)
}
if st.FuncDone == nil {
st.FuncDone = make(map[int]bool)
}
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
rawJSON = bytes.TrimSpace(rawJSON)
if len(rawJSON) == 0 || bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
root := gjson.ParseBytes(rawJSON)
if !root.Exists() {
return []string{}
}
root = unwrapGeminiResponseRoot(root)
var out []string
nextSeq := func() int { st.Seq++; return st.Seq }
@@ -98,19 +153,54 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "item.id", st.ReasoningItemID)
itemDone, _ = sjson.Set(itemDone, "output_index", st.ReasoningIndex)
itemDone, _ = sjson.Set(itemDone, "item.encrypted_content", st.ReasoningEnc)
itemDone, _ = sjson.Set(itemDone, "item.summary.0.text", full)
out = append(out, emitEvent("response.output_item.done", itemDone))
st.ReasoningClosed = true
}
// Helper to finalize the assistant message in correct order.
// It emits response.output_text.done, response.content_part.done,
// and response.output_item.done exactly once.
finalizeMessage := func() {
if !st.MsgOpened || st.MsgClosed {
return
}
fullText := st.ItemTextBuf.String()
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final, _ = sjson.Set(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
st.MsgClosed = true
}
// Initialize per-response fields and emit created/in_progress once
if !st.Started {
if v := root.Get("responseId"); v.Exists() {
st.ResponseID = v.String()
st.ResponseID = root.Get("responseId").String()
if st.ResponseID == "" {
st.ResponseID = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
}
if !strings.HasPrefix(st.ResponseID, "resp_") {
st.ResponseID = fmt.Sprintf("resp_%s", st.ResponseID)
}
if v := root.Get("createTime"); v.Exists() {
if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil {
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
st.CreatedAt = t.Unix()
}
}
@@ -143,15 +233,21 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// Ignore any late thought chunks after reasoning is finalized.
return true
}
if sig := part.Get("thoughtSignature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
st.ReasoningEnc = sig.String()
} else if sig = part.Get("thought_signature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
st.ReasoningEnc = sig.String()
}
if !st.ReasoningOpened {
st.ReasoningOpened = true
st.ReasoningIndex = st.NextIndex
st.NextIndex++
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, st.ReasoningIndex)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","encrypted_content":"","summary":[]}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", st.ReasoningIndex)
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
item, _ = sjson.Set(item, "item.encrypted_content", st.ReasoningEnc)
out = append(out, emitEvent("response.output_item.added", item))
partAdded := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
partAdded, _ = sjson.Set(partAdded, "sequence_number", nextSeq())
@@ -191,9 +287,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
out = append(out, emitEvent("response.content_part.added", partAdded))
st.ItemTextBuf.Reset()
st.ItemTextBuf.WriteString(t.String())
}
st.TextBuf.WriteString(t.String())
st.ItemTextBuf.WriteString(t.String())
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
@@ -205,8 +301,10 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// Function call
if fc := part.Get("functionCall"); fc.Exists() {
// Before emitting function-call outputs, finalize reasoning if open.
// Before emitting function-call outputs, finalize reasoning and the message (if open).
// Responses streaming requires message done events before the next output_item.added.
finalizeReasoning()
finalizeMessage()
name := fc.Get("name").String()
idx := st.NextIndex
st.NextIndex++
@@ -219,6 +317,14 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
}
st.FuncNames[idx] = name
argsJSON := "{}"
if args := fc.Get("args"); args.Exists() {
argsJSON = args.Raw
}
if st.FuncArgsBuf[idx].Len() == 0 && argsJSON != "" {
st.FuncArgsBuf[idx].WriteString(argsJSON)
}
// Emit item.added for function call
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
@@ -228,10 +334,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
item, _ = sjson.Set(item, "item.name", name)
out = append(out, emitEvent("response.output_item.added", item))
// Emit arguments delta (full args in one chunk)
if args := fc.Get("args"); args.Exists() {
argsJSON := args.Raw
st.FuncArgsBuf[idx].WriteString(argsJSON)
// Emit arguments delta (full args in one chunk).
// When Gemini omits args, emit "{}" to keep Responses streaming event order consistent.
if argsJSON != "" {
ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
ad, _ = sjson.Set(ad, "sequence_number", nextSeq())
ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
@@ -240,6 +345,27 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
out = append(out, emitEvent("response.function_call_arguments.delta", ad))
}
// Gemini emits the full function call payload at once, so we can finalize it immediately.
if !st.FuncDone[idx] {
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
fcDone, _ = sjson.Set(fcDone, "arguments", argsJSON)
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
itemDone, _ = sjson.Set(itemDone, "item.arguments", argsJSON)
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone))
st.FuncDone[idx] = true
}
return true
}
@@ -251,28 +377,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
if fr := root.Get("candidates.0.finishReason"); fr.Exists() && fr.String() != "" {
// Finalize reasoning first to keep ordering tight with last delta
finalizeReasoning()
// Close message output if opened
if st.MsgOpened {
fullText := st.ItemTextBuf.String()
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final, _ = sjson.Set(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
}
finalizeMessage()
// Close function calls
if len(st.FuncArgsBuf) > 0 {
@@ -289,6 +394,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
}
}
for _, idx := range idxs {
if st.FuncDone[idx] {
continue
}
args := "{}"
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
args = b.String()
@@ -308,6 +416,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone))
st.FuncDone[idx] = true
}
}
@@ -319,8 +429,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String())
}
@@ -383,41 +493,34 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
}
}
// Compose outputs in encountered order: reasoning, message, function_calls
// Compose outputs in output_index order.
outputsWrapper := `{"arr":[]}`
if st.ReasoningOpened {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
if st.MsgOpened {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
if len(st.FuncArgsBuf) > 0 {
idxs := make([]int, 0, len(st.FuncArgsBuf))
for idx := range st.FuncArgsBuf {
idxs = append(idxs, idx)
for idx := 0; idx < st.NextIndex; idx++ {
if st.ReasoningOpened && idx == st.ReasoningIndex {
item := `{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "encrypted_content", st.ReasoningEnc)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
continue
}
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
if st.MsgOpened && idx == st.MsgIndex {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
continue
}
for _, idx := range idxs {
args := ""
if b := st.FuncArgsBuf[idx]; b != nil {
if callID, ok := st.FuncCallIDs[idx]; ok && callID != "" {
args := "{}"
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
args = b.String()
}
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", st.FuncCallIDs[idx])
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", st.FuncNames[idx])
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
@@ -431,8 +534,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
// cached_tokens not provided by Gemini; default to 0 for structure compatibility
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
// cached token details: align with OpenAI "cached_tokens" semantics.
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
// output tokens
if v := um.Get("candidatesTokenCount"); v.Exists() {
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
@@ -460,6 +563,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// ConvertGeminiResponseToOpenAIResponsesNonStream aggregates Gemini response JSON into a single OpenAI Responses JSON object.
func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
root := gjson.ParseBytes(rawJSON)
root = unwrapGeminiResponseRoot(root)
// Base response scaffold
resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`
@@ -478,15 +582,15 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// created_at: map from createTime if available
createdAt := time.Now().Unix()
if v := root.Get("createTime"); v.Exists() {
if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil {
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
createdAt = t.Unix()
}
}
resp, _ = sjson.Set(resp, "created_at", createdAt)
// Echo request fields when present; fallback model from response modelVersion
if len(requestRawJSON) > 0 {
req := gjson.ParseBytes(requestRawJSON)
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
if v := req.Get("instructions"); v.Exists() {
resp, _ = sjson.Set(resp, "instructions", v.String())
}
@@ -636,8 +740,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
// cached_tokens not provided by Gemini; default to 0 for structure compatibility
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", 0)
// cached token details: align with OpenAI "cached_tokens" semantics.
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
// output tokens
if v := um.Get("candidatesTokenCount"); v.Exists() {
resp, _ = sjson.Set(resp, "usage.output_tokens", v.Int())

View File

@@ -0,0 +1,353 @@
package responses
import (
"context"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func parseSSEEvent(t *testing.T, chunk string) (string, gjson.Result) {
t.Helper()
lines := strings.Split(chunk, "\n")
if len(lines) < 2 {
t.Fatalf("unexpected SSE chunk: %q", chunk)
}
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
if !gjson.Valid(dataLine) {
t.Fatalf("invalid SSE data JSON: %q", dataLine)
}
return event, gjson.Parse(dataLine)
}
func TestConvertGeminiResponseToOpenAIResponses_UnwrapAndAggregateText(t *testing.T) {
// Vertex-style Gemini stream wraps the actual response payload under "response".
// This test ensures we unwrap and that output_text.done contains the full text.
in := []string{
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"让"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"我先"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"了解"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"mcp__serena__list_dir","args":{"recursive":false,"relative_path":"internal"},"id":"toolu_1"}}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":2},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
}
originalReq := []byte(`{"instructions":"test instructions","model":"gpt-5","max_output_tokens":123}`)
var param any
var out []string
for _, line := range in {
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", originalReq, nil, []byte(line), &param)...)
}
var (
gotTextDone bool
gotMessageDone bool
gotResponseDone bool
gotFuncDone bool
textDone string
messageText string
responseID string
instructions string
cachedTokens int64
funcName string
funcArgs string
posTextDone = -1
posPartDone = -1
posMessageDone = -1
posFuncAdded = -1
)
for i, chunk := range out {
ev, data := parseSSEEvent(t, chunk)
switch ev {
case "response.output_text.done":
gotTextDone = true
if posTextDone == -1 {
posTextDone = i
}
textDone = data.Get("text").String()
case "response.content_part.done":
if posPartDone == -1 {
posPartDone = i
}
case "response.output_item.done":
switch data.Get("item.type").String() {
case "message":
gotMessageDone = true
if posMessageDone == -1 {
posMessageDone = i
}
messageText = data.Get("item.content.0.text").String()
case "function_call":
gotFuncDone = true
funcName = data.Get("item.name").String()
funcArgs = data.Get("item.arguments").String()
}
case "response.output_item.added":
if data.Get("item.type").String() == "function_call" && posFuncAdded == -1 {
posFuncAdded = i
}
case "response.completed":
gotResponseDone = true
responseID = data.Get("response.id").String()
instructions = data.Get("response.instructions").String()
cachedTokens = data.Get("response.usage.input_tokens_details.cached_tokens").Int()
}
}
if !gotTextDone {
t.Fatalf("missing response.output_text.done event")
}
if posTextDone == -1 || posPartDone == -1 || posMessageDone == -1 || posFuncAdded == -1 {
t.Fatalf("missing ordering events: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
}
if !(posTextDone < posPartDone && posPartDone < posMessageDone && posMessageDone < posFuncAdded) {
t.Fatalf("unexpected message/function ordering: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
}
if !gotMessageDone {
t.Fatalf("missing message response.output_item.done event")
}
if !gotFuncDone {
t.Fatalf("missing function_call response.output_item.done event")
}
if !gotResponseDone {
t.Fatalf("missing response.completed event")
}
if textDone != "让我先了解" {
t.Fatalf("unexpected output_text.done text: got %q", textDone)
}
if messageText != "让我先了解" {
t.Fatalf("unexpected message done text: got %q", messageText)
}
if responseID != "resp_req_vrtx_1" {
t.Fatalf("unexpected response id: got %q", responseID)
}
if instructions != "test instructions" {
t.Fatalf("unexpected instructions echo: got %q", instructions)
}
if cachedTokens != 2 {
t.Fatalf("unexpected cached token count: got %d", cachedTokens)
}
if funcName != "mcp__serena__list_dir" {
t.Fatalf("unexpected function name: got %q", funcName)
}
if !gjson.Valid(funcArgs) {
t.Fatalf("invalid function arguments JSON: %q", funcArgs)
}
if gjson.Get(funcArgs, "recursive").Bool() != false {
t.Fatalf("unexpected recursive arg: %v", gjson.Get(funcArgs, "recursive").Value())
}
if gjson.Get(funcArgs, "relative_path").String() != "internal" {
t.Fatalf("unexpected relative_path arg: %q", gjson.Get(funcArgs, "relative_path").String())
}
}
func TestConvertGeminiResponseToOpenAIResponses_ReasoningEncryptedContent(t *testing.T) {
sig := "RXE0RENrZ0lDeEFDR0FJcVFOZDdjUzlleGFuRktRdFcvSzNyZ2MvWDNCcDQ4RmxSbGxOWUlOVU5kR1l1UHMrMGdkMVp0Vkg3ekdKU0g4YVljc2JjN3lNK0FrdGpTNUdqamI4T3Z0VVNETzdQd3pmcFhUOGl3U3hXUEJvTVFRQ09mWTFyMEtTWGZxUUlJakFqdmFGWk83RW1XRlBKckJVOVpkYzdDKw=="
in := []string{
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"thoughtSignature":"` + sig + `","text":""}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"a"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hello"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
}
var param any
var out []string
for _, line := range in {
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), &param)...)
}
var (
addedEnc string
doneEnc string
)
for _, chunk := range out {
ev, data := parseSSEEvent(t, chunk)
switch ev {
case "response.output_item.added":
if data.Get("item.type").String() == "reasoning" {
addedEnc = data.Get("item.encrypted_content").String()
}
case "response.output_item.done":
if data.Get("item.type").String() == "reasoning" {
doneEnc = data.Get("item.encrypted_content").String()
}
}
}
if addedEnc != sig {
t.Fatalf("unexpected encrypted_content in response.output_item.added: got %q", addedEnc)
}
if doneEnc != sig {
t.Fatalf("unexpected encrypted_content in response.output_item.done: got %q", doneEnc)
}
}
func TestConvertGeminiResponseToOpenAIResponses_FunctionCallEventOrder(t *testing.T) {
in := []string{
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool1"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool2","args":{"a":1}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
}
var param any
var out []string
for _, line := range in {
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), &param)...)
}
posAdded := []int{-1, -1, -1}
posArgsDelta := []int{-1, -1, -1}
posArgsDone := []int{-1, -1, -1}
posItemDone := []int{-1, -1, -1}
posCompleted := -1
deltaByIndex := map[int]string{}
for i, chunk := range out {
ev, data := parseSSEEvent(t, chunk)
switch ev {
case "response.output_item.added":
if data.Get("item.type").String() != "function_call" {
continue
}
idx := int(data.Get("output_index").Int())
if idx >= 0 && idx < len(posAdded) {
posAdded[idx] = i
}
case "response.function_call_arguments.delta":
idx := int(data.Get("output_index").Int())
if idx >= 0 && idx < len(posArgsDelta) {
posArgsDelta[idx] = i
deltaByIndex[idx] = data.Get("delta").String()
}
case "response.function_call_arguments.done":
idx := int(data.Get("output_index").Int())
if idx >= 0 && idx < len(posArgsDone) {
posArgsDone[idx] = i
}
case "response.output_item.done":
if data.Get("item.type").String() != "function_call" {
continue
}
idx := int(data.Get("output_index").Int())
if idx >= 0 && idx < len(posItemDone) {
posItemDone[idx] = i
}
case "response.completed":
posCompleted = i
output := data.Get("response.output")
if !output.Exists() || !output.IsArray() {
t.Fatalf("missing response.output in response.completed")
}
if len(output.Array()) != 3 {
t.Fatalf("unexpected response.output length: got %d", len(output.Array()))
}
if data.Get("response.output.0.name").String() != "tool0" || data.Get("response.output.0.arguments").String() != "{}" {
t.Fatalf("unexpected output[0]: %s", data.Get("response.output.0").Raw)
}
if data.Get("response.output.1.name").String() != "tool1" || data.Get("response.output.1.arguments").String() != "{}" {
t.Fatalf("unexpected output[1]: %s", data.Get("response.output.1").Raw)
}
if data.Get("response.output.2.name").String() != "tool2" {
t.Fatalf("unexpected output[2] name: %s", data.Get("response.output.2").Raw)
}
if !gjson.Valid(data.Get("response.output.2.arguments").String()) {
t.Fatalf("unexpected output[2] arguments: %q", data.Get("response.output.2.arguments").String())
}
}
}
if posCompleted == -1 {
t.Fatalf("missing response.completed event")
}
for idx := 0; idx < 3; idx++ {
if posAdded[idx] == -1 || posArgsDelta[idx] == -1 || posArgsDone[idx] == -1 || posItemDone[idx] == -1 {
t.Fatalf("missing function call events for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
}
if !(posAdded[idx] < posArgsDelta[idx] && posArgsDelta[idx] < posArgsDone[idx] && posArgsDone[idx] < posItemDone[idx]) {
t.Fatalf("unexpected ordering for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
}
if idx > 0 && !(posItemDone[idx-1] < posAdded[idx]) {
t.Fatalf("function call events overlap between %d and %d: prevDone=%d nextAdded=%d", idx-1, idx, posItemDone[idx-1], posAdded[idx])
}
}
if deltaByIndex[0] != "{}" {
t.Fatalf("unexpected delta for output_index 0: got %q", deltaByIndex[0])
}
if deltaByIndex[1] != "{}" {
t.Fatalf("unexpected delta for output_index 1: got %q", deltaByIndex[1])
}
if deltaByIndex[2] == "" || !gjson.Valid(deltaByIndex[2]) || gjson.Get(deltaByIndex[2], "a").Int() != 1 {
t.Fatalf("unexpected delta for output_index 2: got %q", deltaByIndex[2])
}
if !(posItemDone[2] < posCompleted) {
t.Fatalf("response.completed should be after last output_item.done: last=%d completed=%d", posItemDone[2], posCompleted)
}
}
func TestConvertGeminiResponseToOpenAIResponses_ResponseOutputOrdering(t *testing.T) {
in := []string{
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0","args":{"x":"y"}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hi"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
}
var param any
var out []string
for _, line := range in {
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), &param)...)
}
posFuncDone := -1
posMsgAdded := -1
posCompleted := -1
for i, chunk := range out {
ev, data := parseSSEEvent(t, chunk)
switch ev {
case "response.output_item.done":
if data.Get("item.type").String() == "function_call" && data.Get("output_index").Int() == 0 {
posFuncDone = i
}
case "response.output_item.added":
if data.Get("item.type").String() == "message" && data.Get("output_index").Int() == 1 {
posMsgAdded = i
}
case "response.completed":
posCompleted = i
if data.Get("response.output.0.type").String() != "function_call" {
t.Fatalf("expected response.output[0] to be function_call: %s", data.Get("response.output.0").Raw)
}
if data.Get("response.output.1.type").String() != "message" {
t.Fatalf("expected response.output[1] to be message: %s", data.Get("response.output.1").Raw)
}
if data.Get("response.output.1.content.0.text").String() != "hi" {
t.Fatalf("unexpected message text in response.output[1]: %s", data.Get("response.output.1").Raw)
}
}
}
if posFuncDone == -1 || posMsgAdded == -1 || posCompleted == -1 {
t.Fatalf("missing required events: funcDone=%d msgAdded=%d completed=%d", posFuncDone, posMsgAdded, posCompleted)
}
if !(posFuncDone < posMsgAdded) {
t.Fatalf("expected function_call to complete before message is added: funcDone=%d msgAdded=%d", posFuncDone, posMsgAdded)
}
if !(posMsgAdded < posCompleted) {
t.Fatalf("expected response.completed after message added: msgAdded=%d completed=%d", posMsgAdded, posCompleted)
}
}

View File

@@ -89,12 +89,14 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Handle system message first
systemMsgJSON := `{"role":"system","content":[]}`
hasSystemContent := false
if system := root.Get("system"); system.Exists() {
if system.Type == gjson.String {
if system.String() != "" {
oldSystem := `{"type":"text","text":""}`
oldSystem, _ = sjson.Set(oldSystem, "text", system.String())
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", oldSystem)
hasSystemContent = true
}
} else if system.Type == gjson.JSON {
if system.IsArray() {
@@ -102,12 +104,16 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
for i := 0; i < len(systemResults); i++ {
if contentItem, ok := convertClaudeContentPart(systemResults[i]); ok {
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", contentItem)
hasSystemContent = true
}
}
}
}
}
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
// Only add system message if it has content
if hasSystemContent {
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
}
// Process Anthropic messages
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {

View File

@@ -12,10 +12,23 @@ import (
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
const placeholderReasonDescription = "Brief explanation of why you are calling this tool"
// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
// It handles unsupported keywords, type flattening, and schema simplification while preserving
// semantic information as description hints.
func CleanJSONSchemaForAntigravity(jsonStr string) string {
return cleanJSONSchema(jsonStr, true)
}
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini tool calling.
// It removes unsupported keywords and simplifies schemas, without adding empty-schema placeholders.
func CleanJSONSchemaForGemini(jsonStr string) string {
return cleanJSONSchema(jsonStr, false)
}
// cleanJSONSchema performs the core cleaning operations on the JSON schema.
func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
// Phase 1: Convert and add hints
jsonStr = convertRefsToHints(jsonStr)
jsonStr = convertConstToEnum(jsonStr)
@@ -31,10 +44,94 @@ func CleanJSONSchemaForAntigravity(jsonStr string) string {
// Phase 3: Cleanup
jsonStr = removeUnsupportedKeywords(jsonStr)
if !addPlaceholder {
// Gemini schema cleanup: remove nullable/title and placeholder-only fields.
jsonStr = removeKeywords(jsonStr, []string{"nullable", "title"})
jsonStr = removePlaceholderFields(jsonStr)
}
jsonStr = cleanupRequiredFields(jsonStr)
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
jsonStr = addEmptySchemaPlaceholder(jsonStr)
if addPlaceholder {
jsonStr = addEmptySchemaPlaceholder(jsonStr)
}
return jsonStr
}
// removeKeywords removes all occurrences of specified keywords from the JSON schema.
func removeKeywords(jsonStr string, keywords []string) string {
for _, key := range keywords {
for _, p := range findPaths(jsonStr, key) {
if isPropertyDefinition(trimSuffix(p, "."+key)) {
continue
}
jsonStr, _ = sjson.Delete(jsonStr, p)
}
}
return jsonStr
}
// removePlaceholderFields removes placeholder-only properties ("_" and "reason") and their required entries.
func removePlaceholderFields(jsonStr string) string {
// Remove "_" placeholder properties.
paths := findPaths(jsonStr, "_")
sortByDepth(paths)
for _, p := range paths {
if !strings.HasSuffix(p, ".properties._") {
continue
}
jsonStr, _ = sjson.Delete(jsonStr, p)
parentPath := trimSuffix(p, ".properties._")
reqPath := joinPath(parentPath, "required")
req := gjson.Get(jsonStr, reqPath)
if req.IsArray() {
var filtered []string
for _, r := range req.Array() {
if r.String() != "_" {
filtered = append(filtered, r.String())
}
}
if len(filtered) == 0 {
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
} else {
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
}
}
}
// Remove placeholder-only "reason" objects.
reasonPaths := findPaths(jsonStr, "reason")
sortByDepth(reasonPaths)
for _, p := range reasonPaths {
if !strings.HasSuffix(p, ".properties.reason") {
continue
}
parentPath := trimSuffix(p, ".properties.reason")
props := gjson.Get(jsonStr, joinPath(parentPath, "properties"))
if !props.IsObject() || len(props.Map()) != 1 {
continue
}
desc := gjson.Get(jsonStr, p+".description").String()
if desc != placeholderReasonDescription {
continue
}
jsonStr, _ = sjson.Delete(jsonStr, p)
reqPath := joinPath(parentPath, "required")
req := gjson.Get(jsonStr, reqPath)
if req.IsArray() {
var filtered []string
for _, r := range req.Array() {
if r.String() != "reason" {
filtered = append(filtered, r.String())
}
}
if len(filtered) == 0 {
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
} else {
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
}
}
}
return jsonStr
}
@@ -409,7 +506,7 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
// Add placeholder "reason" property
reasonPath := joinPath(propsPath, "reason")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", placeholderReasonDescription)
// Add to required array
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})

View File

@@ -128,8 +128,23 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
// Parameters:
// - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
models := h.Models()
firstID := ""
lastID := ""
if len(models) > 0 {
if id, ok := models[0]["id"].(string); ok {
firstID = id
}
if id, ok := models[len(models)-1]["id"].(string); ok {
lastID = id
}
}
c.JSON(http.StatusOK, gin.H{
"data": h.Models(),
"data": models,
"has_more": false,
"first_id": firstID,
"last_id": lastID,
})
}

View File

@@ -60,8 +60,12 @@ func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
if !strings.HasPrefix(name, "models/") {
normalizedModel["name"] = "models/" + name
}
normalizedModel["displayName"] = name
normalizedModel["description"] = name
if displayName, _ := normalizedModel["displayName"].(string); displayName == "" {
normalizedModel["displayName"] = name
}
if description, _ := normalizedModel["description"].(string); description == "" {
normalizedModel["description"] = name
}
}
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
normalizedModel["supportedGenerationMethods"] = defaultMethods

View File

@@ -385,6 +385,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
return nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),
@@ -423,6 +424,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
return nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),
@@ -464,6 +466,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return nil, errChan
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),

View File

@@ -2,15 +2,13 @@ package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
@@ -19,20 +17,6 @@ import (
log "github.com/sirupsen/logrus"
)
const (
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
antigravityCallbackPort = 51121
)
var antigravityScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
// AntigravityAuthenticator implements OAuth login for the antigravity provider.
type AntigravityAuthenticator struct{}
@@ -60,12 +44,12 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
opts = &LoginOptions{}
}
callbackPort := antigravityCallbackPort
callbackPort := antigravity.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
authSvc := antigravity.NewAntigravityAuth(cfg, nil)
state, err := misc.GenerateRandomState()
if err != nil {
@@ -83,7 +67,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
}()
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port)
authURL := buildAntigravityAuthURL(redirectURI, state)
authURL := authSvc.BuildAuthURL(state, redirectURI)
if !opts.NoBrowser {
fmt.Println("Opening browser for antigravity authentication")
@@ -164,22 +148,29 @@ waitForCallback:
return nil, fmt.Errorf("antigravity: missing authorization code")
}
tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI, httpClient)
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, cbRes.Code, redirectURI)
if errToken != nil {
return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken)
}
email := ""
if tokenResp.AccessToken != "" {
if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" {
email = strings.TrimSpace(info.Email)
}
accessToken := strings.TrimSpace(tokenResp.AccessToken)
if accessToken == "" {
return nil, fmt.Errorf("antigravity: token exchange returned empty access token")
}
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
if errInfo != nil {
return nil, fmt.Errorf("antigravity: fetch user info failed: %w", errInfo)
}
email = strings.TrimSpace(email)
if email == "" {
return nil, fmt.Errorf("antigravity: empty email returned from user info")
}
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
projectID := ""
if tokenResp.AccessToken != "" {
fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
if accessToken != "" {
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
@@ -204,7 +195,7 @@ waitForCallback:
metadata["project_id"] = projectID
}
fileName := sanitizeAntigravityFileName(email)
fileName := antigravity.CredentialFileName(email)
label := email
if label == "" {
label = "antigravity"
@@ -231,7 +222,7 @@ type callbackResult struct {
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
if port <= 0 {
port = antigravityCallbackPort
port = antigravity.CallbackPort
}
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
@@ -267,309 +258,9 @@ func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbac
return srv, port, resultCh, nil
}
type antigravityTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
func exchangeAntigravityCode(ctx context.Context, code, redirectURI string, httpClient *http.Client) (*antigravityTokenResponse, error) {
data := url.Values{}
data.Set("code", code)
data.Set("client_id", antigravityClientID)
data.Set("client_secret", antigravityClientSecret)
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := httpClient.Do(req)
if errDo != nil {
return nil, errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange: close body error: %v", errClose)
}
}()
var token antigravityTokenResponse
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
return nil, errDecode
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode)
}
return &token, nil
}
type antigravityUserInfo struct {
Email string `json:"email"`
}
func fetchAntigravityUserInfo(ctx context.Context, accessToken string, httpClient *http.Client) (*antigravityUserInfo, error) {
if strings.TrimSpace(accessToken) == "" {
return &antigravityUserInfo{}, nil
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, errDo := httpClient.Do(req)
if errDo != nil {
return nil, errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity userinfo: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return &antigravityUserInfo{}, nil
}
var info antigravityUserInfo
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
return nil, errDecode
}
return &info, nil
}
func buildAntigravityAuthURL(redirectURI, state string) string {
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", antigravityClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(antigravityScopes, " "))
params.Set("state", state)
return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
}
func sanitizeAntigravityFileName(email string) string {
if strings.TrimSpace(email) == "" {
return "antigravity.json"
}
replacer := strings.NewReplacer("@", "_", ".", "_")
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
}
// Antigravity API constants for project discovery
const (
antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com"
antigravityAPIVersion = "v1internal"
antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1"
antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
)
// FetchAntigravityProjectID exposes project discovery for external callers.
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
return fetchAntigravityProjectID(ctx, accessToken, httpClient)
}
// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist.
// This uses the same approach as Gemini CLI to get the cloudaicompanionProject.
func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
// Call loadCodeAssist to get the project
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(loadReqBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", antigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
req.Header.Set("Client-Metadata", antigravityClientMetadata)
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var loadResp map[string]any
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
if projectID == "" {
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
if err != nil {
return "", err
}
return projectID, nil
}
return projectID, nil
}
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
// It returns an empty string when the operation times out or completes without a project ID.
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
if httpClient == nil {
httpClient = http.DefaultClient
}
fmt.Println("Antigravity: onboarding user...", tierID)
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(requestBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
reqCtx := ctx
var cancel context.CancelFunc
if reqCtx == nil {
reqCtx = context.Background()
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", antigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
req.Header.Set("Client-Metadata", antigravityClientMetadata)
resp, errDo := httpClient.Do(req)
if errDo != nil {
cancel()
return "", fmt.Errorf("execute request: %w", errDo)
}
bodyBytes, errRead := io.ReadAll(resp.Body)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("close body error: %v", errClose)
}
cancel()
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode == http.StatusOK {
var data map[string]any
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
return projectID, nil
}
return "", fmt.Errorf("no project_id in response")
}
time.Sleep(2 * time.Second)
continue
}
responsePreview := strings.TrimSpace(string(bodyBytes))
if len(responsePreview) > 500 {
responsePreview = responsePreview[:500]
}
responseErr := responsePreview
if len(responseErr) > 200 {
responseErr = responseErr[:200]
}
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
cfg := &config.Config{}
authSvc := antigravity.NewAntigravityAuth(cfg, httpClient)
return authSvc.FetchProjectID(ctx, accessToken)
}

View File

@@ -570,6 +570,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
@@ -597,6 +598,9 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
@@ -619,6 +623,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
@@ -646,6 +651,9 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
@@ -668,6 +676,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
@@ -694,6 +703,9 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
@@ -729,167 +741,42 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
}
}
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if provider == "" {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return opts
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
if hasRequestedModelMetadata(opts.Metadata) {
return opts
}
if len(opts.Metadata) == 0 {
opts.Metadata = map[string]any{cliproxyexecutor.RequestedModelMetadataKey: requestedModel}
return opts
}
meta := make(map[string]any, len(opts.Metadata)+1)
for k, v := range opts.Metadata {
meta[k] = v
}
meta[cliproxyexecutor.RequestedModelMetadataKey] = requestedModel
opts.Metadata = meta
return opts
}
func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if provider == "" {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
func hasRequestedModelMetadata(meta map[string]any) bool {
if len(meta) == 0 {
return false
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
raw, ok := meta[cliproxyexecutor.RequestedModelMetadataKey]
if !ok || raw == nil {
return false
}
}
func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
if provider == "" {
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
out <- chunk
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
switch v := raw.(type) {
case string:
return strings.TrimSpace(v) != ""
case []byte:
return strings.TrimSpace(string(v)) != ""
default:
return false
}
}
@@ -1140,35 +1027,6 @@ func (m *Manager) normalizeProviders(providers []string) []string {
return result
}
// rotateProviders returns a rotated view of the providers list starting from the
// current offset for the model, and atomically increments the offset for the next call.
// This ensures concurrent requests get different starting providers.
func (m *Manager) rotateProviders(model string, providers []string) []string {
if len(providers) == 0 {
return nil
}
// Atomic read-and-increment: get current offset and advance cursor in one lock
m.mu.Lock()
offset := m.providerOffsets[model]
m.providerOffsets[model] = (offset + 1) % len(providers)
m.mu.Unlock()
if len(providers) > 0 {
offset %= len(providers)
}
if offset < 0 {
offset = 0
}
if offset == 0 {
return providers
}
rotated := make([]string, 0, len(providers))
rotated = append(rotated, providers[offset:]...)
rotated = append(rotated, providers[:offset]...)
return rotated
}
func (m *Manager) retrySettings() (int, time.Duration) {
if m == nil {
return 0, 0
@@ -1250,42 +1108,6 @@ func waitForCooldown(ctx context.Context, wait time.Duration) error {
}
}
func (m *Manager) executeProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (cliproxyexecutor.Response, error)) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
var lastErr error
for _, provider := range providers {
resp, errExec := fn(ctx, provider)
if errExec == nil {
return resp, nil
}
lastErr = errExec
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
func (m *Manager) executeStreamProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (<-chan cliproxyexecutor.StreamChunk, error)) (<-chan cliproxyexecutor.StreamChunk, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
var lastErr error
for _, provider := range providers {
chunks, errExec := fn(ctx, provider)
if errExec == nil {
return chunks, nil
}
lastErr = errExec
}
if lastErr != nil {
return nil, lastErr
}
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
// MarkResult records an execution result and notifies hooks.
func (m *Manager) MarkResult(ctx context.Context, result Result) {
if result.AuthID == "" {
@@ -1371,8 +1193,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
shouldSuspendModel = true
setModelQuota = true
case 408, 500, 502, 503, 504:
next := now.Add(1 * time.Minute)
state.NextRetryAfter = next
if quotaCooldownDisabled.Load() {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
state.NextRetryAfter = next
}
default:
state.NextRetryAfter = time.Time{}
}
@@ -1623,7 +1449,11 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
auth.NextRetryAfter = next
case 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error"
auth.NextRetryAfter = now.Add(1 * time.Minute)
if quotaCooldownDisabled.Load() {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(1 * time.Minute)
}
default:
if auth.StatusMessage == "" {
auth.StatusMessage = "request failed"

View File

@@ -7,6 +7,9 @@ import (
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
const RequestedModelMetadataKey = "requested_model"
// Request encapsulates the translated payload that will be sent to a provider executor.
type Request struct {
// Model is the upstream model identifier after translation.