mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
323 lines
9.4 KiB
Go
323 lines
9.4 KiB
Go
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
|
package antigravity
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// TokenResponse represents OAuth token response from Google
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
TokenType string `json:"token_type"`
|
|
}
|
|
|
|
// userInfo represents Google user profile
|
|
type userInfo struct {
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
// AntigravityAuth handles Antigravity OAuth authentication
|
|
type AntigravityAuth struct {
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewAntigravityAuth creates a new Antigravity auth service.
|
|
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
|
|
if httpClient != nil {
|
|
return &AntigravityAuth{httpClient: httpClient}
|
|
}
|
|
if cfg == nil {
|
|
cfg = &config.Config{}
|
|
}
|
|
return &AntigravityAuth{
|
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
|
}
|
|
}
|
|
|
|
// BuildAuthURL generates the OAuth authorization URL.
|
|
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
|
|
if strings.TrimSpace(redirectURI) == "" {
|
|
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
|
|
}
|
|
params := url.Values{}
|
|
params.Set("access_type", "offline")
|
|
params.Set("client_id", ClientID)
|
|
params.Set("prompt", "consent")
|
|
params.Set("redirect_uri", redirectURI)
|
|
params.Set("response_type", "code")
|
|
params.Set("scope", strings.Join(Scopes, " "))
|
|
params.Set("state", state)
|
|
return AuthEndpoint + "?" + params.Encode()
|
|
}
|
|
|
|
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
|
|
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
|
|
data := url.Values{}
|
|
data.Set("code", code)
|
|
data.Set("client_id", ClientID)
|
|
data.Set("client_secret", ClientSecret)
|
|
data.Set("redirect_uri", redirectURI)
|
|
data.Set("grant_type", "authorization_code")
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, errDo := o.httpClient.Do(req)
|
|
if errDo != nil {
|
|
return nil, errDo
|
|
}
|
|
defer func() {
|
|
if errClose := resp.Body.Close(); errClose != nil {
|
|
log.Errorf("antigravity token exchange: close body error: %v", errClose)
|
|
}
|
|
}()
|
|
|
|
var token TokenResponse
|
|
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
|
|
return nil, errDecode
|
|
}
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode)
|
|
}
|
|
return &token, nil
|
|
}
|
|
|
|
// FetchUserInfo retrieves user email from Google
|
|
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
|
if strings.TrimSpace(accessToken) == "" {
|
|
return "", nil
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
|
|
resp, errDo := o.httpClient.Do(req)
|
|
if errDo != nil {
|
|
return "", errDo
|
|
}
|
|
defer func() {
|
|
if errClose := resp.Body.Close(); errClose != nil {
|
|
log.Errorf("antigravity userinfo: close body error: %v", errClose)
|
|
}
|
|
}()
|
|
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
return "", nil
|
|
}
|
|
var info userInfo
|
|
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
|
return "", errDecode
|
|
}
|
|
return info.Email, nil
|
|
}
|
|
|
|
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
|
|
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
|
|
loadReqBody := map[string]any{
|
|
"metadata": map[string]string{
|
|
"ideType": "ANTIGRAVITY",
|
|
"platform": "PLATFORM_UNSPECIFIED",
|
|
"pluginType": "GEMINI",
|
|
},
|
|
}
|
|
|
|
rawBody, errMarshal := json.Marshal(loadReqBody)
|
|
if errMarshal != nil {
|
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
|
}
|
|
|
|
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
|
if err != nil {
|
|
return "", fmt.Errorf("create request: %w", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("User-Agent", APIUserAgent)
|
|
req.Header.Set("X-Goog-Api-Client", APIClient)
|
|
req.Header.Set("Client-Metadata", ClientMetadata)
|
|
|
|
resp, errDo := o.httpClient.Do(req)
|
|
if errDo != nil {
|
|
return "", fmt.Errorf("execute request: %w", errDo)
|
|
}
|
|
defer func() {
|
|
if errClose := resp.Body.Close(); errClose != nil {
|
|
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
|
}
|
|
}()
|
|
|
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
|
if errRead != nil {
|
|
return "", fmt.Errorf("read response: %w", errRead)
|
|
}
|
|
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
|
}
|
|
|
|
var loadResp map[string]any
|
|
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
|
}
|
|
|
|
// Extract projectID from response
|
|
projectID := ""
|
|
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
|
projectID = strings.TrimSpace(id)
|
|
}
|
|
if projectID == "" {
|
|
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
|
if id, okID := projectMap["id"].(string); okID {
|
|
projectID = strings.TrimSpace(id)
|
|
}
|
|
}
|
|
}
|
|
|
|
if projectID == "" {
|
|
tierID := "legacy-tier"
|
|
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
|
for _, rawTier := range tiers {
|
|
tier, okTier := rawTier.(map[string]any)
|
|
if !okTier {
|
|
continue
|
|
}
|
|
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
|
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
|
tierID = strings.TrimSpace(id)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return projectID, nil
|
|
}
|
|
|
|
return projectID, nil
|
|
}
|
|
|
|
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
|
|
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
|
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
|
|
requestBody := map[string]any{
|
|
"tierId": tierID,
|
|
"metadata": map[string]string{
|
|
"ideType": "ANTIGRAVITY",
|
|
"platform": "PLATFORM_UNSPECIFIED",
|
|
"pluginType": "GEMINI",
|
|
},
|
|
}
|
|
|
|
rawBody, errMarshal := json.Marshal(requestBody)
|
|
if errMarshal != nil {
|
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
|
}
|
|
|
|
maxAttempts := 5
|
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
|
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
|
|
|
reqCtx := ctx
|
|
var cancel context.CancelFunc
|
|
if reqCtx == nil {
|
|
reqCtx = context.Background()
|
|
}
|
|
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
|
|
|
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
|
|
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
|
if errRequest != nil {
|
|
cancel()
|
|
return "", fmt.Errorf("create request: %w", errRequest)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("User-Agent", APIUserAgent)
|
|
req.Header.Set("X-Goog-Api-Client", APIClient)
|
|
req.Header.Set("Client-Metadata", ClientMetadata)
|
|
|
|
resp, errDo := o.httpClient.Do(req)
|
|
if errDo != nil {
|
|
cancel()
|
|
return "", fmt.Errorf("execute request: %w", errDo)
|
|
}
|
|
|
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
|
if errClose := resp.Body.Close(); errClose != nil {
|
|
log.Errorf("close body error: %v", errClose)
|
|
}
|
|
cancel()
|
|
|
|
if errRead != nil {
|
|
return "", fmt.Errorf("read response: %w", errRead)
|
|
}
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
var data map[string]any
|
|
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
|
}
|
|
|
|
if done, okDone := data["done"].(bool); okDone && done {
|
|
projectID := ""
|
|
if responseData, okResp := data["response"].(map[string]any); okResp {
|
|
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
|
case map[string]any:
|
|
if id, okID := projectValue["id"].(string); okID {
|
|
projectID = strings.TrimSpace(id)
|
|
}
|
|
case string:
|
|
projectID = strings.TrimSpace(projectValue)
|
|
}
|
|
}
|
|
|
|
if projectID != "" {
|
|
log.Infof("Successfully fetched project_id: %s", projectID)
|
|
return projectID, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("no project_id in response")
|
|
}
|
|
|
|
time.Sleep(2 * time.Second)
|
|
continue
|
|
}
|
|
|
|
responsePreview := strings.TrimSpace(string(bodyBytes))
|
|
if len(responsePreview) > 500 {
|
|
responsePreview = responsePreview[:500]
|
|
}
|
|
|
|
responseErr := responsePreview
|
|
if len(responseErr) > 200 {
|
|
responseErr = responseErr[:200]
|
|
}
|
|
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
|
}
|
|
|
|
return "", nil
|
|
}
|