mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-05 05:50:50 +08:00
Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f53b952b2 | ||
|
|
f30ffd5f5e | ||
|
|
bc9a24d705 | ||
|
|
2c879f13ef | ||
|
|
07b4a08979 | ||
|
|
7f612bb069 | ||
|
|
5743b78694 | ||
|
|
2e6a2b655c | ||
|
|
cb47ac21bf | ||
|
|
a1394b4596 | ||
|
|
9e97948f03 | ||
|
|
46c6fb1e7a | ||
|
|
9f9fec5d4c | ||
|
|
e95be10485 | ||
|
|
f3d58fa0ce | ||
|
|
8c0eaa1f71 | ||
|
|
405df58f72 | ||
|
|
e7f13aa008 | ||
|
|
7cb6a9b89a | ||
|
|
9aa5344c29 | ||
|
|
8ba0ebbd2a | ||
|
|
c65407ab9f | ||
|
|
9e59685212 | ||
|
|
4a4dfaa910 | ||
|
|
0d6ecb0191 | ||
|
|
f16461bfe7 | ||
|
|
c32e2a8196 | ||
|
|
873d41582f | ||
|
|
6fb7d85558 | ||
|
|
d5e3e32d58 | ||
|
|
f353a54555 | ||
|
|
1d6e2e751d | ||
|
|
cc50b63422 | ||
|
|
15ae83a15b | ||
|
|
81b369aed9 | ||
|
|
ecc850bfb7 | ||
|
|
19b4ef33e0 | ||
|
|
7ca045d8b9 | ||
|
|
abfca6aab2 | ||
|
|
3c71c075db | ||
|
|
9c2992bfb2 | ||
|
|
269a1c5452 | ||
|
|
22ce65ac72 | ||
|
|
a2f8f59192 | ||
|
|
8c7c446f33 | ||
|
|
30a59168d7 | ||
|
|
c8884f5e25 | ||
|
|
d9c6317c84 | ||
|
|
d29ec95526 | ||
|
|
09970dc7af | ||
|
|
d81abd401c | ||
|
|
a6cba25bc1 | ||
|
|
2f6004d74a | ||
|
|
a1634909e8 |
@@ -134,6 +134,10 @@ VSCode extension for quick switching between Claude Code models, featuring integ
|
||||
|
||||
Windows desktop app built with Tauri + React for monitoring AI coding assistant quotas via CLIProxyAPI. Track usage across Gemini, Claude, OpenAI Codex, and Antigravity accounts with real-time dashboard, system tray integration, and one-click proxy control - no API keys needed.
|
||||
|
||||
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||
|
||||
A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
@@ -133,6 +133,10 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
|
||||
|
||||
Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI 监控 AI 编程助手配额。支持跨 Gemini、Claude、OpenAI Codex 和 Antigravity 账户的使用量追踪,提供实时仪表盘、系统托盘集成和一键代理控制,无需 API 密钥。
|
||||
|
||||
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
|
||||
|
||||
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -21,6 +21,7 @@ require (
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/text v0.31.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -70,7 +71,6 @@ require (
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
@@ -21,6 +20,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
@@ -232,14 +232,6 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||
log.Infof("callback forwarder on port %d stopped", port)
|
||||
}
|
||||
|
||||
func sanitizeAntigravityFileName(email string) string {
|
||||
if strings.TrimSpace(email) == "" {
|
||||
return "antigravity.json"
|
||||
}
|
||||
replacer := strings.NewReplacer("@", "_", ".", "_")
|
||||
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
||||
}
|
||||
|
||||
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
||||
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
||||
return "", fmt.Errorf("server port is not configured")
|
||||
@@ -749,6 +741,72 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
return err
|
||||
}
|
||||
|
||||
// PatchAuthFileStatus toggles the disabled state of an auth file
|
||||
func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Disabled *bool `json:"disabled"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
if req.Disabled == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "disabled is required"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find auth by name or ID
|
||||
var targetAuth *coreauth.Auth
|
||||
if auth, ok := h.authManager.GetByID(name); ok {
|
||||
targetAuth = auth
|
||||
} else {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name {
|
||||
targetAuth = auth
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if targetAuth == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update disabled state
|
||||
targetAuth.Disabled = *req.Disabled
|
||||
if *req.Disabled {
|
||||
targetAuth.Status = coreauth.StatusDisabled
|
||||
targetAuth.StatusMessage = "disabled via management API"
|
||||
} else {
|
||||
targetAuth.Status = coreauth.StatusActive
|
||||
targetAuth.StatusMessage = ""
|
||||
}
|
||||
targetAuth.UpdatedAt = time.Now()
|
||||
|
||||
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
@@ -915,67 +973,14 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
rawCode := resultMap["code"]
|
||||
code := strings.Split(rawCode, "#")[0]
|
||||
|
||||
// Exchange code for tokens (replicate logic using updated redirect_uri)
|
||||
// Extract client_id from the modified auth URL
|
||||
clientID := ""
|
||||
if u2, errP := url.Parse(authURL); errP == nil {
|
||||
clientID = u2.Query().Get("client_id")
|
||||
}
|
||||
// Build request
|
||||
bodyMap := map[string]any{
|
||||
"code": code,
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": clientID,
|
||||
"redirect_uri": "http://localhost:54545/callback",
|
||||
"code_verifier": pkceCodes.CodeVerifier,
|
||||
}
|
||||
bodyJSON, _ := json.Marshal(bodyMap)
|
||||
|
||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||
// Exchange code for tokens using internal auth service
|
||||
bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes)
|
||||
if errExchange != nil {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange)
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("failed to close response body: %v", errClose)
|
||||
}
|
||||
}()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
var tResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Account struct {
|
||||
EmailAddress string `json:"email_address"`
|
||||
} `json:"account"`
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
bundle := &claude.ClaudeAuthBundle{
|
||||
TokenData: claude.ClaudeTokenData{
|
||||
AccessToken: tResp.AccessToken,
|
||||
RefreshToken: tResp.RefreshToken,
|
||||
Email: tResp.Account.EmailAddress,
|
||||
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
},
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
// Create token storage
|
||||
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
||||
@@ -1015,17 +1020,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
fmt.Println("Initializing Google authentication...")
|
||||
|
||||
// OAuth2 configuration (mirrors internal/auth/gemini)
|
||||
// OAuth2 configuration using exported constants from internal/auth/gemini
|
||||
conf := &oauth2.Config{
|
||||
ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com",
|
||||
ClientSecret: "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",
|
||||
RedirectURL: "http://localhost:8085/oauth2callback",
|
||||
Scopes: []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
},
|
||||
Endpoint: google.Endpoint,
|
||||
ClientID: geminiAuth.ClientID,
|
||||
ClientSecret: geminiAuth.ClientSecret,
|
||||
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort),
|
||||
Scopes: geminiAuth.Scopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
// Build authorization URL and return it immediately
|
||||
@@ -1147,13 +1148,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||
ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
ifToken["client_secret"] = "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"
|
||||
ifToken["scopes"] = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
ifToken["client_id"] = geminiAuth.ClientID
|
||||
ifToken["client_secret"] = geminiAuth.ClientSecret
|
||||
ifToken["scopes"] = geminiAuth.Scopes
|
||||
ifToken["universe_domain"] = "googleapis.com"
|
||||
|
||||
ts := geminiAuth.GeminiTokenStorage{
|
||||
@@ -1340,73 +1337,25 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
log.Debug("Authorization code received, exchanging for tokens...")
|
||||
// Extract client_id from authURL
|
||||
clientID := ""
|
||||
if u2, errP := url.Parse(authURL); errP == nil {
|
||||
clientID = u2.Query().Get("client_id")
|
||||
}
|
||||
// Exchange code for tokens with redirect equal to mgmtRedirect
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {clientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {"http://localhost:1455/auth/callback"},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||
// Exchange code for tokens using internal auth service
|
||||
bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes)
|
||||
if errExchange != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange)
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
return
|
||||
}
|
||||
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
|
||||
email := ""
|
||||
accountID := ""
|
||||
|
||||
// Extract additional info for filename generation
|
||||
claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken)
|
||||
planType := ""
|
||||
if claims != nil {
|
||||
email = claims.GetUserEmail()
|
||||
accountID = claims.GetAccountID()
|
||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||
}
|
||||
hashAccountID := ""
|
||||
if accountID != "" {
|
||||
digest := sha256.Sum256([]byte(accountID))
|
||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||
}
|
||||
// Build bundle compatible with existing storage
|
||||
bundle := &codex.CodexAuthBundle{
|
||||
TokenData: codex.CodexTokenData{
|
||||
IDToken: tokenResp.IDToken,
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
AccountID: accountID,
|
||||
Email: email,
|
||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
},
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
if claims != nil {
|
||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||
if accountID := claims.GetAccountID(); accountID != "" {
|
||||
digest := sha256.Sum256([]byte(accountID))
|
||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||
}
|
||||
}
|
||||
|
||||
// Create token storage and persist
|
||||
@@ -1441,23 +1390,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
const (
|
||||
antigravityCallbackPort = 51121
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
)
|
||||
var antigravityScopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
fmt.Println("Initializing Antigravity authentication...")
|
||||
|
||||
authSvc := antigravity.NewAntigravityAuth(h.cfg, nil)
|
||||
|
||||
state, errState := misc.GenerateRandomState()
|
||||
if errState != nil {
|
||||
log.Errorf("Failed to generate state parameter: %v", errState)
|
||||
@@ -1465,17 +1403,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("client_id", antigravityClientID)
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", strings.Join(antigravityScopes, " "))
|
||||
params.Set("state", state)
|
||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort)
|
||||
authURL := authSvc.BuildAuthURL(state, redirectURI)
|
||||
|
||||
RegisterOAuthSession(state, "antigravity")
|
||||
|
||||
@@ -1489,7 +1418,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
||||
if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -1498,7 +1427,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
|
||||
defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder)
|
||||
}
|
||||
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
||||
@@ -1538,93 +1467,36 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||
form := url.Values{}
|
||||
form.Set("code", authCode)
|
||||
form.Set("client_id", antigravityClientID)
|
||||
form.Set("client_secret", antigravityClientSecret)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
||||
SetOAuthSessionError(state, "Failed to build token request")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute token request: %v", errDo)
|
||||
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI)
|
||||
if errToken != nil {
|
||||
log.Errorf("Failed to exchange token: %v", errToken)
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity token exchange close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
||||
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||
if accessToken == "" {
|
||||
log.Error("antigravity: token exchange returned empty access token")
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
|
||||
if errInfo != nil {
|
||||
log.Errorf("Failed to fetch user info: %v", errInfo)
|
||||
SetOAuthSessionError(state, "Failed to fetch user info")
|
||||
return
|
||||
}
|
||||
|
||||
email := ""
|
||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errInfoReq != nil {
|
||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
||||
SetOAuthSessionError(state, "Failed to build user info request")
|
||||
return
|
||||
}
|
||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||
|
||||
infoResp, errInfo := httpClient.Do(infoReq)
|
||||
if errInfo != nil {
|
||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := infoResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity user info close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
|
||||
var infoPayload struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
|
||||
email = strings.TrimSpace(infoPayload.Email)
|
||||
}
|
||||
} else {
|
||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
||||
return
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
log.Error("antigravity: user info returned empty email")
|
||||
SetOAuthSessionError(state, "Failed to fetch user info")
|
||||
return
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
||||
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
||||
if accessToken != "" {
|
||||
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
|
||||
if errProject != nil {
|
||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||
} else {
|
||||
@@ -1649,7 +1521,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
metadata["project_id"] = projectID
|
||||
}
|
||||
|
||||
fileName := sanitizeAntigravityFileName(email)
|
||||
fileName := antigravity.CredentialFileName(email)
|
||||
label := strings.TrimSpace(email)
|
||||
if label == "" {
|
||||
label = "antigravity"
|
||||
|
||||
@@ -610,6 +610,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
|
||||
344
internal/auth/antigravity/auth.go
Normal file
344
internal/auth/antigravity/auth.go
Normal file
@@ -0,0 +1,344 @@
|
||||
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TokenResponse represents OAuth token response from Google
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// userInfo represents Google user profile
|
||||
type userInfo struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// AntigravityAuth handles Antigravity OAuth authentication
|
||||
type AntigravityAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAntigravityAuth creates a new Antigravity auth service.
|
||||
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
|
||||
if httpClient != nil {
|
||||
return &AntigravityAuth{httpClient: httpClient}
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = &config.Config{}
|
||||
}
|
||||
return &AntigravityAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||
}
|
||||
}
|
||||
|
||||
// BuildAuthURL generates the OAuth authorization URL.
|
||||
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", strings.Join(Scopes, " "))
|
||||
params.Set("state", state)
|
||||
return AuthEndpoint + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
|
||||
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("code", code)
|
||||
data.Set("client_id", ClientID)
|
||||
data.Set("client_secret", ClientSecret)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
data.Set("grant_type", "authorization_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
if errRead != nil {
|
||||
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
|
||||
}
|
||||
body := strings.TrimSpace(string(bodyBytes))
|
||||
if body == "" {
|
||||
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var token TokenResponse
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
|
||||
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves user email from Google
|
||||
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
if accessToken == "" {
|
||||
return "", fmt.Errorf("antigravity userinfo: missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity userinfo: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
|
||||
}
|
||||
body := strings.TrimSpace(string(bodyBytes))
|
||||
if body == "" {
|
||||
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
|
||||
}
|
||||
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
var info userInfo
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
|
||||
}
|
||||
email := strings.TrimSpace(info.Email)
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("antigravity userinfo: response missing email")
|
||||
}
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
|
||||
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
|
||||
loadReqBody := map[string]any{
|
||||
"metadata": map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", APIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
|
||||
var loadResp map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
// Extract projectID from response
|
||||
projectID := ""
|
||||
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
if projectID == "" {
|
||||
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||
if id, okID := projectMap["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
tierID := "legacy-tier"
|
||||
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||
for _, rawTier := range tiers {
|
||||
tier, okTier := rawTier.(map[string]any)
|
||||
if !okTier {
|
||||
continue
|
||||
}
|
||||
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||
tierID = strings.TrimSpace(id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
|
||||
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
|
||||
requestBody := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(requestBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
maxAttempts := 5
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||
|
||||
reqCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if reqCtx == nil {
|
||||
reqCtx = context.Background()
|
||||
}
|
||||
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
|
||||
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if errRequest != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", APIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("close body error: %v", errClose)
|
||||
}
|
||||
cancel()
|
||||
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var data map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
if done, okDone := data["done"].(bool); okDone && done {
|
||||
projectID := ""
|
||||
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||
case map[string]any:
|
||||
if id, okID := projectValue["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
case string:
|
||||
projectID = strings.TrimSpace(projectValue)
|
||||
}
|
||||
}
|
||||
|
||||
if projectID != "" {
|
||||
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no project_id in response")
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||
if len(responsePreview) > 500 {
|
||||
responsePreview = responsePreview[:500]
|
||||
}
|
||||
|
||||
responseErr := responsePreview
|
||||
if len(responseErr) > 200 {
|
||||
responseErr = responseErr[:200]
|
||||
}
|
||||
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
34
internal/auth/antigravity/constants.go
Normal file
34
internal/auth/antigravity/constants.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||
package antigravity
|
||||
|
||||
// OAuth client credentials and configuration
|
||||
const (
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
CallbackPort = 51121
|
||||
)
|
||||
|
||||
// Scopes defines the OAuth scopes required for Antigravity authentication
|
||||
var Scopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
// OAuth2 endpoints for Google authentication
|
||||
const (
|
||||
TokenEndpoint = "https://oauth2.googleapis.com/token"
|
||||
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
|
||||
)
|
||||
|
||||
// Antigravity API configuration
|
||||
const (
|
||||
APIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
APIVersion = "v1internal"
|
||||
APIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
||||
)
|
||||
16
internal/auth/antigravity/filename.go
Normal file
16
internal/auth/antigravity/filename.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Antigravity credentials.
|
||||
// It uses the email as a suffix to disambiguate accounts.
|
||||
func CredentialFileName(email string) string {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
return "antigravity.json"
|
||||
}
|
||||
return fmt.Sprintf("antigravity-%s.json", email)
|
||||
}
|
||||
@@ -18,11 +18,12 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OAuth configuration constants for Claude/Anthropic
|
||||
const (
|
||||
anthropicAuthURL = "https://claude.ai/oauth/authorize"
|
||||
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
redirectURI = "http://localhost:54545/callback"
|
||||
AuthURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
RedirectURI = "http://localhost:54545/callback"
|
||||
)
|
||||
|
||||
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
||||
@@ -82,16 +83,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
||||
|
||||
params := url.Values{
|
||||
"code": {"true"},
|
||||
"client_id": {anthropicClientID},
|
||||
"client_id": {ClientID},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {redirectURI},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"scope": {"org:create_api_key user:profile user:inference"},
|
||||
"code_challenge": {pkceCodes.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
|
||||
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||
return authURL, state, nil
|
||||
}
|
||||
|
||||
@@ -137,8 +138,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
||||
"code": newCode,
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": anthropicClientID,
|
||||
"redirect_uri": redirectURI,
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"code_verifier": pkceCodes.CodeVerifier,
|
||||
}
|
||||
|
||||
@@ -154,7 +155,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
||||
|
||||
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
@@ -221,7 +222,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"client_id": anthropicClientID,
|
||||
"client_id": ClientID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
}
|
||||
@@ -231,7 +232,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,9 +4,6 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
|
||||
@@ -43,15 +40,7 @@ func normalizePlanTypeForFilename(planType string) string {
|
||||
}
|
||||
|
||||
for i, part := range parts {
|
||||
parts[i] = titleToken(part)
|
||||
parts[i] = strings.ToLower(strings.TrimSpace(part))
|
||||
}
|
||||
return strings.Join(parts, "-")
|
||||
}
|
||||
|
||||
func titleToken(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
return cases.Title(language.English).String(token)
|
||||
}
|
||||
|
||||
@@ -19,11 +19,12 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OAuth configuration constants for OpenAI Codex
|
||||
const (
|
||||
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
|
||||
openaiTokenURL = "https://auth.openai.com/oauth/token"
|
||||
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
redirectURI = "http://localhost:1455/auth/callback"
|
||||
AuthURL = "https://auth.openai.com/oauth/authorize"
|
||||
TokenURL = "https://auth.openai.com/oauth/token"
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
RedirectURI = "http://localhost:1455/auth/callback"
|
||||
)
|
||||
|
||||
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
||||
@@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
}
|
||||
|
||||
params := url.Values{
|
||||
"client_id": {openaiClientID},
|
||||
"client_id": {ClientID},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {redirectURI},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"scope": {"openid email profile offline_access"},
|
||||
"state": {state},
|
||||
"code_challenge": {pkceCodes.CodeChallenge},
|
||||
@@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
|
||||
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||
return authURL, nil
|
||||
}
|
||||
|
||||
@@ -77,13 +78,13 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {openaiClientID},
|
||||
"client_id": {ClientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {redirectURI},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
@@ -163,13 +164,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"client_id": {openaiClientID},
|
||||
"client_id": {ClientID},
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"scope": {"openid profile email"},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
@@ -28,19 +28,19 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// OAuth configuration constants for Gemini
|
||||
const (
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
geminiDefaultCallbackPort = 8085
|
||||
ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
DefaultCallbackPort = 8085
|
||||
)
|
||||
|
||||
var (
|
||||
geminiOauthScopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
)
|
||||
// OAuth scopes for Gemini authentication
|
||||
var Scopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
||||
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
||||
@@ -74,7 +74,7 @@ func NewGeminiAuth() *GeminiAuth {
|
||||
// - *http.Client: An HTTP client configured with authentication
|
||||
// - error: An error if the client configuration fails, nil otherwise
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||
callbackPort := geminiDefaultCallbackPort
|
||||
callbackPort := DefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
@@ -112,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
|
||||
// Configure the OAuth2 client.
|
||||
conf := &oauth2.Config{
|
||||
ClientID: geminiOauthClientID,
|
||||
ClientSecret: geminiOauthClientSecret,
|
||||
ClientID: ClientID,
|
||||
ClientSecret: ClientSecret,
|
||||
RedirectURL: callbackURL, // This will be used by the local server.
|
||||
Scopes: geminiOauthScopes,
|
||||
Scopes: Scopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
@@ -198,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
}
|
||||
|
||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||
ifToken["client_id"] = geminiOauthClientID
|
||||
ifToken["client_secret"] = geminiOauthClientSecret
|
||||
ifToken["scopes"] = geminiOauthScopes
|
||||
ifToken["client_id"] = ClientID
|
||||
ifToken["client_secret"] = ClientSecret
|
||||
ifToken["scopes"] = Scopes
|
||||
ifToken["universe_domain"] = "googleapis.com"
|
||||
|
||||
ts := GeminiTokenStorage{
|
||||
@@ -226,7 +226,7 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||
// - error: An error if the token acquisition fails, nil otherwise
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||
callbackPort := geminiDefaultCallbackPort
|
||||
callbackPort := DefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
89
internal/cache/signature_cache.go
vendored
89
internal/cache/signature_cache.go
vendored
@@ -3,7 +3,6 @@ package cache
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -25,18 +24,18 @@ const (
|
||||
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
||||
MinValidSignatureLen = 50
|
||||
|
||||
// SessionCleanupInterval controls how often stale sessions are purged
|
||||
SessionCleanupInterval = 10 * time.Minute
|
||||
// CacheCleanupInterval controls how often stale entries are purged
|
||||
CacheCleanupInterval = 10 * time.Minute
|
||||
)
|
||||
|
||||
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
|
||||
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
|
||||
var signatureCache sync.Map
|
||||
|
||||
// sessionCleanupOnce ensures the background cleanup goroutine starts only once
|
||||
var sessionCleanupOnce sync.Once
|
||||
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
|
||||
var cacheCleanupOnce sync.Once
|
||||
|
||||
// sessionCache is the inner map type
|
||||
type sessionCache struct {
|
||||
// groupCache is the inner map type
|
||||
type groupCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]SignatureEntry
|
||||
}
|
||||
@@ -47,36 +46,36 @@ func hashText(text string) string {
|
||||
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
||||
}
|
||||
|
||||
// getOrCreateSession gets or creates a session cache
|
||||
func getOrCreateSession(sessionID string) *sessionCache {
|
||||
// getOrCreateGroupCache gets or creates a cache bucket for a model group
|
||||
func getOrCreateGroupCache(groupKey string) *groupCache {
|
||||
// Start background cleanup on first access
|
||||
sessionCleanupOnce.Do(startSessionCleanup)
|
||||
cacheCleanupOnce.Do(startCacheCleanup)
|
||||
|
||||
if val, ok := signatureCache.Load(sessionID); ok {
|
||||
return val.(*sessionCache)
|
||||
if val, ok := signatureCache.Load(groupKey); ok {
|
||||
return val.(*groupCache)
|
||||
}
|
||||
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
|
||||
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
|
||||
return actual.(*sessionCache)
|
||||
sc := &groupCache{entries: make(map[string]SignatureEntry)}
|
||||
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
|
||||
return actual.(*groupCache)
|
||||
}
|
||||
|
||||
// startSessionCleanup launches a background goroutine that periodically
|
||||
// removes sessions where all entries have expired.
|
||||
func startSessionCleanup() {
|
||||
// startCacheCleanup launches a background goroutine that periodically
|
||||
// removes caches where all entries have expired.
|
||||
func startCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(SessionCleanupInterval)
|
||||
ticker := time.NewTicker(CacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredSessions()
|
||||
purgeExpiredCaches()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeExpiredSessions removes sessions with no valid (non-expired) entries.
|
||||
func purgeExpiredSessions() {
|
||||
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
|
||||
func purgeExpiredCaches() {
|
||||
now := time.Now()
|
||||
signatureCache.Range(func(key, value any) bool {
|
||||
sc := value.(*sessionCache)
|
||||
sc := value.(*groupCache)
|
||||
sc.mu.Lock()
|
||||
// Remove expired entries
|
||||
for k, entry := range sc.entries {
|
||||
@@ -86,7 +85,7 @@ func purgeExpiredSessions() {
|
||||
}
|
||||
isEmpty := len(sc.entries) == 0
|
||||
sc.mu.Unlock()
|
||||
// Remove session if empty
|
||||
// Remove cache bucket if empty
|
||||
if isEmpty {
|
||||
signatureCache.Delete(key)
|
||||
}
|
||||
@@ -94,7 +93,7 @@ func purgeExpiredSessions() {
|
||||
})
|
||||
}
|
||||
|
||||
// CacheSignature stores a thinking signature for a given session and text.
|
||||
// CacheSignature stores a thinking signature for a given model group and text.
|
||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||
func CacheSignature(modelName, text, signature string) {
|
||||
if text == "" || signature == "" {
|
||||
@@ -104,9 +103,9 @@ func CacheSignature(modelName, text, signature string) {
|
||||
return
|
||||
}
|
||||
|
||||
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text)
|
||||
groupKey := GetModelGroup(modelName)
|
||||
textHash := hashText(text)
|
||||
sc := getOrCreateSession(textHash)
|
||||
sc := getOrCreateGroupCache(groupKey)
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
@@ -116,26 +115,25 @@ func CacheSignature(modelName, text, signature string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetCachedSignature retrieves a cached signature for a given session and text.
|
||||
// GetCachedSignature retrieves a cached signature for a given model group and text.
|
||||
// Returns empty string if not found or expired.
|
||||
func GetCachedSignature(modelName, text string) string {
|
||||
family := GetModelGroup(modelName)
|
||||
groupKey := GetModelGroup(modelName)
|
||||
|
||||
if text == "" {
|
||||
if family == "gemini" {
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text)
|
||||
val, ok := signatureCache.Load(hashText(text))
|
||||
val, ok := signatureCache.Load(groupKey)
|
||||
if !ok {
|
||||
if family == "gemini" {
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
sc := val.(*sessionCache)
|
||||
sc := val.(*groupCache)
|
||||
|
||||
textHash := hashText(text)
|
||||
|
||||
@@ -145,7 +143,7 @@ func GetCachedSignature(modelName, text string) string {
|
||||
entry, exists := sc.entries[textHash]
|
||||
if !exists {
|
||||
sc.mu.Unlock()
|
||||
if family == "gemini" {
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
@@ -153,7 +151,7 @@ func GetCachedSignature(modelName, text string) string {
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, textHash)
|
||||
sc.mu.Unlock()
|
||||
if family == "gemini" {
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
@@ -167,22 +165,17 @@ func GetCachedSignature(modelName, text string) string {
|
||||
return entry.Signature
|
||||
}
|
||||
|
||||
// ClearSignatureCache clears signature cache for a specific session or all sessions.
|
||||
func ClearSignatureCache(sessionID string) {
|
||||
if sessionID != "" {
|
||||
signatureCache.Range(func(key, _ any) bool {
|
||||
kStr, ok := key.(string)
|
||||
if ok && strings.HasSuffix(kStr, "#"+sessionID) {
|
||||
signatureCache.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else {
|
||||
// ClearSignatureCache clears signature cache for a specific model group or all groups.
|
||||
func ClearSignatureCache(modelName string) {
|
||||
if modelName == "" {
|
||||
signatureCache.Range(func(key, _ any) bool {
|
||||
signatureCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
groupKey := GetModelGroup(modelName)
|
||||
signatureCache.Delete(groupKey)
|
||||
}
|
||||
|
||||
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
||||
|
||||
103
internal/cache/signature_cache_test.go
vendored
103
internal/cache/signature_cache_test.go
vendored
@@ -5,6 +5,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const testModelName = "claude-sonnet-4-5"
|
||||
|
||||
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
@@ -12,30 +14,31 @@ func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
|
||||
// Store signature
|
||||
CacheSignature("test-model", text, signature)
|
||||
CacheSignature(testModelName, text, signature)
|
||||
|
||||
// Retrieve signature
|
||||
retrieved := GetCachedSignature("test-model", text)
|
||||
retrieved := GetCachedSignature(testModelName, text)
|
||||
if retrieved != signature {
|
||||
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_DifferentSessions(t *testing.T) {
|
||||
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "Same text in different sessions"
|
||||
text := "Same text across models"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature("test-model", text, sig1)
|
||||
CacheSignature("test-model", text, sig2)
|
||||
geminiModel := "gemini-3-pro-preview"
|
||||
CacheSignature(testModelName, text, sig1)
|
||||
CacheSignature(geminiModel, text, sig2)
|
||||
|
||||
if GetCachedSignature("test-model", text) != sig1 {
|
||||
t.Error("Session-a signature mismatch")
|
||||
if GetCachedSignature(testModelName, text) != sig1 {
|
||||
t.Error("Claude signature mismatch")
|
||||
}
|
||||
if GetCachedSignature("test-model", text) != sig2 {
|
||||
t.Error("Session-b signature mismatch")
|
||||
if GetCachedSignature(geminiModel, text) != sig2 {
|
||||
t.Error("Gemini signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,13 +46,13 @@ func TestCacheSignature_NotFound(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Non-existent session
|
||||
if got := GetCachedSignature("test-model", "some text"); got != "" {
|
||||
if got := GetCachedSignature(testModelName, "some text"); got != "" {
|
||||
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
||||
}
|
||||
|
||||
// Existing session but different text
|
||||
CacheSignature("test-model", "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature("test-model", "text-b"); got != "" {
|
||||
CacheSignature(testModelName, "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature(testModelName, "text-b"); got != "" {
|
||||
t.Errorf("Expected empty string for different text, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -58,12 +61,11 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// All empty/invalid inputs should be no-ops
|
||||
CacheSignature("test-model", "text", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("test-model", "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("test-model", "text", "")
|
||||
CacheSignature("test-model", "text", "short") // Too short
|
||||
CacheSignature(testModelName, "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature(testModelName, "text", "")
|
||||
CacheSignature(testModelName, "text", "short") // Too short
|
||||
|
||||
if got := GetCachedSignature("test-model", "text"); got != "" {
|
||||
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -74,27 +76,24 @@ func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
||||
text := "Some text"
|
||||
shortSig := "abc123" // Less than 50 chars
|
||||
|
||||
CacheSignature("test-model", text, shortSig)
|
||||
CacheSignature(testModelName, text, shortSig)
|
||||
|
||||
if got := GetCachedSignature("test-model", text); got != "" {
|
||||
if got := GetCachedSignature(testModelName, text); got != "" {
|
||||
t.Errorf("Short signature should be rejected, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSignatureCache_SpecificSession(t *testing.T) {
|
||||
func TestClearSignatureCache_ModelGroup(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("test-model", "text", sig)
|
||||
CacheSignature("test-model", "text", sig)
|
||||
CacheSignature(testModelName, "text", sig)
|
||||
CacheSignature(testModelName, "text-2", sig)
|
||||
|
||||
ClearSignatureCache("session-1")
|
||||
|
||||
if got := GetCachedSignature("test-model", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("test-model", "text"); got != sig {
|
||||
t.Error("session-2 should still exist")
|
||||
if got := GetCachedSignature(testModelName, "text"); got != sig {
|
||||
t.Error("signature should remain when clearing unknown session")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,35 +101,37 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("test-model", "text", sig)
|
||||
CacheSignature("test-model", "text", sig)
|
||||
CacheSignature(testModelName, "text", sig)
|
||||
CacheSignature(testModelName, "text-2", sig)
|
||||
|
||||
ClearSignatureCache("")
|
||||
|
||||
if got := GetCachedSignature("test-model", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||
t.Error("text should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("test-model", "text"); got != "" {
|
||||
t.Error("session-2 should be cleared")
|
||||
if got := GetCachedSignature(testModelName, "text-2"); got != "" {
|
||||
t.Error("text-2 should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasValidSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
signature string
|
||||
expected bool
|
||||
}{
|
||||
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
|
||||
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
|
||||
{"empty string", "", false},
|
||||
{"short signature", "abc", false},
|
||||
{"valid long signature", testModelName, "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||
{"exactly 50 chars", testModelName, "12345678901234567890123456789012345678901234567890", true},
|
||||
{"49 chars - invalid", testModelName, "1234567890123456789012345678901234567890123456789", false},
|
||||
{"empty string", testModelName, "", false},
|
||||
{"short signature", testModelName, "abc", false},
|
||||
{"gemini sentinel", "gemini-3-pro-preview", "skip_thought_signature_validator", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasValidSignature("claude-sonnet-4-5-thinking", tt.signature)
|
||||
result := HasValidSignature(tt.modelName, tt.signature)
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
||||
}
|
||||
@@ -147,13 +148,13 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature("test-model", text1, sig1)
|
||||
CacheSignature("test-model", text2, sig2)
|
||||
CacheSignature(testModelName, text1, sig1)
|
||||
CacheSignature(testModelName, text2, sig2)
|
||||
|
||||
if GetCachedSignature("test-model", text1) != sig1 {
|
||||
if GetCachedSignature(testModelName, text1) != sig1 {
|
||||
t.Error("text1 signature mismatch")
|
||||
}
|
||||
if GetCachedSignature("test-model", text2) != sig2 {
|
||||
if GetCachedSignature(testModelName, text2) != sig2 {
|
||||
t.Error("text2 signature mismatch")
|
||||
}
|
||||
}
|
||||
@@ -164,9 +165,9 @@ func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
||||
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
||||
|
||||
CacheSignature("test-model", text, sig)
|
||||
CacheSignature(testModelName, text, sig)
|
||||
|
||||
if got := GetCachedSignature("test-model", text); got != sig {
|
||||
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -178,10 +179,10 @@ func TestCacheSignature_Overwrite(t *testing.T) {
|
||||
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
||||
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
||||
|
||||
CacheSignature("test-model", text, sig1)
|
||||
CacheSignature("test-model", text, sig2) // Overwrite
|
||||
CacheSignature(testModelName, text, sig1)
|
||||
CacheSignature(testModelName, text, sig2) // Overwrite
|
||||
|
||||
if got := GetCachedSignature("test-model", text); got != sig2 {
|
||||
if got := GetCachedSignature(testModelName, text); got != sig2 {
|
||||
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
||||
}
|
||||
}
|
||||
@@ -196,10 +197,10 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
text := "text"
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature("test-model", text, sig)
|
||||
CacheSignature(testModelName, text, sig)
|
||||
|
||||
// Fresh entry should be retrievable
|
||||
if got := GetCachedSignature("test-model", text); got != sig {
|
||||
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
||||
}
|
||||
|
||||
|
||||
@@ -1033,10 +1033,10 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
"owned_by": model.OwnedBy,
|
||||
}
|
||||
if model.Created > 0 {
|
||||
result["created"] = model.Created
|
||||
result["created_at"] = model.Created
|
||||
}
|
||||
if model.Type != "" {
|
||||
result["type"] = model.Type
|
||||
result["type"] = "model"
|
||||
}
|
||||
if model.DisplayName != "" {
|
||||
result["display_name"] = model.DisplayName
|
||||
|
||||
@@ -398,7 +398,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
return nil, translatedPayload{}, err
|
||||
}
|
||||
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
||||
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||
|
||||
@@ -142,7 +142,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
return resp, err
|
||||
}
|
||||
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -261,7 +262,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
return resp, err
|
||||
}
|
||||
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -627,7 +629,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
return nil, err
|
||||
}
|
||||
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -994,7 +997,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
now := time.Now().Unix()
|
||||
modelConfig := registry.GetAntigravityModelConfig()
|
||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||
for originalName := range result.Map() {
|
||||
for originalName, modelData := range result.Map() {
|
||||
modelID := strings.TrimSpace(originalName)
|
||||
if modelID == "" {
|
||||
continue
|
||||
@@ -1004,12 +1007,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
continue
|
||||
}
|
||||
modelCfg := modelConfig[modelID]
|
||||
modelName := modelID
|
||||
|
||||
// Extract displayName from upstream response, fallback to modelID
|
||||
displayName := modelData.Get("displayName").String()
|
||||
if displayName == "" {
|
||||
displayName = modelID
|
||||
}
|
||||
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: modelID,
|
||||
Name: modelName,
|
||||
Description: modelID,
|
||||
DisplayName: modelID,
|
||||
Name: modelID,
|
||||
Description: displayName,
|
||||
DisplayName: displayName,
|
||||
Version: modelID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
@@ -1213,7 +1222,17 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
||||
|
||||
payload = []byte(strJSON)
|
||||
} else {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
// Clean tool schemas for Gemini to remove unsupported JSON Schema keywords
|
||||
// without adding empty-schema placeholders.
|
||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
|
||||
@@ -1230,6 +1249,12 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
} else {
|
||||
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
@@ -1405,26 +1430,10 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
// template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
||||
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
||||
template, _ = sjson.SetRaw(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()), funcDecl.Get("parametersJsonSchema").Raw)
|
||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters.$schema", key.Int(), funKey.Int()))
|
||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parametersJsonSchema", key.Int(), funKey.Int()))
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
|
||||
template, _ = sjson.Delete(template, "toolConfig")
|
||||
}
|
||||
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
return []byte(template)
|
||||
}
|
||||
|
||||
|
||||
@@ -114,7 +114,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -162,7 +163,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
@@ -245,7 +246,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -293,7 +295,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
@@ -731,6 +733,11 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
||||
|
||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||
// Skip built-in tools (web_search, code_execution, etc.) which have
|
||||
// a "type" field and require their name to remain unchanged.
|
||||
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
||||
return true
|
||||
}
|
||||
name := tool.Get("name").String()
|
||||
if name == "" || strings.HasPrefix(name, prefix) {
|
||||
return true
|
||||
|
||||
@@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||
out := applyClaudeToolPrefix(input, "proxy_")
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||
t.Fatalf("built-in tool name should not be prefixed: tools.0.name = %q, want %q", got, "web_search")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_my_custom_tool" {
|
||||
t.Fatalf("custom tool should be prefixed: tools.1.name = %q, want %q", got, "proxy_my_custom_tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
|
||||
@@ -101,7 +101,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
@@ -149,7 +150,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
@@ -213,7 +214,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
@@ -263,7 +265,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, readErr
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -129,7 +129,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -226,7 +227,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
if httpResp.StatusCode == 429 {
|
||||
if idx+1 < len(models) {
|
||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||
@@ -278,7 +279,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||
|
||||
projectID := resolveGeminiProjectID(auth)
|
||||
|
||||
@@ -358,7 +360,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), data...)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
if httpResp.StatusCode == 429 {
|
||||
if idx+1 < len(models) {
|
||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||
|
||||
@@ -126,7 +126,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := "generateContent"
|
||||
@@ -187,7 +188,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
@@ -228,7 +229,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
@@ -280,7 +282,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("gemini executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -400,7 +402,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
log.Debugf("request error, error status: %d, error body: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
||||
}
|
||||
|
||||
|
||||
@@ -325,7 +325,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
}
|
||||
|
||||
@@ -388,7 +389,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
@@ -438,7 +439,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, false)
|
||||
@@ -501,7 +503,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
@@ -541,7 +543,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
@@ -598,7 +601,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -664,7 +667,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
@@ -721,7 +725,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
@@ -834,7 +838,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
@@ -918,7 +922,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
|
||||
@@ -98,7 +98,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
|
||||
@@ -141,7 +142,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("iflow request error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
@@ -201,7 +202,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
body = ensureToolsArray(body)
|
||||
}
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
|
||||
@@ -242,7 +244,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
log.Errorf("iflow executor: close response body error: %v", errClose)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
log.Debugf("iflow streaming error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,7 +12,10 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -332,6 +335,12 @@ func summarizeErrorBody(contentType string, body []byte) string {
|
||||
}
|
||||
return "[html body omitted]"
|
||||
}
|
||||
|
||||
// Try to extract error message from JSON response
|
||||
if message := extractJSONErrorMessage(body); message != "" {
|
||||
return message
|
||||
}
|
||||
|
||||
return string(body)
|
||||
}
|
||||
|
||||
@@ -358,3 +367,25 @@ func extractHTMLTitle(body []byte) string {
|
||||
}
|
||||
return strings.Join(strings.Fields(title), " ")
|
||||
}
|
||||
|
||||
// extractJSONErrorMessage attempts to extract error.message from JSON error responses
|
||||
func extractJSONErrorMessage(body []byte) string {
|
||||
result := gjson.GetBytes(body, "error.message")
|
||||
if result.Exists() && result.String() != "" {
|
||||
return result.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// logWithRequestID returns a logrus Entry with request_id field populated from context.
|
||||
// If no request ID is found in context, it returns the standard logger.
|
||||
func logWithRequestID(ctx context.Context) *log.Entry {
|
||||
if ctx == nil {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
requestID := logging.GetRequestID(ctx)
|
||||
if requestID == "" {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
return log.WithField("request_id", requestID)
|
||||
}
|
||||
|
||||
@@ -90,7 +90,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -145,7 +146,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
@@ -185,7 +186,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -237,7 +239,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -12,8 +14,9 @@ import (
|
||||
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
||||
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||
// against the original payload when provided.
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte) []byte {
|
||||
// against the original payload when provided. requestedModel carries the client-visible
|
||||
// model name before alias resolution so payload rules can target aliases precisely.
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
@@ -22,10 +25,11 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
return payload
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return payload
|
||||
}
|
||||
candidates := payloadModelCandidates(cfg, model, protocol)
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
out := payload
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
@@ -163,65 +167,42 @@ func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) b
|
||||
return false
|
||||
}
|
||||
|
||||
func payloadModelCandidates(cfg *config.Config, model, protocol string) []string {
|
||||
func payloadModelCandidates(model, requestedModel string) []string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
candidates := []string{model}
|
||||
if cfg == nil {
|
||||
return candidates
|
||||
}
|
||||
aliases := payloadModelAliases(cfg, model, protocol)
|
||||
if len(aliases) == 0 {
|
||||
return candidates
|
||||
}
|
||||
seen := map[string]struct{}{strings.ToLower(model): struct{}{}}
|
||||
for _, alias := range aliases {
|
||||
alias = strings.TrimSpace(alias)
|
||||
if alias == "" {
|
||||
continue
|
||||
candidates := make([]string, 0, 3)
|
||||
seen := make(map[string]struct{}, 3)
|
||||
addCandidate := func(value string) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(alias)
|
||||
key := strings.ToLower(value)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
candidates = append(candidates, alias)
|
||||
candidates = append(candidates, value)
|
||||
}
|
||||
if model != "" {
|
||||
addCandidate(model)
|
||||
}
|
||||
if requestedModel != "" {
|
||||
parsed := thinking.ParseSuffix(requestedModel)
|
||||
base := strings.TrimSpace(parsed.ModelName)
|
||||
if base != "" {
|
||||
addCandidate(base)
|
||||
}
|
||||
if parsed.HasSuffix {
|
||||
addCandidate(requestedModel)
|
||||
}
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func payloadModelAliases(cfg *config.Config, model, protocol string) []string {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return nil
|
||||
}
|
||||
channel := strings.ToLower(strings.TrimSpace(protocol))
|
||||
if channel == "" {
|
||||
return nil
|
||||
}
|
||||
entries := cfg.OAuthModelAlias[channel]
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
aliases := make([]string, 0, 2)
|
||||
for _, entry := range entries {
|
||||
if !strings.EqualFold(strings.TrimSpace(entry.Name), model) {
|
||||
continue
|
||||
}
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if alias == "" {
|
||||
continue
|
||||
}
|
||||
aliases = append(aliases, alias)
|
||||
}
|
||||
return aliases
|
||||
}
|
||||
|
||||
// buildPayloadPath combines an optional root path with a relative parameter path.
|
||||
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
||||
// the parameter path is treated as relative to root.
|
||||
@@ -258,6 +239,35 @@ func payloadRawValue(value any) ([]byte, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||
fallback = strings.TrimSpace(fallback)
|
||||
if len(opts.Metadata) == 0 {
|
||||
return fallback
|
||||
}
|
||||
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return fallback
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return fallback
|
||||
}
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return fallback
|
||||
}
|
||||
trimmed := strings.TrimSpace(string(v))
|
||||
if trimmed == "" {
|
||||
return fallback
|
||||
}
|
||||
return trimmed
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
||||
// Examples:
|
||||
//
|
||||
|
||||
@@ -91,7 +91,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -132,7 +133,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
@@ -184,7 +185,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -220,7 +222,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||
}
|
||||
|
||||
@@ -74,13 +74,13 @@ func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Valid signature must be at least 50 characters
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Let me think..."
|
||||
|
||||
// Pre-cache the signature (simulating a response from the same session)
|
||||
// The session ID is derived from the first user message hash
|
||||
// Since there's no user message in this test, we need to add one
|
||||
// Pre-cache the signature (simulating a previous response for the same thinking text)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
@@ -117,6 +117,8 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Unsigned thinking blocks should be removed entirely (not converted to text)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
@@ -238,6 +240,8 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Let me think..."
|
||||
|
||||
@@ -279,6 +283,8 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Case: text block followed by thinking block -> should be reordered to thinking first
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Planning..."
|
||||
@@ -487,6 +493,8 @@ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *t
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Last assistant message ends with signed thinking block - should be kept
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Valid thinking..."
|
||||
|
||||
@@ -139,7 +139,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
|
||||
if params.CurrentThinkingText.Len() > 0 {
|
||||
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
||||
// log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len())
|
||||
params.CurrentThinkingText.Reset()
|
||||
}
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
// Signature Caching Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
|
||||
func TestConvertAntigravityResponseToClaude_ParamsInitialized(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Request with user message - should derive session ID
|
||||
// Request with user message - should initialize params
|
||||
requestJSON := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
|
||||
@@ -37,10 +37,12 @@ func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m)
|
||||
|
||||
// Verify session ID was set
|
||||
params := param.(*Params)
|
||||
if params.SessionID == "" {
|
||||
t.Error("SessionID should be derived from request")
|
||||
if !params.HasFirstResponse {
|
||||
t.Error("HasFirstResponse should be set after first chunk")
|
||||
}
|
||||
if params.CurrentThinkingText.Len() == 0 {
|
||||
t.Error("Thinking text should be accumulated")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,12 +132,8 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
||||
// Process thinking chunk
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m)
|
||||
params := param.(*Params)
|
||||
sessionID := params.SessionID
|
||||
thinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
if sessionID == "" {
|
||||
t.Fatal("SessionID should be set")
|
||||
}
|
||||
if thinkingText == "" {
|
||||
t.Fatal("Thinking text should be accumulated")
|
||||
}
|
||||
|
||||
@@ -62,40 +62,6 @@ func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *test
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) {
|
||||
// Thinking blocks should be removed entirely for Gemini
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(fmt.Sprintf(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"thought": true, "text": "Thinking...", "thoughtSignature": "%s"},
|
||||
{"text": "Here is my response"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`, validSignature))
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that thinking block is removed
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Only text part should remain
|
||||
if parts[0].Get("thought").Bool() {
|
||||
t.Error("Thinking block should be removed for Gemini")
|
||||
}
|
||||
if parts[0].Get("text").String() != "Here is my response" {
|
||||
t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
||||
// Multiple functionCalls should all get skip_thought_signature_validator
|
||||
inputJSON := []byte(`{
|
||||
|
||||
@@ -305,12 +305,12 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
}
|
||||
|
||||
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
|
||||
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||
toolNode := []byte(`{}`)
|
||||
hasTool := false
|
||||
functionToolNode := []byte(`{}`)
|
||||
hasFunction := false
|
||||
googleSearchNodes := make([][]byte, 0)
|
||||
for _, t := range tools.Array() {
|
||||
if t.Get("type").String() == "function" {
|
||||
fn := t.Get("function")
|
||||
@@ -349,31 +349,37 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
}
|
||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
toolNode = tmp
|
||||
functionToolNode = tmp
|
||||
hasFunction = true
|
||||
hasTool = true
|
||||
}
|
||||
}
|
||||
if gs := t.Get("google_search"); gs.Exists() {
|
||||
googleToolNode := []byte(`{}`)
|
||||
var errSet error
|
||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
||||
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||
continue
|
||||
}
|
||||
hasTool = true
|
||||
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||
}
|
||||
}
|
||||
if hasTool {
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
|
||||
if hasFunction || len(googleSearchNodes) > 0 {
|
||||
toolsNode := []byte("[]")
|
||||
if hasFunction {
|
||||
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||
}
|
||||
for _, googleNode := range googleSearchNodes {
|
||||
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -98,9 +98,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Temperature setting for controlling response randomness
|
||||
if temp := genConfig.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
}
|
||||
// Top P setting for nucleus sampling
|
||||
if topP := genConfig.Get("topP"); topP.Exists() {
|
||||
} else if topP := genConfig.Get("topP"); topP.Exists() {
|
||||
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
}
|
||||
// Stop sequences configuration for custom termination conditions
|
||||
|
||||
@@ -110,10 +110,8 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Temperature setting for controlling response randomness
|
||||
if temp := root.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
}
|
||||
|
||||
// Top P setting for nucleus sampling
|
||||
if topP := root.Get("top_p"); topP.Exists() {
|
||||
} else if topP := root.Get("top_p"); topP.Exists() {
|
||||
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
}
|
||||
|
||||
|
||||
@@ -283,12 +283,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
}
|
||||
}
|
||||
|
||||
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
|
||||
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||
toolNode := []byte(`{}`)
|
||||
hasTool := false
|
||||
functionToolNode := []byte(`{}`)
|
||||
hasFunction := false
|
||||
googleSearchNodes := make([][]byte, 0)
|
||||
for _, t := range tools.Array() {
|
||||
if t.Get("type").String() == "function" {
|
||||
fn := t.Get("function")
|
||||
@@ -327,31 +327,37 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
}
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
}
|
||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
toolNode = tmp
|
||||
functionToolNode = tmp
|
||||
hasFunction = true
|
||||
hasTool = true
|
||||
}
|
||||
}
|
||||
if gs := t.Get("google_search"); gs.Exists() {
|
||||
googleToolNode := []byte(`{}`)
|
||||
var errSet error
|
||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
||||
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||
continue
|
||||
}
|
||||
hasTool = true
|
||||
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||
}
|
||||
}
|
||||
if hasTool {
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
|
||||
if hasFunction || len(googleSearchNodes) > 0 {
|
||||
toolsNode := []byte("[]")
|
||||
if hasFunction {
|
||||
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||
}
|
||||
for _, googleNode := range googleSearchNodes {
|
||||
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -289,12 +289,12 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
|
||||
// tools -> tools[0].functionDeclarations + tools[0].googleSearch passthrough
|
||||
// tools -> tools[].functionDeclarations + tools[].googleSearch passthrough
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||
toolNode := []byte(`{}`)
|
||||
hasTool := false
|
||||
functionToolNode := []byte(`{}`)
|
||||
hasFunction := false
|
||||
googleSearchNodes := make([][]byte, 0)
|
||||
for _, t := range tools.Array() {
|
||||
if t.Get("type").String() == "function" {
|
||||
fn := t.Get("function")
|
||||
@@ -333,31 +333,37 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||
if !hasFunction {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
||||
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||
}
|
||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||
continue
|
||||
}
|
||||
toolNode = tmp
|
||||
functionToolNode = tmp
|
||||
hasFunction = true
|
||||
hasTool = true
|
||||
}
|
||||
}
|
||||
if gs := t.Get("google_search"); gs.Exists() {
|
||||
googleToolNode := []byte(`{}`)
|
||||
var errSet error
|
||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
||||
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||
if errSet != nil {
|
||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||
continue
|
||||
}
|
||||
hasTool = true
|
||||
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||
}
|
||||
}
|
||||
if hasTool {
|
||||
out, _ = sjson.SetRawBytes(out, "tools", []byte("[]"))
|
||||
out, _ = sjson.SetRawBytes(out, "tools.0", toolNode)
|
||||
if hasFunction || len(googleSearchNodes) > 0 {
|
||||
toolsNode := []byte("[]")
|
||||
if hasFunction {
|
||||
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||
}
|
||||
for _, googleNode := range googleSearchNodes {
|
||||
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "tools", toolsNode)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -89,12 +89,14 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
// Handle system message first
|
||||
systemMsgJSON := `{"role":"system","content":[]}`
|
||||
hasSystemContent := false
|
||||
if system := root.Get("system"); system.Exists() {
|
||||
if system.Type == gjson.String {
|
||||
if system.String() != "" {
|
||||
oldSystem := `{"type":"text","text":""}`
|
||||
oldSystem, _ = sjson.Set(oldSystem, "text", system.String())
|
||||
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", oldSystem)
|
||||
hasSystemContent = true
|
||||
}
|
||||
} else if system.Type == gjson.JSON {
|
||||
if system.IsArray() {
|
||||
@@ -102,12 +104,16 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
if contentItem, ok := convertClaudeContentPart(systemResults[i]); ok {
|
||||
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", contentItem)
|
||||
hasSystemContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
|
||||
// Only add system message if it has content
|
||||
if hasSystemContent {
|
||||
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
|
||||
}
|
||||
|
||||
// Process Anthropic messages
|
||||
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
|
||||
|
||||
@@ -181,11 +181,11 @@ func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
// Find the relevant message (skip system message at index 0)
|
||||
// Find the relevant message
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
if len(messages) < 2 {
|
||||
if len(messages) < 1 {
|
||||
if tt.wantHasReasoningContent || tt.wantHasContent {
|
||||
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages))
|
||||
t.Fatalf("Expected at least 1 message, got %d", len(messages))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -272,15 +272,15 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T)
|
||||
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
|
||||
// Should have: user + assistant (thinking-only) + user = 3 messages
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("Expected 3 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
// Check the assistant message (index 2) has reasoning_content
|
||||
assistantMsg := messages[2]
|
||||
// Check the assistant message (index 1) has reasoning_content
|
||||
assistantMsg := messages[1]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
t.Errorf("Expected message[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
if !assistantMsg.Get("reasoning_content").Exists() {
|
||||
@@ -292,6 +292,104 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_SystemMessageScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantHasSys bool
|
||||
wantSysText string
|
||||
}{
|
||||
{
|
||||
name: "No system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasSys: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": "",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasSys: false,
|
||||
},
|
||||
{
|
||||
name: "String system field",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": "Be helpful",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasSys: true,
|
||||
wantSysText: "Be helpful",
|
||||
},
|
||||
{
|
||||
name: "Array system field with text",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": [{"type": "text", "text": "Array system"}],
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasSys: true,
|
||||
wantSysText: "Array system",
|
||||
},
|
||||
{
|
||||
name: "Array system field with multiple text blocks",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"system": [
|
||||
{"type": "text", "text": "Block 1"},
|
||||
{"type": "text", "text": "Block 2"}
|
||||
],
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
}`,
|
||||
wantHasSys: true,
|
||||
wantSysText: "Block 2", // We will update the test logic to check all blocks or specifically the second one
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
hasSys := false
|
||||
var sysMsg gjson.Result
|
||||
if len(messages) > 0 && messages[0].Get("role").String() == "system" {
|
||||
hasSys = true
|
||||
sysMsg = messages[0]
|
||||
}
|
||||
|
||||
if hasSys != tt.wantHasSys {
|
||||
t.Errorf("got hasSystem = %v, want %v", hasSys, tt.wantHasSys)
|
||||
}
|
||||
|
||||
if tt.wantHasSys {
|
||||
// Check content - it could be string or array in OpenAI
|
||||
content := sysMsg.Get("content")
|
||||
var gotText string
|
||||
if content.IsArray() {
|
||||
arr := content.Array()
|
||||
if len(arr) > 0 {
|
||||
// Get the last element's text for validation
|
||||
gotText = arr[len(arr)-1].Get("text").String()
|
||||
}
|
||||
} else {
|
||||
gotText = content.String()
|
||||
}
|
||||
|
||||
if tt.wantSysText != "" && gotText != tt.wantSysText {
|
||||
t.Errorf("got system text = %q, want %q", gotText, tt.wantSysText)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
@@ -318,39 +416,35 @@ func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
|
||||
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after)
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
// Correct order: assistant(tool_calls) + tool(result) + user(before+after)
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[0].Get("role").String() != "system" {
|
||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
||||
}
|
||||
|
||||
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
|
||||
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
|
||||
if messages[0].Get("role").String() != "assistant" || !messages[0].Get("tool_calls").Exists() {
|
||||
t.Fatalf("Expected messages[0] to be assistant tool_calls, got %s: %s", messages[0].Get("role").String(), messages[0].Raw)
|
||||
}
|
||||
|
||||
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
|
||||
if messages[2].Get("role").String() != "tool" {
|
||||
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String())
|
||||
if messages[1].Get("role").String() != "tool" {
|
||||
t.Fatalf("Expected messages[1] to be tool (must follow tool_calls), got %s", messages[1].Get("role").String())
|
||||
}
|
||||
if got := messages[2].Get("tool_call_id").String(); got != "call_1" {
|
||||
if got := messages[1].Get("tool_call_id").String(); got != "call_1" {
|
||||
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
|
||||
}
|
||||
if got := messages[2].Get("content").String(); got != "tool ok" {
|
||||
if got := messages[1].Get("content").String(); got != "tool ok" {
|
||||
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
|
||||
}
|
||||
|
||||
// User message comes after tool message
|
||||
if messages[3].Get("role").String() != "user" {
|
||||
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String())
|
||||
if messages[2].Get("role").String() != "user" {
|
||||
t.Fatalf("Expected messages[2] to be user, got %s", messages[2].Get("role").String())
|
||||
}
|
||||
// User message should contain both "before" and "after" text
|
||||
if got := messages[3].Get("content.0.text").String(); got != "before" {
|
||||
if got := messages[2].Get("content.0.text").String(); got != "before" {
|
||||
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
|
||||
}
|
||||
if got := messages[3].Get("content.1.text").String(); got != "after" {
|
||||
if got := messages[2].Get("content.1.text").String(); got != "after" {
|
||||
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
|
||||
}
|
||||
}
|
||||
@@ -378,16 +472,16 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// system + assistant(tool_calls) + tool(result)
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
// assistant(tool_calls) + tool(result)
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[2].Get("role").String() != "tool" {
|
||||
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String())
|
||||
if messages[1].Get("role").String() != "tool" {
|
||||
t.Fatalf("Expected messages[1] to be tool, got %s", messages[1].Get("role").String())
|
||||
}
|
||||
|
||||
toolContent := messages[2].Get("content").String()
|
||||
toolContent := messages[1].Get("content").String()
|
||||
parsed := gjson.Parse(toolContent)
|
||||
if parsed.Get("foo").String() != "bar" {
|
||||
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
|
||||
@@ -414,18 +508,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// New behavior: content + tool_calls unified in single assistant message
|
||||
// Expect: system + assistant(content[pre,post] + tool_calls)
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
// Expect: assistant(content[pre,post] + tool_calls)
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
if messages[0].Get("role").String() != "system" {
|
||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
||||
}
|
||||
|
||||
assistantMsg := messages[1]
|
||||
assistantMsg := messages[0]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
// Should have both content and tool_calls in same message
|
||||
@@ -470,14 +560,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
|
||||
// New behavior: all content, thinking, and tool_calls unified in single assistant message
|
||||
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
|
||||
if len(messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
// Expect: assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||
}
|
||||
|
||||
assistantMsg := messages[1]
|
||||
assistantMsg := messages[0]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
// Should have content with both pre and post
|
||||
|
||||
@@ -12,10 +12,23 @@ import (
|
||||
|
||||
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
|
||||
const placeholderReasonDescription = "Brief explanation of why you are calling this tool"
|
||||
|
||||
// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
|
||||
// It handles unsupported keywords, type flattening, and schema simplification while preserving
|
||||
// semantic information as description hints.
|
||||
func CleanJSONSchemaForAntigravity(jsonStr string) string {
|
||||
return cleanJSONSchema(jsonStr, true)
|
||||
}
|
||||
|
||||
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini tool calling.
|
||||
// It removes unsupported keywords and simplifies schemas, without adding empty-schema placeholders.
|
||||
func CleanJSONSchemaForGemini(jsonStr string) string {
|
||||
return cleanJSONSchema(jsonStr, false)
|
||||
}
|
||||
|
||||
// cleanJSONSchema performs the core cleaning operations on the JSON schema.
|
||||
func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
|
||||
// Phase 1: Convert and add hints
|
||||
jsonStr = convertRefsToHints(jsonStr)
|
||||
jsonStr = convertConstToEnum(jsonStr)
|
||||
@@ -31,10 +44,94 @@ func CleanJSONSchemaForAntigravity(jsonStr string) string {
|
||||
|
||||
// Phase 3: Cleanup
|
||||
jsonStr = removeUnsupportedKeywords(jsonStr)
|
||||
if !addPlaceholder {
|
||||
// Gemini schema cleanup: remove nullable/title and placeholder-only fields.
|
||||
jsonStr = removeKeywords(jsonStr, []string{"nullable", "title"})
|
||||
jsonStr = removePlaceholderFields(jsonStr)
|
||||
}
|
||||
jsonStr = cleanupRequiredFields(jsonStr)
|
||||
|
||||
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
|
||||
jsonStr = addEmptySchemaPlaceholder(jsonStr)
|
||||
if addPlaceholder {
|
||||
jsonStr = addEmptySchemaPlaceholder(jsonStr)
|
||||
}
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// removeKeywords removes all occurrences of specified keywords from the JSON schema.
|
||||
func removeKeywords(jsonStr string, keywords []string) string {
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// removePlaceholderFields removes placeholder-only properties ("_" and "reason") and their required entries.
|
||||
func removePlaceholderFields(jsonStr string) string {
|
||||
// Remove "_" placeholder properties.
|
||||
paths := findPaths(jsonStr, "_")
|
||||
sortByDepth(paths)
|
||||
for _, p := range paths {
|
||||
if !strings.HasSuffix(p, ".properties._") {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
parentPath := trimSuffix(p, ".properties._")
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
req := gjson.Get(jsonStr, reqPath)
|
||||
if req.IsArray() {
|
||||
var filtered []string
|
||||
for _, r := range req.Array() {
|
||||
if r.String() != "_" {
|
||||
filtered = append(filtered, r.String())
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove placeholder-only "reason" objects.
|
||||
reasonPaths := findPaths(jsonStr, "reason")
|
||||
sortByDepth(reasonPaths)
|
||||
for _, p := range reasonPaths {
|
||||
if !strings.HasSuffix(p, ".properties.reason") {
|
||||
continue
|
||||
}
|
||||
parentPath := trimSuffix(p, ".properties.reason")
|
||||
props := gjson.Get(jsonStr, joinPath(parentPath, "properties"))
|
||||
if !props.IsObject() || len(props.Map()) != 1 {
|
||||
continue
|
||||
}
|
||||
desc := gjson.Get(jsonStr, p+".description").String()
|
||||
if desc != placeholderReasonDescription {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
req := gjson.Get(jsonStr, reqPath)
|
||||
if req.IsArray() {
|
||||
var filtered []string
|
||||
for _, r := range req.Array() {
|
||||
if r.String() != "reason" {
|
||||
filtered = append(filtered, r.String())
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
@@ -409,7 +506,7 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
|
||||
// Add placeholder "reason" property
|
||||
reasonPath := joinPath(propsPath, "reason")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", placeholderReasonDescription)
|
||||
|
||||
// Add to required array
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
|
||||
|
||||
@@ -128,8 +128,23 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
|
||||
// Parameters:
|
||||
// - c: The Gin context for the request.
|
||||
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
|
||||
models := h.Models()
|
||||
firstID := ""
|
||||
lastID := ""
|
||||
if len(models) > 0 {
|
||||
if id, ok := models[0]["id"].(string); ok {
|
||||
firstID = id
|
||||
}
|
||||
if id, ok := models[len(models)-1]["id"].(string); ok {
|
||||
lastID = id
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": h.Models(),
|
||||
"data": models,
|
||||
"has_more": false,
|
||||
"first_id": firstID,
|
||||
"last_id": lastID,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -60,8 +60,12 @@ func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
|
||||
if !strings.HasPrefix(name, "models/") {
|
||||
normalizedModel["name"] = "models/" + name
|
||||
}
|
||||
normalizedModel["displayName"] = name
|
||||
normalizedModel["description"] = name
|
||||
if displayName, _ := normalizedModel["displayName"].(string); displayName == "" {
|
||||
normalizedModel["displayName"] = name
|
||||
}
|
||||
if description, _ := normalizedModel["description"].(string); description == "" {
|
||||
normalizedModel["description"] = name
|
||||
}
|
||||
}
|
||||
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
|
||||
normalizedModel["supportedGenerationMethods"] = defaultMethods
|
||||
|
||||
@@ -385,6 +385,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -423,6 +424,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -464,6 +466,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return nil, errChan
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
|
||||
@@ -2,15 +2,13 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
@@ -19,20 +17,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
antigravityCallbackPort = 51121
|
||||
)
|
||||
|
||||
var antigravityScopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
// AntigravityAuthenticator implements OAuth login for the antigravity provider.
|
||||
type AntigravityAuthenticator struct{}
|
||||
|
||||
@@ -60,12 +44,12 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := antigravityCallbackPort
|
||||
callbackPort := antigravity.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
||||
authSvc := antigravity.NewAntigravityAuth(cfg, nil)
|
||||
|
||||
state, err := misc.GenerateRandomState()
|
||||
if err != nil {
|
||||
@@ -83,7 +67,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
}()
|
||||
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port)
|
||||
authURL := buildAntigravityAuthURL(redirectURI, state)
|
||||
authURL := authSvc.BuildAuthURL(state, redirectURI)
|
||||
|
||||
if !opts.NoBrowser {
|
||||
fmt.Println("Opening browser for antigravity authentication")
|
||||
@@ -164,22 +148,29 @@ waitForCallback:
|
||||
return nil, fmt.Errorf("antigravity: missing authorization code")
|
||||
}
|
||||
|
||||
tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI, httpClient)
|
||||
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, cbRes.Code, redirectURI)
|
||||
if errToken != nil {
|
||||
return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken)
|
||||
}
|
||||
|
||||
email := ""
|
||||
if tokenResp.AccessToken != "" {
|
||||
if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" {
|
||||
email = strings.TrimSpace(info.Email)
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("antigravity: token exchange returned empty access token")
|
||||
}
|
||||
|
||||
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
|
||||
if errInfo != nil {
|
||||
return nil, fmt.Errorf("antigravity: fetch user info failed: %w", errInfo)
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
return nil, fmt.Errorf("antigravity: empty email returned from user info")
|
||||
}
|
||||
|
||||
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
|
||||
projectID := ""
|
||||
if tokenResp.AccessToken != "" {
|
||||
fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
||||
if accessToken != "" {
|
||||
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
|
||||
if errProject != nil {
|
||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||
} else {
|
||||
@@ -204,7 +195,7 @@ waitForCallback:
|
||||
metadata["project_id"] = projectID
|
||||
}
|
||||
|
||||
fileName := sanitizeAntigravityFileName(email)
|
||||
fileName := antigravity.CredentialFileName(email)
|
||||
label := email
|
||||
if label == "" {
|
||||
label = "antigravity"
|
||||
@@ -231,7 +222,7 @@ type callbackResult struct {
|
||||
|
||||
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
||||
if port <= 0 {
|
||||
port = antigravityCallbackPort
|
||||
port = antigravity.CallbackPort
|
||||
}
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
@@ -267,309 +258,9 @@ func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbac
|
||||
return srv, port, resultCh, nil
|
||||
}
|
||||
|
||||
type antigravityTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
func exchangeAntigravityCode(ctx context.Context, code, redirectURI string, httpClient *http.Client) (*antigravityTokenResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("code", code)
|
||||
data.Set("client_id", antigravityClientID)
|
||||
data.Set("client_secret", antigravityClientSecret)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
data.Set("grant_type", "authorization_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return nil, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
var token antigravityTokenResponse
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
|
||||
return nil, errDecode
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode)
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
type antigravityUserInfo struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func fetchAntigravityUserInfo(ctx context.Context, accessToken string, httpClient *http.Client) (*antigravityUserInfo, error) {
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return &antigravityUserInfo{}, nil
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return nil, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity userinfo: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return &antigravityUserInfo{}, nil
|
||||
}
|
||||
var info antigravityUserInfo
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||
return nil, errDecode
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
func buildAntigravityAuthURL(redirectURI, state string) string {
|
||||
params := url.Values{}
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("client_id", antigravityClientID)
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", strings.Join(antigravityScopes, " "))
|
||||
params.Set("state", state)
|
||||
return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
||||
}
|
||||
|
||||
func sanitizeAntigravityFileName(email string) string {
|
||||
if strings.TrimSpace(email) == "" {
|
||||
return "antigravity.json"
|
||||
}
|
||||
replacer := strings.NewReplacer("@", "_", ".", "_")
|
||||
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
||||
}
|
||||
|
||||
// Antigravity API constants for project discovery
|
||||
const (
|
||||
antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityAPIVersion = "v1internal"
|
||||
antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||
antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
||||
)
|
||||
|
||||
// FetchAntigravityProjectID exposes project discovery for external callers.
|
||||
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
||||
return fetchAntigravityProjectID(ctx, accessToken, httpClient)
|
||||
}
|
||||
|
||||
// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist.
|
||||
// This uses the same approach as Gemini CLI to get the cloudaicompanionProject.
|
||||
func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
||||
// Call loadCodeAssist to get the project
|
||||
loadReqBody := map[string]any{
|
||||
"metadata": map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
||||
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
|
||||
var loadResp map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
// Extract projectID from response
|
||||
projectID := ""
|
||||
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
if projectID == "" {
|
||||
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||
if id, okID := projectMap["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
tierID := "legacy-tier"
|
||||
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||
for _, rawTier := range tiers {
|
||||
tier, okTier := rawTier.(map[string]any)
|
||||
if !okTier {
|
||||
continue
|
||||
}
|
||||
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||
tierID = strings.TrimSpace(id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
|
||||
// It returns an empty string when the operation times out or completes without a project ID.
|
||||
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
fmt.Println("Antigravity: onboarding user...", tierID)
|
||||
requestBody := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(requestBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
maxAttempts := 5
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||
|
||||
reqCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if reqCtx == nil {
|
||||
reqCtx = context.Background()
|
||||
}
|
||||
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
|
||||
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if errRequest != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
||||
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("close body error: %v", errClose)
|
||||
}
|
||||
cancel()
|
||||
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var data map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
if done, okDone := data["done"].(bool); okDone && done {
|
||||
projectID := ""
|
||||
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||
case map[string]any:
|
||||
if id, okID := projectValue["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
case string:
|
||||
projectID = strings.TrimSpace(projectValue)
|
||||
}
|
||||
}
|
||||
|
||||
if projectID != "" {
|
||||
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no project_id in response")
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||
if len(responsePreview) > 500 {
|
||||
responsePreview = responsePreview[:500]
|
||||
}
|
||||
|
||||
responseErr := responsePreview
|
||||
if len(responseErr) > 200 {
|
||||
responseErr = responseErr[:200]
|
||||
}
|
||||
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
cfg := &config.Config{}
|
||||
authSvc := antigravity.NewAntigravityAuth(cfg, httpClient)
|
||||
return authSvc.FetchProjectID(ctx, accessToken)
|
||||
}
|
||||
|
||||
@@ -570,6 +570,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
@@ -597,6 +598,9 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
@@ -619,6 +623,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
@@ -646,6 +651,9 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
@@ -668,6 +676,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
@@ -694,6 +703,9 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
@@ -729,167 +741,42 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return opts
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
if hasRequestedModelMetadata(opts.Metadata) {
|
||||
return opts
|
||||
}
|
||||
if len(opts.Metadata) == 0 {
|
||||
opts.Metadata = map[string]any{cliproxyexecutor.RequestedModelMetadataKey: requestedModel}
|
||||
return opts
|
||||
}
|
||||
meta := make(map[string]any, len(opts.Metadata)+1)
|
||||
for k, v := range opts.Metadata {
|
||||
meta[k] = v
|
||||
}
|
||||
meta[cliproxyexecutor.RequestedModelMetadataKey] = requestedModel
|
||||
opts.Metadata = meta
|
||||
return opts
|
||||
}
|
||||
|
||||
func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
func hasRequestedModelMetadata(meta map[string]any) bool {
|
||||
if len(meta) == 0 {
|
||||
return false
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
raw, ok := meta[cliproxyexecutor.RequestedModelMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if provider == "" {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v) != ""
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v)) != ""
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1140,35 +1027,6 @@ func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// rotateProviders returns a rotated view of the providers list starting from the
|
||||
// current offset for the model, and atomically increments the offset for the next call.
|
||||
// This ensures concurrent requests get different starting providers.
|
||||
func (m *Manager) rotateProviders(model string, providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Atomic read-and-increment: get current offset and advance cursor in one lock
|
||||
m.mu.Lock()
|
||||
offset := m.providerOffsets[model]
|
||||
m.providerOffsets[model] = (offset + 1) % len(providers)
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(providers) > 0 {
|
||||
offset %= len(providers)
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if offset == 0 {
|
||||
return providers
|
||||
}
|
||||
rotated := make([]string, 0, len(providers))
|
||||
rotated = append(rotated, providers[offset:]...)
|
||||
rotated = append(rotated, providers[:offset]...)
|
||||
return rotated
|
||||
}
|
||||
|
||||
func (m *Manager) retrySettings() (int, time.Duration) {
|
||||
if m == nil {
|
||||
return 0, 0
|
||||
@@ -1250,42 +1108,6 @@ func waitForCooldown(ctx context.Context, wait time.Duration) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (cliproxyexecutor.Response, error)) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
resp, errExec := fn(ctx, provider)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (<-chan cliproxyexecutor.StreamChunk, error)) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
chunks, errExec := fn(ctx, provider)
|
||||
if errExec == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// MarkResult records an execution result and notifies hooks.
|
||||
func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
if result.AuthID == "" {
|
||||
@@ -1371,8 +1193,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
case 408, 500, 502, 503, 504:
|
||||
next := now.Add(1 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
if quotaCooldownDisabled.Load() {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
}
|
||||
default:
|
||||
state.NextRetryAfter = time.Time{}
|
||||
}
|
||||
@@ -1623,7 +1449,11 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
||||
auth.NextRetryAfter = next
|
||||
case 408, 500, 502, 503, 504:
|
||||
auth.StatusMessage = "transient upstream error"
|
||||
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
||||
if quotaCooldownDisabled.Load() {
|
||||
auth.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
||||
}
|
||||
default:
|
||||
if auth.StatusMessage == "" {
|
||||
auth.StatusMessage = "request failed"
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
|
||||
const RequestedModelMetadataKey = "requested_model"
|
||||
|
||||
// Request encapsulates the translated payload that will be sent to a provider executor.
|
||||
type Request struct {
|
||||
// Model is the upstream model identifier after translation.
|
||||
|
||||
Reference in New Issue
Block a user