From 9e5968521256f54ebfee25d804854109da68627e Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Fri, 23 Jan 2026 21:35:37 +0800 Subject: [PATCH] refactor(auth): implement Antigravity AuthService in internal/auth --- internal/auth/antigravity/auth.go | 313 ++++++++++++++++++++++++++++++ 1 file changed, 313 insertions(+) create mode 100644 internal/auth/antigravity/auth.go diff --git a/internal/auth/antigravity/auth.go b/internal/auth/antigravity/auth.go new file mode 100644 index 00000000..80dab17d --- /dev/null +++ b/internal/auth/antigravity/auth.go @@ -0,0 +1,313 @@ +// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider. +package antigravity + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +// TokenResponse represents OAuth token response from Google +type TokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_token"` + TokenType string `json:"token_type"` +} + +// userInfo represents Google user profile +type userInfo struct { + Email string `json:"email"` +} + +// AntigravityAuth handles Antigravity OAuth authentication +type AntigravityAuth struct { + httpClient *http.Client +} + +// NewAntigravityAuth creates a new Antigravity auth service +func NewAntigravityAuth(cfg *config.Config) *AntigravityAuth { + return &AntigravityAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}), + } +} + +// BuildAuthURL generates the OAuth authorization URL +func (o *AntigravityAuth) BuildAuthURL(state string) string { + params := url.Values{} + params.Set("access_type", "offline") + params.Set("client_id", ClientID) + params.Set("prompt", "consent") + params.Set("redirect_uri", fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)) + params.Set("response_type", "code") + params.Set("scope", strings.Join(Scopes, " ")) + params.Set("state", state) + return AuthEndpoint + "?" + params.Encode() +} + +// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens +func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) { + data := url.Values{} + data.Set("code", code) + data.Set("client_id", ClientID) + data.Set("client_secret", ClientSecret) + data.Set("redirect_uri", redirectURI) + data.Set("grant_type", "authorization_code") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + return nil, errDo + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity token exchange: close body error: %v", errClose) + } + }() + + var token TokenResponse + if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil { + return nil, errDecode + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode) + } + return &token, nil +} + +// FetchUserInfo retrieves user email from Google +func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) { + if strings.TrimSpace(accessToken) == "" { + return "", nil + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + return "", errDo + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity userinfo: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return "", nil + } + var info userInfo + if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil { + return "", errDecode + } + return info.Email, nil +} + +// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist +func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) { + loadReqBody := map[string]any{ + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + } + + rawBody, errMarshal := json.Marshal(loadReqBody) + if errMarshal != nil { + return "", fmt.Errorf("marshal request body: %w", errMarshal) + } + + endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", APIUserAgent) + req.Header.Set("X-Goog-Api-Client", APIClient) + req.Header.Set("Client-Metadata", ClientMetadata) + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + return "", fmt.Errorf("execute request: %w", errDo) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose) + } + }() + + bodyBytes, errRead := io.ReadAll(resp.Body) + if errRead != nil { + return "", fmt.Errorf("read response: %w", errRead) + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) + } + + var loadResp map[string]any + if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil { + return "", fmt.Errorf("decode response: %w", errDecode) + } + + // Extract projectID from response + projectID := "" + if id, ok := loadResp["cloudaicompanionProject"].(string); ok { + projectID = strings.TrimSpace(id) + } + if projectID == "" { + if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok { + if id, okID := projectMap["id"].(string); okID { + projectID = strings.TrimSpace(id) + } + } + } + + if projectID == "" { + tierID := "legacy-tier" + if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { + for _, rawTier := range tiers { + tier, okTier := rawTier.(map[string]any) + if !okTier { + continue + } + if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { + if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { + tierID = strings.TrimSpace(id) + break + } + } + } + } + + projectID, err = o.OnboardUser(ctx, accessToken, tierID) + if err != nil { + return "", err + } + return projectID, nil + } + + return projectID, nil +} + +// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion +func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) { + log.Infof("Antigravity: onboarding user with tier: %s", tierID) + requestBody := map[string]any{ + "tierId": tierID, + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + } + + rawBody, errMarshal := json.Marshal(requestBody) + if errMarshal != nil { + return "", fmt.Errorf("marshal request body: %w", errMarshal) + } + + maxAttempts := 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + log.Debugf("Polling attempt %d/%d", attempt, maxAttempts) + + reqCtx := ctx + var cancel context.CancelFunc + if reqCtx == nil { + reqCtx = context.Background() + } + reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second) + + endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion) + req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if errRequest != nil { + cancel() + return "", fmt.Errorf("create request: %w", errRequest) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", APIUserAgent) + req.Header.Set("X-Goog-Api-Client", APIClient) + req.Header.Set("Client-Metadata", ClientMetadata) + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + cancel() + return "", fmt.Errorf("execute request: %w", errDo) + } + + bodyBytes, errRead := io.ReadAll(resp.Body) + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("close body error: %v", errClose) + } + cancel() + + if errRead != nil { + return "", fmt.Errorf("read response: %w", errRead) + } + + if resp.StatusCode == http.StatusOK { + var data map[string]any + if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil { + return "", fmt.Errorf("decode response: %w", errDecode) + } + + if done, okDone := data["done"].(bool); okDone && done { + projectID := "" + if responseData, okResp := data["response"].(map[string]any); okResp { + switch projectValue := responseData["cloudaicompanionProject"].(type) { + case map[string]any: + if id, okID := projectValue["id"].(string); okID { + projectID = strings.TrimSpace(id) + } + case string: + projectID = strings.TrimSpace(projectValue) + } + } + + if projectID != "" { + log.Infof("Successfully fetched project_id: %s", projectID) + return projectID, nil + } + + return "", fmt.Errorf("no project_id in response") + } + + time.Sleep(2 * time.Second) + continue + } + + responsePreview := strings.TrimSpace(string(bodyBytes)) + if len(responsePreview) > 500 { + responsePreview = responsePreview[:500] + } + + responseErr := responsePreview + if len(responseErr) > 200 { + responseErr = responseErr[:200] + } + return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr) + } + + return "", nil +}