Refactor user onboarding and token management

- Enhanced the `Client` initialization to include `TokenStorage` and configuration parameters.
- Replaced `SaveTokenToFile` with a `Client` method for better encapsulation.
- Improved onboarding flow with project ID verification and API enablement checks.
- Refactored token saving logic to ensure proper handling of directory creation and JSON encoding.
- Removed unused file-related code in `auth.go` for improved maintainability.
This commit is contained in:
Luis Pater
2025-07-04 07:53:07 +08:00
parent 79acea5976
commit 57ead9a4bc
4 changed files with 155 additions and 74 deletions

View File

@@ -6,12 +6,16 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
@@ -108,17 +112,41 @@ type Client struct {
ProjectID string
RequestMutex sync.Mutex
Email string
tokenStorage *auth.TokenStorage
cfg *config.Config
}
// NewClient creates a new CLI API client.
func NewClient(httpClient *http.Client) *Client {
func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config) *Client {
return &Client{
httpClient: httpClient,
httpClient: httpClient,
tokenStorage: ts,
cfg: cfg,
}
}
func (c *Client) SetProjectID(projectID string) {
c.tokenStorage.ProjectID = projectID
}
func (c *Client) SetIsAuto(auto bool) {
c.tokenStorage.Auto = auto
}
func (c *Client) SetIsChecked(checked bool) {
c.tokenStorage.Checked = checked
}
func (c *Client) IsChecked() bool {
return c.tokenStorage.Checked
}
func (c *Client) IsAuto() bool {
return c.tokenStorage.Auto
}
// SetupUser performs the initial user onboarding and setup.
func (c *Client) SetupUser(ctx context.Context, email, projectID string, auto bool) (string, error) {
func (c *Client) SetupUser(ctx context.Context, email, projectID string) (string, error) {
c.Email = email
log.Info("Performing user onboarding...")
@@ -172,29 +200,32 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string, auto bo
return projectID, fmt.Errorf("failed to start user onboarding, need define a project id")
}
var lroResp map[string]interface{}
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
if err != nil {
return projectID, fmt.Errorf("failed to start user onboarding: %w", err)
}
for {
var lroResp map[string]interface{}
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
if err != nil {
return projectID, fmt.Errorf("failed to start user onboarding: %w", err)
}
// a, _ := json.Marshal(&lroResp)
// log.Debug(string(a))
// a, _ = json.Marshal(&lroResp)
// log.Debug(string(a))
// 3. Poll Long-Running Operation (LRO)
if done, doneOk := lroResp["done"].(bool); doneOk && done {
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
if projectID != "" && !auto {
c.ProjectID = projectID
log.Infof("Onboarding complete. Project ID: %s is being enforced. Maybe you need to enable 'Gemini for Google Cloud' once in Google Cloud Console.", c.ProjectID)
} else {
c.ProjectID = project["id"].(string)
// 3. Poll Long-Running Operation (LRO)
done, doneOk := lroResp["done"].(bool)
if doneOk && done {
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
if projectID != "" {
c.ProjectID = projectID
} else {
c.ProjectID = project["id"].(string)
}
log.Infof("Onboarding complete. Using Project ID: %s", c.ProjectID)
return c.ProjectID, nil
}
return c.ProjectID, nil
} else {
log.Println("Onboarding in progress, waiting 5 seconds...")
time.Sleep(5 * time.Second)
}
}
return projectID, fmt.Errorf("failed to get operation name from onboarding response: %v", lroResp)
}
// makeAPIRequest handles making requests to the CLI API endpoints.
@@ -332,7 +363,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
byteRequestBody, _ := json.Marshal(requestBody)
// log.Debug(string(rawJson))
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
@@ -396,6 +427,57 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return dataChan, errChan
}
func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
ctx, cancel := context.WithCancel(context.Background())
defer func() {
c.RequestMutex.Unlock()
cancel()
}()
c.RequestMutex.Lock()
requestBody := `{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`
requestBody = fmt.Sprintf(requestBody, c.tokenStorage.ProjectID)
// log.Debug(requestBody)
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody))
if err != nil {
if err.StatusCode == 403 {
errJson := err.Error.Error()
codeResult := gjson.Get(errJson, "error.code")
if codeResult.Exists() && codeResult.Type == gjson.Number {
if codeResult.Int() == 403 {
activationUrlResult := gjson.Get(errJson, "error.details.0.metadata.activationUrl")
if activationUrlResult.Exists() {
log.Warnf(
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
activationUrlResult.String(),
os.Args[0],
c.tokenStorage.ProjectID,
)
}
}
}
return false, nil
}
return false, err.Error
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
}
if scannerErr := scanner.Err(); scannerErr != nil {
_ = stream.Close()
} else {
_ = stream.Close()
}
return true, nil
}
func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
@@ -426,6 +508,27 @@ func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
return &project, nil
}
func (c *Client) SaveTokenToFile() error {
if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("%s-%s.json", c.tokenStorage.Email, c.tokenStorage.ProjectID))
log.Infof("Saving credentials to %s", fileName)
f, err := os.Create(fileName)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(c.tokenStorage); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
// getClientMetadata returns metadata about the client environment.
func getClientMetadata() map[string]string {
return map[string]string{
@@ -452,9 +555,9 @@ func getUserAgent() string {
// getPlatform returns the OS and architecture in the format expected by the API.
func getPlatform() string {
os := runtime.GOOS
goOS := runtime.GOOS
arch := runtime.GOARCH
switch os {
switch goOS {
case "darwin":
return fmt.Sprintf("DARWIN_%s", strings.ToUpper(arch))
case "linux":