diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index b59acacf..210da57f 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -382,7 +382,7 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie // Call loadCodeAssist to get the project loadReqBody := map[string]any{ "metadata": map[string]string{ - "ideType": "IDE_UNSPECIFIED", + "ideType": "ANTIGRAVITY", "platform": "PLATFORM_UNSPECIFIED", "pluginType": "GEMINI", }, @@ -442,8 +442,134 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie } if projectID == "" { - return "", fmt.Errorf("no cloudaicompanionProject in response") + tierID := "legacy-tier" + if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { + for _, rawTier := range tiers { + tier, okTier := rawTier.(map[string]any) + if !okTier { + continue + } + if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { + if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { + tierID = strings.TrimSpace(id) + break + } + } + } + } + + projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient) + if err != nil { + return "", err + } + return projectID, nil } return projectID, nil } + +// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion. +// It returns an empty string when the operation times out or completes without a project ID. +func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) { + if httpClient == nil { + httpClient = http.DefaultClient + } + fmt.Println("Antigravity: onboarding user...", tierID) + requestBody := map[string]any{ + "tierId": tierID, + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + } + + rawBody, errMarshal := json.Marshal(requestBody) + if errMarshal != nil { + return "", fmt.Errorf("marshal request body: %w", errMarshal) + } + + maxAttempts := 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + log.Debugf("Polling attempt %d/%d", attempt, maxAttempts) + + reqCtx := ctx + var cancel context.CancelFunc + if reqCtx == nil { + reqCtx = context.Background() + } + reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second) + + endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion) + req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if errRequest != nil { + cancel() + return "", fmt.Errorf("create request: %w", errRequest) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", antigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) + req.Header.Set("Client-Metadata", antigravityClientMetadata) + + resp, errDo := httpClient.Do(req) + if errDo != nil { + cancel() + return "", fmt.Errorf("execute request: %w", errDo) + } + + bodyBytes, errRead := io.ReadAll(resp.Body) + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("close body error: %v", errClose) + } + cancel() + + if errRead != nil { + return "", fmt.Errorf("read response: %w", errRead) + } + + if resp.StatusCode == http.StatusOK { + var data map[string]any + if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil { + return "", fmt.Errorf("decode response: %w", errDecode) + } + + if done, okDone := data["done"].(bool); okDone && done { + projectID := "" + if responseData, okResp := data["response"].(map[string]any); okResp { + switch projectValue := responseData["cloudaicompanionProject"].(type) { + case map[string]any: + if id, okID := projectValue["id"].(string); okID { + projectID = strings.TrimSpace(id) + } + case string: + projectID = strings.TrimSpace(projectValue) + } + } + + if projectID != "" { + log.Infof("Successfully fetched project_id: %s", projectID) + return projectID, nil + } + + return "", fmt.Errorf("no project_id in response") + } + + time.Sleep(2 * time.Second) + continue + } + + responsePreview := strings.TrimSpace(string(bodyBytes)) + if len(responsePreview) > 500 { + responsePreview = responsePreview[:500] + } + + responseErr := responsePreview + if len(responseErr) > 200 { + responseErr = responseErr[:200] + } + return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr) + } + + return "", nil +}