mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 13:00:52 +08:00
Refactor authentication and service initialization code
- Moved login and service management logic to `internal/cmd` package (`login.go` and `run.go`). - Introduced `DoLogin` and `StartService` functions for modularity. - Enhanced error handling by using structured `ErrorMessage` in `Client`. - Improved token file saving process and added project-specific token identification. - Updated API handlers to handle more detailed error responses, including status codes.
This commit is contained in:
@@ -2,23 +2,14 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api"
|
"github.com/luispater/CLIProxyAPI/internal/cmd"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"io/fs"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogFormatter struct {
|
type LogFormatter struct {
|
||||||
@@ -95,147 +86,8 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if login {
|
if login {
|
||||||
var ts auth.TokenStorage
|
cmd.DoLogin(cfg, projectID)
|
||||||
if projectID != "" {
|
|
||||||
ts.ProjectID = projectID
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Initialize authenticated HTTP Client
|
|
||||||
clientCtx := context.Background()
|
|
||||||
|
|
||||||
log.Info("Initializing authentication...")
|
|
||||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
|
||||||
if errGetClient != nil {
|
|
||||||
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Info("Authentication successful.")
|
|
||||||
|
|
||||||
// 3. Initialize CLI Client
|
|
||||||
cliClient := client.NewClient(httpClient)
|
|
||||||
if err = cliClient.SetupUser(clientCtx, ts.Email, projectID); err != nil {
|
|
||||||
if err.Error() == "failed to start user onboarding, need define a project id" {
|
|
||||||
log.Error("failed to start user onboarding")
|
|
||||||
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
|
|
||||||
if errGetProjectList != nil {
|
|
||||||
log.Fatalf("failed to complete user setup: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Infof("Your account %s needs specify a project id.", ts.Email)
|
|
||||||
log.Info("========================================================================")
|
|
||||||
for i := 0; i < len(project.Projects); i++ {
|
|
||||||
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
|
|
||||||
log.Infof("Project Name: %s", project.Projects[i].Name)
|
|
||||||
log.Info("========================================================================")
|
|
||||||
}
|
|
||||||
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Log as a warning because in some cases, the CLI might still be usable
|
|
||||||
// or the user might want to retry setup later.
|
|
||||||
log.Fatalf("failed to complete user setup: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Create API server configuration
|
cmd.StartService(cfg)
|
||||||
apiConfig := &api.ServerConfig{
|
|
||||||
Port: fmt.Sprintf("%d", cfg.Port),
|
|
||||||
Debug: cfg.Debug,
|
|
||||||
ApiKeys: cfg.ApiKeys,
|
|
||||||
}
|
|
||||||
|
|
||||||
cliClients := make([]*client.Client, 0)
|
|
||||||
err = filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
|
||||||
log.Debugf(path)
|
|
||||||
f, errOpen := os.Open(path)
|
|
||||||
if errOpen != nil {
|
|
||||||
return errOpen
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = f.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var ts auth.TokenStorage
|
|
||||||
if err = json.NewDecoder(f).Decode(&ts); err == nil {
|
|
||||||
// 2. Initialize authenticated HTTP Client
|
|
||||||
clientCtx := context.Background()
|
|
||||||
|
|
||||||
log.Info("Initializing authentication...")
|
|
||||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
|
||||||
if errGetClient != nil {
|
|
||||||
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
|
||||||
return errGetClient
|
|
||||||
}
|
|
||||||
log.Info("Authentication successful.")
|
|
||||||
|
|
||||||
// 3. Initialize CLI Client
|
|
||||||
cliClient := client.NewClient(httpClient)
|
|
||||||
if err = cliClient.SetupUser(clientCtx, ts.Email, ts.ProjectID); err != nil {
|
|
||||||
if err.Error() == "failed to start user onboarding, need define a project id" {
|
|
||||||
log.Error("failed to start user onboarding")
|
|
||||||
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
|
|
||||||
if errGetProjectList != nil {
|
|
||||||
log.Fatalf("failed to complete user setup: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Infof("Your account %s needs specify a project id.", ts.Email)
|
|
||||||
log.Info("========================================================================")
|
|
||||||
for i := 0; i < len(project.Projects); i++ {
|
|
||||||
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
|
|
||||||
log.Infof("Project Name: %s", project.Projects[i].Name)
|
|
||||||
log.Info("========================================================================")
|
|
||||||
}
|
|
||||||
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Log as a warning because in some cases, the CLI might still be usable
|
|
||||||
// or the user might want to retry setup later.
|
|
||||||
log.Fatalf("failed to complete user setup: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cliClients = append(cliClients, cliClient)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create API server
|
|
||||||
apiServer := api.NewServer(apiConfig, cliClients)
|
|
||||||
log.Infof("Starting API server on port %s", apiConfig.Port)
|
|
||||||
if err = apiServer.Start(); err != nil {
|
|
||||||
log.Fatalf("API server failed to start: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up graceful shutdown
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-sigChan:
|
|
||||||
log.Debugf("Received shutdown signal. Cleaning up...")
|
|
||||||
|
|
||||||
// Create shutdown context
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
_ = ctx // Mark ctx as used to avoid error, as apiServer.Stop(ctx) is commented out
|
|
||||||
|
|
||||||
// Stop API server
|
|
||||||
if err = apiServer.Stop(ctx); err != nil {
|
|
||||||
log.Debugf("Error stopping API server: %v", err)
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
log.Debugf("Cleanup completed. Exiting...")
|
|
||||||
os.Exit(0)
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -407,7 +407,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
cliClient.RequestMutex.Lock()
|
cliClient.RequestMutex.Lock()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Request use account: %s", cliClient.Email)
|
log.Debugf("Request use account: %s, project id: %s", cliClient.Email, cliClient.ProjectID)
|
||||||
jsonTemplate := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
jsonTemplate := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
||||||
for {
|
for {
|
||||||
@@ -429,8 +429,8 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
}
|
}
|
||||||
case err, okError := <-errChan:
|
case err, okError := <-errChan:
|
||||||
if okError {
|
if okError {
|
||||||
c.Status(http.StatusInternalServerError)
|
c.Status(err.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error())
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
// c.JSON(http.StatusInternalServerError, ErrorResponse{
|
// c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
// Error: ErrorDetail{
|
// Error: ErrorDetail{
|
||||||
@@ -501,7 +501,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
cliClient.RequestMutex.Lock()
|
cliClient.RequestMutex.Lock()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Request use account: %s", cliClient.Email)
|
log.Debugf("Request use account: %s, project id: %s", cliClient.Email, cliClient.ProjectID)
|
||||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -526,8 +526,8 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
}
|
}
|
||||||
case err, okError := <-errChan:
|
case err, okError := <-errChan:
|
||||||
if okError {
|
if okError {
|
||||||
c.Status(http.StatusInternalServerError)
|
c.Status(err.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error())
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
// c.JSON(http.StatusInternalServerError, ErrorResponse{
|
// c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
// Error: ErrorDetail{
|
// Error: ErrorDetail{
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ type TokenStorage struct {
|
|||||||
Token any `json:"token"`
|
Token any `json:"token"`
|
||||||
ProjectID string `json:"project_id"`
|
ProjectID string `json:"project_id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
Auto bool `json:"auto"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens.
|
// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens.
|
||||||
@@ -95,11 +96,12 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||||
}
|
}
|
||||||
ts, err = saveTokenToFile(ctx, conf, token, ts.ProjectID, cfg.AuthDir)
|
newTs, errSaveTokenToFile := saveTokenToFile(ctx, conf, token, ts.ProjectID, cfg.AuthDir)
|
||||||
if err != nil {
|
if errSaveTokenToFile != nil {
|
||||||
// Log the error but proceed, as we have a valid token for the session.
|
|
||||||
log.Errorf("Warning: failed to save token to file: %v", err)
|
log.Errorf("Warning: failed to save token to file: %v", err)
|
||||||
|
return nil, errSaveTokenToFile
|
||||||
}
|
}
|
||||||
|
*ts = *newTs
|
||||||
}
|
}
|
||||||
tsToken, _ := json.Marshal(ts.Token)
|
tsToken, _ := json.Marshal(ts.Token)
|
||||||
if err = json.Unmarshal(tsToken, &token); err != nil {
|
if err = json.Unmarshal(tsToken, &token); err != nil {
|
||||||
@@ -139,19 +141,6 @@ func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.T
|
|||||||
log.Info("Failed to get user email from token")
|
log.Info("Failed to get user email from token")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Saving credentials to %s", filepath.Join(authDir, fmt.Sprintf("%s.json", emailResult.String())))
|
|
||||||
if err = os.MkdirAll(authDir, 0700); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := os.Create(filepath.Join(authDir, fmt.Sprintf("%s.json", emailResult.String())))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create token file: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = f.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var ifToken map[string]any
|
var ifToken map[string]any
|
||||||
jsonData, _ := json.Marshal(token)
|
jsonData, _ := json.Marshal(token)
|
||||||
err = json.Unmarshal(jsonData, &ifToken)
|
err = json.Unmarshal(jsonData, &ifToken)
|
||||||
@@ -171,12 +160,46 @@ func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.T
|
|||||||
Email: emailResult.String(),
|
Email: emailResult.String(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
if err = os.MkdirAll(authDir, 0700); err != nil {
|
||||||
return nil, fmt.Errorf("failed to write token to file: %w", err)
|
return nil, fmt.Errorf("failed to create directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID != "" {
|
||||||
|
log.Infof("Saving credentials to %s", filepath.Join(authDir, fmt.Sprintf("%s-%s.json", emailResult.String(), projectID)))
|
||||||
|
|
||||||
|
f, errCreate := os.Create(filepath.Join(authDir, fmt.Sprintf("%s-%s.json", emailResult.String(), projectID)))
|
||||||
|
if errCreate != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create token file: %w", errCreate)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &ts, nil
|
return &ts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SaveTokenToFile(ts *TokenStorage, cfg *config.Config, auto bool) error {
|
||||||
|
ts.Auto = auto
|
||||||
|
fileName := filepath.Join(cfg.AuthDir, fmt.Sprintf("%s-%s.json", ts.Email, ts.ProjectID))
|
||||||
|
log.Infof("Saving credentials to %s", fileName)
|
||||||
|
f, err := os.Create(fileName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// getTokenFromWeb starts a local server to handle the OAuth2 flow.
|
// getTokenFromWeb starts a local server to handle the OAuth2 flow.
|
||||||
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
|
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
|
||||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ const (
|
|||||||
pluginVersion = "1.0.0"
|
pluginVersion = "1.0.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ErrorMessage struct {
|
||||||
|
StatusCode int
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
type GCPProject struct {
|
type GCPProject struct {
|
||||||
Projects []GCPProjectProjects `json:"projects"`
|
Projects []GCPProjectProjects `json:"projects"`
|
||||||
}
|
}
|
||||||
@@ -100,7 +105,7 @@ type ToolDeclaration struct {
|
|||||||
// Client is the main client for interacting with the CLI API.
|
// Client is the main client for interacting with the CLI API.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
projectID string
|
ProjectID string
|
||||||
RequestMutex sync.Mutex
|
RequestMutex sync.Mutex
|
||||||
Email string
|
Email string
|
||||||
}
|
}
|
||||||
@@ -113,7 +118,7 @@ func NewClient(httpClient *http.Client) *Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetupUser performs the initial user onboarding and setup.
|
// SetupUser performs the initial user onboarding and setup.
|
||||||
func (c *Client) SetupUser(ctx context.Context, email, projectID string) error {
|
func (c *Client) SetupUser(ctx context.Context, email, projectID string, auto bool) (string, error) {
|
||||||
c.Email = email
|
c.Email = email
|
||||||
log.Info("Performing user onboarding...")
|
log.Info("Performing user onboarding...")
|
||||||
|
|
||||||
@@ -128,11 +133,14 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string) error {
|
|||||||
var loadAssistResp map[string]interface{}
|
var loadAssistResp map[string]interface{}
|
||||||
err := c.makeAPIRequest(ctx, "loadCodeAssist", "POST", loadAssistReqBody, &loadAssistResp)
|
err := c.makeAPIRequest(ctx, "loadCodeAssist", "POST", loadAssistReqBody, &loadAssistResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load code assist: %w", err)
|
return projectID, fmt.Errorf("failed to load code assist: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// a, _ := json.Marshal(&loadAssistResp)
|
// a, _ := json.Marshal(&loadAssistResp)
|
||||||
// log.Debug(string(a))
|
// log.Debug(string(a))
|
||||||
|
//
|
||||||
|
// a, _ = json.Marshal(loadAssistReqBody)
|
||||||
|
// log.Debug(string(a))
|
||||||
|
|
||||||
// 2. OnboardUser
|
// 2. OnboardUser
|
||||||
var onboardTierID = "legacy-tier"
|
var onboardTierID = "legacy-tier"
|
||||||
@@ -161,13 +169,13 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string) error {
|
|||||||
if onboardProjectID != "" {
|
if onboardProjectID != "" {
|
||||||
onboardReqBody["cloudaicompanionProject"] = onboardProjectID
|
onboardReqBody["cloudaicompanionProject"] = onboardProjectID
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("failed to start user onboarding, need define a project id")
|
return projectID, fmt.Errorf("failed to start user onboarding, need define a project id")
|
||||||
}
|
}
|
||||||
|
|
||||||
var lroResp map[string]interface{}
|
var lroResp map[string]interface{}
|
||||||
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
|
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to start user onboarding: %w", err)
|
return projectID, fmt.Errorf("failed to start user onboarding: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// a, _ = json.Marshal(&lroResp)
|
// a, _ = json.Marshal(&lroResp)
|
||||||
@@ -176,12 +184,17 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string) error {
|
|||||||
// 3. Poll Long-Running Operation (LRO)
|
// 3. Poll Long-Running Operation (LRO)
|
||||||
if done, doneOk := lroResp["done"].(bool); doneOk && done {
|
if done, doneOk := lroResp["done"].(bool); doneOk && done {
|
||||||
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
|
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
|
||||||
c.projectID = project["id"].(string)
|
if projectID != "" && !auto {
|
||||||
log.Infof("Onboarding complete. Using Project ID: %s", c.projectID)
|
c.ProjectID = projectID
|
||||||
return nil
|
log.Infof("Onboarding complete. Project ID: %s is being enforced. Maybe you need to enable 'Gemini for Google Cloud' once in Google Cloud Console.", c.ProjectID)
|
||||||
|
} else {
|
||||||
|
c.ProjectID = project["id"].(string)
|
||||||
|
log.Infof("Onboarding complete. Using Project ID: %s", c.ProjectID)
|
||||||
|
}
|
||||||
|
return c.ProjectID, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("failed to get operation name from onboarding response: %v", lroResp)
|
return projectID, fmt.Errorf("failed to get operation name from onboarding response: %v", lroResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeAPIRequest handles making requests to the CLI API endpoints.
|
// makeAPIRequest handles making requests to the CLI API endpoints.
|
||||||
@@ -240,7 +253,7 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// StreamAPIRequest handles making streaming requests to the CLI API endpoints.
|
// StreamAPIRequest handles making streaming requests to the CLI API endpoints.
|
||||||
func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body interface{}) (io.ReadCloser, error) {
|
func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body interface{}) (io.ReadCloser, *ErrorMessage) {
|
||||||
var jsonBody []byte
|
var jsonBody []byte
|
||||||
var err error
|
var err error
|
||||||
if byteBody, ok := body.([]byte); ok {
|
if byteBody, ok := body.([]byte); ok {
|
||||||
@@ -248,7 +261,7 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
|
|||||||
} else {
|
} else {
|
||||||
jsonBody, err = json.Marshal(body)
|
jsonBody, err = json.Marshal(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// log.Debug(string(jsonBody))
|
// log.Debug(string(jsonBody))
|
||||||
@@ -259,12 +272,12 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
|
|||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %w", err)}
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get token: %w", err)
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %w", err)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
@@ -276,7 +289,7 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
|
|||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %w", err)}
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
@@ -285,7 +298,7 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
|
|||||||
}()
|
}()
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
return nil, fmt.Errorf(string(bodyBytes))
|
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
|
||||||
// return nil, fmt.Errorf("api streaming request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
// return nil, fmt.Errorf("api streaming request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,9 +306,9 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageStream handles a single conversational turn, including tool calls.
|
// SendMessageStream handles a single conversational turn, including tool calls.
|
||||||
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan error) {
|
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
dataTag := []byte("data: ")
|
dataTag := []byte("data: ")
|
||||||
errChan := make(chan error)
|
errChan := make(chan *ErrorMessage)
|
||||||
dataChan := make(chan []byte)
|
dataChan := make(chan []byte)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(errChan)
|
defer close(errChan)
|
||||||
@@ -312,7 +325,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
|||||||
request.Tools = tools
|
request.Tools = tools
|
||||||
|
|
||||||
requestBody := map[string]interface{}{
|
requestBody := map[string]interface{}{
|
||||||
"project": c.projectID, // Assuming ProjectID is available
|
"project": c.ProjectID, // Assuming ProjectID is available
|
||||||
"request": request,
|
"request": request,
|
||||||
"model": model,
|
"model": model,
|
||||||
}
|
}
|
||||||
@@ -370,9 +383,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = scanner.Err(); err != nil {
|
if errScanner := scanner.Err(); errScanner != nil {
|
||||||
// log.Println(err)
|
// log.Println(err)
|
||||||
errChan <- err
|
errChan <- &ErrorMessage{500, errScanner}
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
62
internal/cmd/login.go
Normal file
62
internal/cmd/login.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DoLogin(cfg *config.Config, projectID string) {
|
||||||
|
var err error
|
||||||
|
var ts auth.TokenStorage
|
||||||
|
if projectID != "" {
|
||||||
|
ts.ProjectID = projectID
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Initialize authenticated HTTP Client
|
||||||
|
clientCtx := context.Background()
|
||||||
|
|
||||||
|
log.Info("Initializing authentication...")
|
||||||
|
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||||
|
if errGetClient != nil {
|
||||||
|
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Info("Authentication successful.")
|
||||||
|
|
||||||
|
// 3. Initialize CLI Client
|
||||||
|
cliClient := client.NewClient(httpClient)
|
||||||
|
projectID, err = cliClient.SetupUser(clientCtx, ts.Email, projectID, ts.Auto)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "failed to start user onboarding, need define a project id" {
|
||||||
|
log.Error("failed to start user onboarding")
|
||||||
|
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
|
||||||
|
if errGetProjectList != nil {
|
||||||
|
log.Fatalf("failed to complete user setup: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Infof("Your account %s needs specify a project id.", ts.Email)
|
||||||
|
log.Info("========================================================================")
|
||||||
|
for i := 0; i < len(project.Projects); i++ {
|
||||||
|
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
|
||||||
|
log.Infof("Project Name: %s", project.Projects[i].Name)
|
||||||
|
log.Info("========================================================================")
|
||||||
|
}
|
||||||
|
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Log as a warning because in some cases, the CLI might still be usable
|
||||||
|
// or the user might want to retry setup later.
|
||||||
|
log.Fatalf("failed to complete user setup: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto := ts.ProjectID == ""
|
||||||
|
ts.ProjectID = projectID
|
||||||
|
err = auth.SaveTokenToFile(&ts, cfg, auto)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
122
internal/cmd/run.go
Normal file
122
internal/cmd/run.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StartService(cfg *config.Config) {
|
||||||
|
// Create API server configuration
|
||||||
|
apiConfig := &api.ServerConfig{
|
||||||
|
Port: fmt.Sprintf("%d", cfg.Port),
|
||||||
|
Debug: cfg.Debug,
|
||||||
|
ApiKeys: cfg.ApiKeys,
|
||||||
|
}
|
||||||
|
|
||||||
|
cliClients := make([]*client.Client, 0)
|
||||||
|
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
||||||
|
log.Debugf(path)
|
||||||
|
f, errOpen := os.Open(path)
|
||||||
|
if errOpen != nil {
|
||||||
|
return errOpen
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var ts auth.TokenStorage
|
||||||
|
if err = json.NewDecoder(f).Decode(&ts); err == nil {
|
||||||
|
// 2. Initialize authenticated HTTP Client
|
||||||
|
clientCtx := context.Background()
|
||||||
|
|
||||||
|
log.Info("Initializing authentication...")
|
||||||
|
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||||
|
if errGetClient != nil {
|
||||||
|
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
||||||
|
return errGetClient
|
||||||
|
}
|
||||||
|
log.Info("Authentication successful.")
|
||||||
|
|
||||||
|
// 3. Initialize CLI Client
|
||||||
|
cliClient := client.NewClient(httpClient)
|
||||||
|
if _, err = cliClient.SetupUser(clientCtx, ts.Email, ts.ProjectID, ts.Auto); err != nil {
|
||||||
|
if err.Error() == "failed to start user onboarding, need define a project id" {
|
||||||
|
log.Error("failed to start user onboarding")
|
||||||
|
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
|
||||||
|
if errGetProjectList != nil {
|
||||||
|
log.Fatalf("failed to complete user setup: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Infof("Your account %s needs specify a project id.", ts.Email)
|
||||||
|
log.Info("========================================================================")
|
||||||
|
for i := 0; i < len(project.Projects); i++ {
|
||||||
|
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
|
||||||
|
log.Infof("Project Name: %s", project.Projects[i].Name)
|
||||||
|
log.Info("========================================================================")
|
||||||
|
}
|
||||||
|
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Log as a warning because in some cases, the CLI might still be usable
|
||||||
|
// or the user might want to retry setup later.
|
||||||
|
log.Fatalf("failed to complete user setup: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cliClients = append(cliClients, cliClient)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create API server
|
||||||
|
apiServer := api.NewServer(apiConfig, cliClients)
|
||||||
|
log.Infof("Starting API server on port %s", apiConfig.Port)
|
||||||
|
if err = apiServer.Start(); err != nil {
|
||||||
|
log.Fatalf("API server failed to start: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up graceful shutdown
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sigChan:
|
||||||
|
log.Debugf("Received shutdown signal. Cleaning up...")
|
||||||
|
|
||||||
|
// Create shutdown context
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
_ = ctx // Mark ctx as used to avoid error, as apiServer.Stop(ctx) is commented out
|
||||||
|
|
||||||
|
// Stop API server
|
||||||
|
if err = apiServer.Stop(ctx); err != nil {
|
||||||
|
log.Debugf("Error stopping API server: %v", err)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
log.Debugf("Cleanup completed. Exiting...")
|
||||||
|
os.Exit(0)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user