From 57ead9a4bc095a4810a3c22ff082fc6feade33f9 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Fri, 4 Jul 2025 07:53:07 +0800 Subject: [PATCH] Refactor user onboarding and token management - Enhanced the `Client` initialization to include `TokenStorage` and configuration parameters. - Replaced `SaveTokenToFile` with a `Client` method for better encapsulation. - Improved onboarding flow with project ID verification and API enablement checks. - Refactored token saving logic to ensure proper handling of directory creation and JSON encoding. - Removed unused file-related code in `auth.go` for improved maintainability. --- internal/auth/auth.go | 48 ++---------- internal/client/client.go | 151 ++++++++++++++++++++++++++++++++------ internal/cmd/login.go | 24 +++++- internal/cmd/run.go | 6 +- 4 files changed, 155 insertions(+), 74 deletions(-) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 44092f28..979b403f 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -13,8 +13,6 @@ import ( "net" "net/http" "net/url" - "os" - "path/filepath" "time" "github.com/skratchdot/open-golang/open" @@ -40,6 +38,7 @@ type TokenStorage struct { ProjectID string `json:"project_id"` Email string `json:"email"` Auto bool `json:"auto"` + Checked bool `json:"checked"` } // GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens. @@ -96,7 +95,7 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C if err != nil { return nil, fmt.Errorf("failed to get token from web: %w", err) } - newTs, errSaveTokenToFile := saveTokenToFile(ctx, conf, token, ts.ProjectID, cfg.AuthDir) + newTs, errSaveTokenToFile := createTokenStorage(ctx, conf, token, ts.ProjectID) if errSaveTokenToFile != nil { log.Errorf("Warning: failed to save token to file: %v", err) return nil, errSaveTokenToFile @@ -111,8 +110,8 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C return conf.Client(ctx, token), nil } -// saveTokenToFile saves a token to the local credentials file. -func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID, authDir string) (*TokenStorage, error) { +// createTokenStorage creates a token storage. +func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) { httpClient := config.Client(ctx, token) req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if err != nil { @@ -160,46 +159,9 @@ func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.T Email: emailResult.String(), } - if err = os.MkdirAll(authDir, 0700); err != nil { - return nil, fmt.Errorf("failed to create directory: %w", err) - } - - if projectID != "" { - log.Infof("Saving credentials to %s", filepath.Join(authDir, fmt.Sprintf("%s-%s.json", emailResult.String(), projectID))) - - f, errCreate := os.Create(filepath.Join(authDir, fmt.Sprintf("%s-%s.json", emailResult.String(), projectID))) - if errCreate != nil { - return nil, fmt.Errorf("failed to create token file: %w", errCreate) - } - defer func() { - _ = f.Close() - }() - - if err = json.NewEncoder(f).Encode(ts); err != nil { - return nil, fmt.Errorf("failed to write token to file: %w", err) - } - } return &ts, nil } -func SaveTokenToFile(ts *TokenStorage, cfg *config.Config, auto bool) error { - ts.Auto = auto - fileName := filepath.Join(cfg.AuthDir, fmt.Sprintf("%s-%s.json", ts.Email, ts.ProjectID)) - log.Infof("Saving credentials to %s", fileName) - f, err := os.Create(fileName) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - _ = f.Close() - }() - - if err = json.NewEncoder(f).Encode(ts); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) - } - return nil -} - // getTokenFromWeb starts a local server to handle the OAuth2 flow. func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. @@ -235,7 +197,7 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, // Open the authorization URL in the user's browser. authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n\n", authURL) + log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n", authURL) err := open.Run(authURL) if err != nil { diff --git a/internal/client/client.go b/internal/client/client.go index 50194314..c42e787a 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -6,12 +6,16 @@ import ( "context" "encoding/json" "fmt" + "github.com/luispater/CLIProxyAPI/internal/auth" + "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "golang.org/x/oauth2" "io" "net/http" + "os" + "path/filepath" "runtime" "strings" "sync" @@ -108,17 +112,41 @@ type Client struct { ProjectID string RequestMutex sync.Mutex Email string + tokenStorage *auth.TokenStorage + cfg *config.Config } // NewClient creates a new CLI API client. -func NewClient(httpClient *http.Client) *Client { +func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config) *Client { return &Client{ - httpClient: httpClient, + httpClient: httpClient, + tokenStorage: ts, + cfg: cfg, } } +func (c *Client) SetProjectID(projectID string) { + c.tokenStorage.ProjectID = projectID +} + +func (c *Client) SetIsAuto(auto bool) { + c.tokenStorage.Auto = auto +} + +func (c *Client) SetIsChecked(checked bool) { + c.tokenStorage.Checked = checked +} + +func (c *Client) IsChecked() bool { + return c.tokenStorage.Checked +} + +func (c *Client) IsAuto() bool { + return c.tokenStorage.Auto +} + // SetupUser performs the initial user onboarding and setup. -func (c *Client) SetupUser(ctx context.Context, email, projectID string, auto bool) (string, error) { +func (c *Client) SetupUser(ctx context.Context, email, projectID string) (string, error) { c.Email = email log.Info("Performing user onboarding...") @@ -172,29 +200,32 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string, auto bo return projectID, fmt.Errorf("failed to start user onboarding, need define a project id") } - var lroResp map[string]interface{} - err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp) - if err != nil { - return projectID, fmt.Errorf("failed to start user onboarding: %w", err) - } + for { + var lroResp map[string]interface{} + err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp) + if err != nil { + return projectID, fmt.Errorf("failed to start user onboarding: %w", err) + } + // a, _ := json.Marshal(&lroResp) + // log.Debug(string(a)) - // a, _ = json.Marshal(&lroResp) - // log.Debug(string(a)) - - // 3. Poll Long-Running Operation (LRO) - if done, doneOk := lroResp["done"].(bool); doneOk && done { - if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk { - if projectID != "" && !auto { - c.ProjectID = projectID - log.Infof("Onboarding complete. Project ID: %s is being enforced. Maybe you need to enable 'Gemini for Google Cloud' once in Google Cloud Console.", c.ProjectID) - } else { - c.ProjectID = project["id"].(string) + // 3. Poll Long-Running Operation (LRO) + done, doneOk := lroResp["done"].(bool) + if doneOk && done { + if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk { + if projectID != "" { + c.ProjectID = projectID + } else { + c.ProjectID = project["id"].(string) + } log.Infof("Onboarding complete. Using Project ID: %s", c.ProjectID) + return c.ProjectID, nil } - return c.ProjectID, nil + } else { + log.Println("Onboarding in progress, waiting 5 seconds...") + time.Sleep(5 * time.Second) } } - return projectID, fmt.Errorf("failed to get operation name from onboarding response: %v", lroResp) } // makeAPIRequest handles making requests to the CLI API endpoints. @@ -332,7 +363,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st byteRequestBody, _ := json.Marshal(requestBody) - // log.Debug(string(rawJson)) + // log.Debug(string(byteRequestBody)) reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") if reasoningEffortResult.String() == "none" { @@ -396,6 +427,57 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st return dataChan, errChan } +func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + c.RequestMutex.Unlock() + cancel() + }() + c.RequestMutex.Lock() + + requestBody := `{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}` + requestBody = fmt.Sprintf(requestBody, c.tokenStorage.ProjectID) + // log.Debug(requestBody) + stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody)) + if err != nil { + if err.StatusCode == 403 { + errJson := err.Error.Error() + codeResult := gjson.Get(errJson, "error.code") + if codeResult.Exists() && codeResult.Type == gjson.Number { + if codeResult.Int() == 403 { + activationUrlResult := gjson.Get(errJson, "error.details.0.metadata.activationUrl") + if activationUrlResult.Exists() { + log.Warnf( + "\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s", + activationUrlResult.String(), + os.Args[0], + c.tokenStorage.ProjectID, + ) + } + } + } + return false, nil + } + return false, err.Error + } + + scanner := bufio.NewScanner(stream) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + } + + if scannerErr := scanner.Err(); scannerErr != nil { + _ = stream.Close() + } else { + _ = stream.Close() + } + + return true, nil +} + func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) { token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil) @@ -426,6 +508,27 @@ func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) { return &project, nil } +func (c *Client) SaveTokenToFile() error { + if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("%s-%s.json", c.tokenStorage.Email, c.tokenStorage.ProjectID)) + log.Infof("Saving credentials to %s", fileName) + f, err := os.Create(fileName) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + if err = json.NewEncoder(f).Encode(c.tokenStorage); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} + // getClientMetadata returns metadata about the client environment. func getClientMetadata() map[string]string { return map[string]string{ @@ -452,9 +555,9 @@ func getUserAgent() string { // getPlatform returns the OS and architecture in the format expected by the API. func getPlatform() string { - os := runtime.GOOS + goOS := runtime.GOOS arch := runtime.GOARCH - switch os { + switch goOS { case "darwin": return fmt.Sprintf("DARWIN_%s", strings.ToUpper(arch)) case "linux": diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 66a04a20..49809a4b 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -28,8 +28,8 @@ func DoLogin(cfg *config.Config, projectID string) { log.Info("Authentication successful.") // 3. Initialize CLI Client - cliClient := client.NewClient(httpClient) - projectID, err = cliClient.SetupUser(clientCtx, ts.Email, projectID, ts.Auto) + cliClient := client.NewClient(httpClient, &ts, cfg) + projectID, err = cliClient.SetupUser(clientCtx, ts.Email, projectID) if err != nil { if err.Error() == "failed to start user onboarding, need define a project id" { log.Error("failed to start user onboarding") @@ -53,10 +53,26 @@ func DoLogin(cfg *config.Config, projectID string) { } } else { auto := ts.ProjectID == "" - ts.ProjectID = projectID - err = auth.SaveTokenToFile(&ts, cfg, auto) + cliClient.SetProjectID(projectID) + cliClient.SetIsAuto(auto) + + if !cliClient.IsChecked() && !cliClient.IsAuto() { + isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled() + if checkErr != nil { + log.Fatalf("failed to check cloud api is enabled: %v", checkErr) + return + } + cliClient.SetIsChecked(isChecked) + } + + if !cliClient.IsChecked() && !cliClient.IsAuto() { + return + } + + err = cliClient.SaveTokenToFile() if err != nil { log.Fatal(err) + return } } } diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 52500037..133a47da 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -33,7 +33,7 @@ func StartService(cfg *config.Config) { } if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") { - log.Debugf(path) + log.Debugf("Loading token from: %s", path) f, errOpen := os.Open(path) if errOpen != nil { return errOpen @@ -56,8 +56,8 @@ func StartService(cfg *config.Config) { log.Info("Authentication successful.") // 3. Initialize CLI Client - cliClient := client.NewClient(httpClient) - if _, err = cliClient.SetupUser(clientCtx, ts.Email, ts.ProjectID, ts.Auto); err != nil { + cliClient := client.NewClient(httpClient, &ts, cfg) + if _, err = cliClient.SetupUser(clientCtx, ts.Email, ts.ProjectID); err != nil { if err.Error() == "failed to start user onboarding, need define a project id" { log.Error("failed to start user onboarding") project, errGetProjectList := cliClient.GetProjectList(clientCtx)