mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Refactor user onboarding and token management
- Enhanced the `Client` initialization to include `TokenStorage` and configuration parameters. - Replaced `SaveTokenToFile` with a `Client` method for better encapsulation. - Improved onboarding flow with project ID verification and API enablement checks. - Refactored token saving logic to ensure proper handling of directory creation and JSON encoding. - Removed unused file-related code in `auth.go` for improved maintainability.
This commit is contained in:
@@ -13,8 +13,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
@@ -40,6 +38,7 @@ type TokenStorage struct {
|
||||
ProjectID string `json:"project_id"`
|
||||
Email string `json:"email"`
|
||||
Auto bool `json:"auto"`
|
||||
Checked bool `json:"checked"`
|
||||
}
|
||||
|
||||
// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens.
|
||||
@@ -96,7 +95,7 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||
}
|
||||
newTs, errSaveTokenToFile := saveTokenToFile(ctx, conf, token, ts.ProjectID, cfg.AuthDir)
|
||||
newTs, errSaveTokenToFile := createTokenStorage(ctx, conf, token, ts.ProjectID)
|
||||
if errSaveTokenToFile != nil {
|
||||
log.Errorf("Warning: failed to save token to file: %v", err)
|
||||
return nil, errSaveTokenToFile
|
||||
@@ -111,8 +110,8 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
||||
return conf.Client(ctx, token), nil
|
||||
}
|
||||
|
||||
// saveTokenToFile saves a token to the local credentials file.
|
||||
func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID, authDir string) (*TokenStorage, error) {
|
||||
// createTokenStorage creates a token storage.
|
||||
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
|
||||
httpClient := config.Client(ctx, token)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if err != nil {
|
||||
@@ -160,46 +159,9 @@ func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.T
|
||||
Email: emailResult.String(),
|
||||
}
|
||||
|
||||
if err = os.MkdirAll(authDir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
if projectID != "" {
|
||||
log.Infof("Saving credentials to %s", filepath.Join(authDir, fmt.Sprintf("%s-%s.json", emailResult.String(), projectID)))
|
||||
|
||||
f, errCreate := os.Create(filepath.Join(authDir, fmt.Sprintf("%s-%s.json", emailResult.String(), projectID)))
|
||||
if errCreate != nil {
|
||||
return nil, fmt.Errorf("failed to create token file: %w", errCreate)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return nil, fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
}
|
||||
return &ts, nil
|
||||
}
|
||||
|
||||
func SaveTokenToFile(ts *TokenStorage, cfg *config.Config, auto bool) error {
|
||||
ts.Auto = auto
|
||||
fileName := filepath.Join(cfg.AuthDir, fmt.Sprintf("%s-%s.json", ts.Email, ts.ProjectID))
|
||||
log.Infof("Saving credentials to %s", fileName)
|
||||
f, err := os.Create(fileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTokenFromWeb starts a local server to handle the OAuth2 flow.
|
||||
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
@@ -235,7 +197,7 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
|
||||
|
||||
// Open the authorization URL in the user's browser.
|
||||
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n\n", authURL)
|
||||
log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n", authURL)
|
||||
|
||||
err := open.Run(authURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,12 +6,16 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/oauth2"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -108,17 +112,41 @@ type Client struct {
|
||||
ProjectID string
|
||||
RequestMutex sync.Mutex
|
||||
Email string
|
||||
tokenStorage *auth.TokenStorage
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewClient creates a new CLI API client.
|
||||
func NewClient(httpClient *http.Client) *Client {
|
||||
func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config) *Client {
|
||||
return &Client{
|
||||
httpClient: httpClient,
|
||||
httpClient: httpClient,
|
||||
tokenStorage: ts,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SetProjectID(projectID string) {
|
||||
c.tokenStorage.ProjectID = projectID
|
||||
}
|
||||
|
||||
func (c *Client) SetIsAuto(auto bool) {
|
||||
c.tokenStorage.Auto = auto
|
||||
}
|
||||
|
||||
func (c *Client) SetIsChecked(checked bool) {
|
||||
c.tokenStorage.Checked = checked
|
||||
}
|
||||
|
||||
func (c *Client) IsChecked() bool {
|
||||
return c.tokenStorage.Checked
|
||||
}
|
||||
|
||||
func (c *Client) IsAuto() bool {
|
||||
return c.tokenStorage.Auto
|
||||
}
|
||||
|
||||
// SetupUser performs the initial user onboarding and setup.
|
||||
func (c *Client) SetupUser(ctx context.Context, email, projectID string, auto bool) (string, error) {
|
||||
func (c *Client) SetupUser(ctx context.Context, email, projectID string) (string, error) {
|
||||
c.Email = email
|
||||
log.Info("Performing user onboarding...")
|
||||
|
||||
@@ -172,29 +200,32 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string, auto bo
|
||||
return projectID, fmt.Errorf("failed to start user onboarding, need define a project id")
|
||||
}
|
||||
|
||||
var lroResp map[string]interface{}
|
||||
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
|
||||
if err != nil {
|
||||
return projectID, fmt.Errorf("failed to start user onboarding: %w", err)
|
||||
}
|
||||
for {
|
||||
var lroResp map[string]interface{}
|
||||
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
|
||||
if err != nil {
|
||||
return projectID, fmt.Errorf("failed to start user onboarding: %w", err)
|
||||
}
|
||||
// a, _ := json.Marshal(&lroResp)
|
||||
// log.Debug(string(a))
|
||||
|
||||
// a, _ = json.Marshal(&lroResp)
|
||||
// log.Debug(string(a))
|
||||
|
||||
// 3. Poll Long-Running Operation (LRO)
|
||||
if done, doneOk := lroResp["done"].(bool); doneOk && done {
|
||||
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
|
||||
if projectID != "" && !auto {
|
||||
c.ProjectID = projectID
|
||||
log.Infof("Onboarding complete. Project ID: %s is being enforced. Maybe you need to enable 'Gemini for Google Cloud' once in Google Cloud Console.", c.ProjectID)
|
||||
} else {
|
||||
c.ProjectID = project["id"].(string)
|
||||
// 3. Poll Long-Running Operation (LRO)
|
||||
done, doneOk := lroResp["done"].(bool)
|
||||
if doneOk && done {
|
||||
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
|
||||
if projectID != "" {
|
||||
c.ProjectID = projectID
|
||||
} else {
|
||||
c.ProjectID = project["id"].(string)
|
||||
}
|
||||
log.Infof("Onboarding complete. Using Project ID: %s", c.ProjectID)
|
||||
return c.ProjectID, nil
|
||||
}
|
||||
return c.ProjectID, nil
|
||||
} else {
|
||||
log.Println("Onboarding in progress, waiting 5 seconds...")
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
return projectID, fmt.Errorf("failed to get operation name from onboarding response: %v", lroResp)
|
||||
}
|
||||
|
||||
// makeAPIRequest handles making requests to the CLI API endpoints.
|
||||
@@ -332,7 +363,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
||||
|
||||
byteRequestBody, _ := json.Marshal(requestBody)
|
||||
|
||||
// log.Debug(string(rawJson))
|
||||
// log.Debug(string(byteRequestBody))
|
||||
|
||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||
if reasoningEffortResult.String() == "none" {
|
||||
@@ -396,6 +427,57 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
c.RequestMutex.Unlock()
|
||||
cancel()
|
||||
}()
|
||||
c.RequestMutex.Lock()
|
||||
|
||||
requestBody := `{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`
|
||||
requestBody = fmt.Sprintf(requestBody, c.tokenStorage.ProjectID)
|
||||
// log.Debug(requestBody)
|
||||
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody))
|
||||
if err != nil {
|
||||
if err.StatusCode == 403 {
|
||||
errJson := err.Error.Error()
|
||||
codeResult := gjson.Get(errJson, "error.code")
|
||||
if codeResult.Exists() && codeResult.Type == gjson.Number {
|
||||
if codeResult.Int() == 403 {
|
||||
activationUrlResult := gjson.Get(errJson, "error.details.0.metadata.activationUrl")
|
||||
if activationUrlResult.Exists() {
|
||||
log.Warnf(
|
||||
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
||||
activationUrlResult.String(),
|
||||
os.Args[0],
|
||||
c.tokenStorage.ProjectID,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
return false, err.Error
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if scannerErr := scanner.Err(); scannerErr != nil {
|
||||
_ = stream.Close()
|
||||
} else {
|
||||
_ = stream.Close()
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
|
||||
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
||||
@@ -426,6 +508,27 @@ func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
|
||||
return &project, nil
|
||||
}
|
||||
|
||||
func (c *Client) SaveTokenToFile() error {
|
||||
if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("%s-%s.json", c.tokenStorage.Email, c.tokenStorage.ProjectID))
|
||||
log.Infof("Saving credentials to %s", fileName)
|
||||
f, err := os.Create(fileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(c.tokenStorage); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getClientMetadata returns metadata about the client environment.
|
||||
func getClientMetadata() map[string]string {
|
||||
return map[string]string{
|
||||
@@ -452,9 +555,9 @@ func getUserAgent() string {
|
||||
|
||||
// getPlatform returns the OS and architecture in the format expected by the API.
|
||||
func getPlatform() string {
|
||||
os := runtime.GOOS
|
||||
goOS := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
switch os {
|
||||
switch goOS {
|
||||
case "darwin":
|
||||
return fmt.Sprintf("DARWIN_%s", strings.ToUpper(arch))
|
||||
case "linux":
|
||||
|
||||
@@ -28,8 +28,8 @@ func DoLogin(cfg *config.Config, projectID string) {
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
// 3. Initialize CLI Client
|
||||
cliClient := client.NewClient(httpClient)
|
||||
projectID, err = cliClient.SetupUser(clientCtx, ts.Email, projectID, ts.Auto)
|
||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
||||
projectID, err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
|
||||
if err != nil {
|
||||
if err.Error() == "failed to start user onboarding, need define a project id" {
|
||||
log.Error("failed to start user onboarding")
|
||||
@@ -53,10 +53,26 @@ func DoLogin(cfg *config.Config, projectID string) {
|
||||
}
|
||||
} else {
|
||||
auto := ts.ProjectID == ""
|
||||
ts.ProjectID = projectID
|
||||
err = auth.SaveTokenToFile(&ts, cfg, auto)
|
||||
cliClient.SetProjectID(projectID)
|
||||
cliClient.SetIsAuto(auto)
|
||||
|
||||
if !cliClient.IsChecked() && !cliClient.IsAuto() {
|
||||
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
|
||||
if checkErr != nil {
|
||||
log.Fatalf("failed to check cloud api is enabled: %v", checkErr)
|
||||
return
|
||||
}
|
||||
cliClient.SetIsChecked(isChecked)
|
||||
}
|
||||
|
||||
if !cliClient.IsChecked() && !cliClient.IsAuto() {
|
||||
return
|
||||
}
|
||||
|
||||
err = cliClient.SaveTokenToFile()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func StartService(cfg *config.Config) {
|
||||
}
|
||||
|
||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
||||
log.Debugf(path)
|
||||
log.Debugf("Loading token from: %s", path)
|
||||
f, errOpen := os.Open(path)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
@@ -56,8 +56,8 @@ func StartService(cfg *config.Config) {
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
// 3. Initialize CLI Client
|
||||
cliClient := client.NewClient(httpClient)
|
||||
if _, err = cliClient.SetupUser(clientCtx, ts.Email, ts.ProjectID, ts.Auto); err != nil {
|
||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
||||
if _, err = cliClient.SetupUser(clientCtx, ts.Email, ts.ProjectID); err != nil {
|
||||
if err.Error() == "failed to start user onboarding, need define a project id" {
|
||||
log.Error("failed to start user onboarding")
|
||||
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
|
||||
|
||||
Reference in New Issue
Block a user