mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Enhance OAuth handling for Anthropic, Codex, Gemini, and Qwen tokens
- Transitioned OAuth callback handling from temporary servers to predefined persistent endpoints. - Simplified token retrieval by replacing in-memory handling with state-file-based persistence. - Introduced unified `oauthStatus` map for tracking flow progress and errors. - Added new `/auth/*/callback` routes, streamlining code and state management for OAuth flows. - Improved error handling and logging in token exchange and callback flows.
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -18,12 +19,17 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/auth/qwen"
|
"github.com/luispater/CLIProxyAPI/internal/auth/qwen"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/misc"
|
"github.com/luispater/CLIProxyAPI/internal/misc"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
oauthStatus = make(map[string]string)
|
||||||
|
)
|
||||||
|
|
||||||
// List auth files
|
// List auth files
|
||||||
func (h *Handler) ListAuthFiles(c *gin.Context) {
|
func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||||
entries, err := os.ReadDir(h.cfg.AuthDir)
|
entries, err := os.ReadDir(h.cfg.AuthDir)
|
||||||
@@ -183,93 +189,143 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
// Initialize Claude auth service
|
// Initialize Claude auth service
|
||||||
anthropicAuth := claude.NewClaudeAuth(h.cfg)
|
anthropicAuth := claude.NewClaudeAuth(h.cfg)
|
||||||
|
|
||||||
// Generate authorization URL
|
// Generate authorization URL (then override redirect_uri to reuse server port)
|
||||||
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
|
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to generate authorization URL: %v", err)
|
log.Fatalf("Failed to generate authorization URL: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Override redirect_uri in authorization URL to current server port
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Initialize OAuth server
|
// Helper: wait for callback file
|
||||||
oauthServer := claude.NewOAuthServer(54545)
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
|
||||||
|
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
|
||||||
// Start OAuth callback server
|
deadline := time.Now().Add(timeout)
|
||||||
if err = oauthServer.Start(); err != nil {
|
for {
|
||||||
if strings.Contains(err.Error(), "already in use") {
|
if time.Now().After(deadline) {
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrPortInUse, err)
|
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
return
|
}
|
||||||
|
data, errRead := os.ReadFile(path)
|
||||||
|
if errRead == nil {
|
||||||
|
var m map[string]string
|
||||||
|
_ = json.Unmarshal(data, &m)
|
||||||
|
_ = os.Remove(path)
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
}
|
}
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrServerStartFailed, err)
|
|
||||||
log.Fatalf("Failed to start OAuth callback server: %v", authErr)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if err = oauthServer.Stop(ctx); err != nil {
|
|
||||||
log.Warnf("Failed to stop OAuth server: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
log.Info("Waiting for authentication callback...")
|
log.Info("Waiting for authentication callback...")
|
||||||
|
// Wait up to 5 minutes
|
||||||
// Wait for OAuth callback
|
resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
|
||||||
result, errWaitForCallback := oauthServer.WaitForCallback(5 * time.Minute)
|
if errWait != nil {
|
||||||
if errWaitForCallback != nil {
|
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
|
||||||
if strings.Contains(errWaitForCallback.Error(), "timeout") {
|
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWaitForCallback)
|
|
||||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
|
||||||
} else {
|
|
||||||
log.Errorf("Authentication failed: %v", errWaitForCallback)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Error != "" {
|
|
||||||
oauthErr := claude.NewOAuthError(result.Error, "", http.StatusBadRequest)
|
|
||||||
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate state parameter
|
|
||||||
if result.State != state {
|
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, result.State))
|
|
||||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if errStr := resultMap["error"]; errStr != "" {
|
||||||
log.Debug("Authorization code received, exchanging for tokens...")
|
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||||
|
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
||||||
// Exchange authorization code for tokens
|
oauthStatus[state] = "Bad request"
|
||||||
authBundle, errExchangeCodeForTokens := anthropicAuth.ExchangeCodeForTokens(ctx, result.Code, state, pkceCodes)
|
return
|
||||||
if errExchangeCodeForTokens != nil {
|
}
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchangeCodeForTokens)
|
if resultMap["state"] != state {
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
||||||
log.Debug("This may be due to network issues or invalid authorization code")
|
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||||
|
oauthStatus[state] = "State code error"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create token storage
|
// Parse code (Claude may append state after '#')
|
||||||
tokenStorage := anthropicAuth.CreateTokenStorage(authBundle)
|
rawCode := resultMap["code"]
|
||||||
|
code := strings.Split(rawCode, "#")[0]
|
||||||
|
|
||||||
|
// Exchange code for tokens (replicate logic using updated redirect_uri)
|
||||||
|
// Extract client_id from the modified auth URL
|
||||||
|
clientID := ""
|
||||||
|
if u2, errP := url.Parse(authURL); errP == nil {
|
||||||
|
clientID = u2.Query().Get("client_id")
|
||||||
|
}
|
||||||
|
// Build request
|
||||||
|
bodyMap := map[string]any{
|
||||||
|
"code": code,
|
||||||
|
"state": state,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": clientID,
|
||||||
|
"redirect_uri": "http://localhost:54545/callback",
|
||||||
|
"code_verifier": pkceCodes.CodeVerifier,
|
||||||
|
}
|
||||||
|
bodyJSON, _ := json.Marshal(bodyMap)
|
||||||
|
|
||||||
|
httpClient := util.SetProxy(h.cfg, &http.Client{})
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
resp, errDo := httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||||
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
|
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("failed to close response body: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var tResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
Account struct {
|
||||||
|
EmailAddress string `json:"email_address"`
|
||||||
|
} `json:"account"`
|
||||||
|
}
|
||||||
|
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||||
|
log.Errorf("failed to parse token response: %v", errU)
|
||||||
|
oauthStatus[state] = "Failed to parse token response"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bundle := &claude.ClaudeAuthBundle{
|
||||||
|
TokenData: claude.ClaudeTokenData{
|
||||||
|
AccessToken: tResp.AccessToken,
|
||||||
|
RefreshToken: tResp.RefreshToken,
|
||||||
|
Email: tResp.Account.EmailAddress,
|
||||||
|
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
LastRefresh: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token storage
|
||||||
|
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
||||||
// Initialize Claude client
|
// Initialize Claude client
|
||||||
anthropicClient := client.NewClaudeClient(h.cfg, tokenStorage)
|
anthropicClient := client.NewClaudeClient(h.cfg, tokenStorage)
|
||||||
|
|
||||||
// Save token storage
|
// Save token storage
|
||||||
if errWaitForCallback = anthropicClient.SaveTokenToFile(); errWaitForCallback != nil {
|
if errSave := anthropicClient.SaveTokenToFile(); errSave != nil {
|
||||||
log.Fatalf("Failed to save authentication tokens: %v", errWaitForCallback)
|
log.Fatalf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
oauthStatus[state] = "Failed to save authentication tokens"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Authentication successful!")
|
log.Info("Authentication successful!")
|
||||||
if authBundle.APIKey != "" {
|
if bundle.APIKey != "" {
|
||||||
log.Info("API key obtained and saved")
|
log.Info("API key obtained and saved")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("You can now use Claude services through this CLI")
|
log.Info("You can now use Claude services through this CLI")
|
||||||
|
delete(oauthStatus, state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL})
|
oauthStatus[state] = ""
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||||
@@ -294,67 +350,46 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build authorization URL and return it immediately
|
// Build authorization URL and return it immediately
|
||||||
authURL := conf.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||||
|
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
codeChan := make(chan string)
|
// Wait for callback file written by server route
|
||||||
errChan := make(chan error)
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
server := &http.Server{Addr: ":8085", Handler: mux}
|
|
||||||
|
|
||||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if err := r.URL.Query().Get("error"); err != "" {
|
|
||||||
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
|
||||||
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
code := r.URL.Query().Get("code")
|
|
||||||
if code == "" {
|
|
||||||
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
|
|
||||||
errChan <- fmt.Errorf("code not found in callback")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
|
||||||
codeChan <- code
|
|
||||||
})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if errListen := server.ListenAndServe(); errListen != nil && errListen != http.ErrServerClosed {
|
|
||||||
log.Fatalf("ListenAndServe(): %v", errListen)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
log.Info("Waiting for authentication callback...")
|
log.Info("Waiting for authentication callback...")
|
||||||
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
var authCode string
|
var authCode string
|
||||||
select {
|
for {
|
||||||
case code := <-codeChan:
|
if time.Now().After(deadline) {
|
||||||
authCode = code
|
log.Error("oauth flow timed out")
|
||||||
case errCallback := <-errChan:
|
oauthStatus[state] = "OAuth flow timed out"
|
||||||
log.Errorf("Authentication failed: %v", errCallback)
|
return
|
||||||
// Attempt graceful shutdown
|
|
||||||
if errShutdown := server.Shutdown(ctx); errShutdown != nil {
|
|
||||||
log.Warnf("Failed to shut down server: %v", errShutdown)
|
|
||||||
}
|
}
|
||||||
return
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
case <-time.After(5 * time.Minute):
|
var m map[string]string
|
||||||
log.Error("oauth flow timed out")
|
_ = json.Unmarshal(data, &m)
|
||||||
if errShutdown := server.Shutdown(ctx); errShutdown != nil {
|
_ = os.Remove(waitFile)
|
||||||
log.Warnf("Failed to shut down server after timeout: %v", errShutdown)
|
if errStr := m["error"]; errStr != "" {
|
||||||
|
log.Errorf("Authentication failed: %s", errStr)
|
||||||
|
oauthStatus[state] = "Authentication failed"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authCode = m["code"]
|
||||||
|
if authCode == "" {
|
||||||
|
log.Errorf("Authentication failed: code not found")
|
||||||
|
oauthStatus[state] = "Authentication failed: code not found"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
return
|
time.Sleep(500 * time.Millisecond)
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown the callback server after receiving the code
|
|
||||||
if errShutdown := server.Shutdown(ctx); errShutdown != nil {
|
|
||||||
log.Warnf("Failed to shut down server: %v", errShutdown)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange authorization code for token
|
// Exchange authorization code for token
|
||||||
token, err := conf.Exchange(ctx, authCode)
|
token, err := conf.Exchange(ctx, authCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to exchange token: %v", err)
|
log.Errorf("Failed to exchange token: %v", err)
|
||||||
|
oauthStatus[state] = "Failed to exchange token"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,6 +398,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||||
if errNewRequest != nil {
|
if errNewRequest != nil {
|
||||||
log.Errorf("Could not get user info: %v", errNewRequest)
|
log.Errorf("Could not get user info: %v", errNewRequest)
|
||||||
|
oauthStatus[state] = "Could not get user info"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -371,6 +407,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
log.Errorf("Failed to execute request: %v", errDo)
|
log.Errorf("Failed to execute request: %v", errDo)
|
||||||
|
oauthStatus[state] = "Failed to execute request"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -382,6 +419,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -390,6 +428,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
log.Infof("Authenticated user email: %s", email)
|
log.Infof("Authenticated user email: %s", email)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Failed to get user email from token")
|
log.Info("Failed to get user email from token")
|
||||||
|
oauthStatus[state] = "Failed to get user email from token"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
||||||
@@ -397,6 +436,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
jsonData, _ := json.Marshal(token)
|
jsonData, _ := json.Marshal(token)
|
||||||
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
||||||
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
||||||
|
oauthStatus[state] = "Failed to unmarshal token"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -421,6 +461,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
httpClient2, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
httpClient2, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
||||||
if errGetClient != nil {
|
if errGetClient != nil {
|
||||||
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
||||||
|
oauthStatus[state] = "Failed to get authenticated client"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Info("Authentication successful.")
|
log.Info("Authentication successful.")
|
||||||
@@ -432,9 +473,11 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
if err = cliClient.SetupUser(ctx, ts.Email, projectID); err != nil {
|
if err = cliClient.SetupUser(ctx, ts.Email, projectID); err != nil {
|
||||||
if err.Error() == "failed to start user onboarding, need define a project id" {
|
if err.Error() == "failed to start user onboarding, need define a project id" {
|
||||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||||
|
oauthStatus[state] = "Failed to start user onboarding: A project ID is required"
|
||||||
project, errGetProjectList := cliClient.GetProjectList(ctx)
|
project, errGetProjectList := cliClient.GetProjectList(ctx)
|
||||||
if errGetProjectList != nil {
|
if errGetProjectList != nil {
|
||||||
log.Fatalf("Failed to get project list: %v", err)
|
log.Fatalf("Failed to get project list: %v", err)
|
||||||
|
oauthStatus[state] = "Failed to get project list"
|
||||||
} else {
|
} else {
|
||||||
log.Infof("Your account %s needs to specify a project ID.", ts.Email)
|
log.Infof("Your account %s needs to specify a project ID.", ts.Email)
|
||||||
log.Info("========================================================================")
|
log.Info("========================================================================")
|
||||||
@@ -447,6 +490,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Fatalf("Failed to complete user setup: %v", err)
|
log.Fatalf("Failed to complete user setup: %v", err)
|
||||||
|
oauthStatus[state] = "Failed to complete user setup"
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -458,24 +502,29 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
|
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
|
||||||
if checkErr != nil {
|
if checkErr != nil {
|
||||||
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr)
|
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr)
|
||||||
|
oauthStatus[state] = "Failed to check if Cloud AI API is enabled"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cliClient.SetIsChecked(isChecked)
|
cliClient.SetIsChecked(isChecked)
|
||||||
if !isChecked {
|
if !isChecked {
|
||||||
log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.")
|
log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.")
|
||||||
|
oauthStatus[state] = "Failed to check if Cloud AI API is enabled"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = cliClient.SaveTokenToFile(); err != nil {
|
if err = cliClient.SaveTokenToFile(); err != nil {
|
||||||
log.Fatalf("Failed to save token to file: %v", err)
|
log.Fatalf("Failed to save token to file: %v", err)
|
||||||
|
oauthStatus[state] = "Failed to save token to file"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
delete(oauthStatus, state)
|
||||||
log.Info("You can now use Gemini CLI services through this CLI")
|
log.Info("You can now use Gemini CLI services through this CLI")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL})
|
oauthStatus[state] = ""
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||||
@@ -508,89 +557,125 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Initialize OAuth server
|
// Wait for callback file
|
||||||
oauthServer := codex.NewOAuthServer(1455)
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
|
||||||
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
// Start OAuth callback server
|
var code string
|
||||||
if err = oauthServer.Start(); err != nil {
|
for {
|
||||||
if strings.Contains(err.Error(), "already in use") {
|
if time.Now().After(deadline) {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrPortInUse, err)
|
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
||||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrServerStartFailed, err)
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
log.Fatalf("Failed to start OAuth callback server: %v", authErr)
|
var m map[string]string
|
||||||
return
|
_ = json.Unmarshal(data, &m)
|
||||||
}
|
_ = os.Remove(waitFile)
|
||||||
defer func() {
|
if errStr := m["error"]; errStr != "" {
|
||||||
if err = oauthServer.Stop(ctx); err != nil {
|
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||||
log.Warnf("Failed to stop OAuth server: %v", err)
|
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||||
|
oauthStatus[state] = "Bad Request"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m["state"] != state {
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
||||||
|
oauthStatus[state] = "State code error"
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code = m["code"]
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}()
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
log.Info("Waiting for authentication callback...")
|
|
||||||
|
|
||||||
// Wait for OAuth callback
|
|
||||||
result, errWaitForCallback := oauthServer.WaitForCallback(5 * time.Minute)
|
|
||||||
if errWaitForCallback != nil {
|
|
||||||
if strings.Contains(errWaitForCallback.Error(), "timeout") {
|
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, errWaitForCallback)
|
|
||||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
|
||||||
} else {
|
|
||||||
log.Errorf("Authentication failed: %v", errWaitForCallback)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Error != "" {
|
|
||||||
oauthErr := codex.NewOAuthError(result.Error, "", http.StatusBadRequest)
|
|
||||||
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate state parameter
|
|
||||||
if result.State != state {
|
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, result.State))
|
|
||||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("Authorization code received, exchanging for tokens...")
|
log.Debug("Authorization code received, exchanging for tokens...")
|
||||||
|
// Extract client_id from authURL
|
||||||
// Exchange authorization code for tokens
|
clientID := ""
|
||||||
authBundle, errExchangeCodeForTokens := openaiAuth.ExchangeCodeForTokens(ctx, result.Code, pkceCodes)
|
if u2, errP := url.Parse(authURL); errP == nil {
|
||||||
if errExchangeCodeForTokens != nil {
|
clientID = u2.Query().Get("client_id")
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchangeCodeForTokens)
|
}
|
||||||
|
// Exchange code for tokens with redirect equal to mgmtRedirect
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"client_id": {clientID},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {"http://localhost:1455/auth/callback"},
|
||||||
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
|
}
|
||||||
|
httpClient := util.SetProxy(h.cfg, &http.Client{})
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
resp, errDo := httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||||
|
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
log.Debug("This may be due to network issues or invalid authorization code")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
// Create token storage
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
tokenStorage := openaiAuth.CreateTokenStorage(authBundle)
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
||||||
// Initialize Codex client
|
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||||
openaiClient, errWaitForCallback := client.NewCodexClient(h.cfg, tokenStorage)
|
|
||||||
if errWaitForCallback != nil {
|
|
||||||
log.Fatalf("Failed to initialize Codex client: %v", errWaitForCallback)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var tokenResp struct {
|
||||||
// Save token storage
|
AccessToken string `json:"access_token"`
|
||||||
if errWaitForCallback = openaiClient.SaveTokenToFile(); errWaitForCallback != nil {
|
RefreshToken string `json:"refresh_token"`
|
||||||
log.Fatalf("Failed to save authentication tokens: %v", errWaitForCallback)
|
IDToken string `json:"id_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||||
|
oauthStatus[state] = "Failed to parse token response"
|
||||||
|
log.Errorf("failed to parse token response: %v", errU)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
|
||||||
|
email := ""
|
||||||
|
accountID := ""
|
||||||
|
if claims != nil {
|
||||||
|
email = claims.GetUserEmail()
|
||||||
|
accountID = claims.GetAccountID()
|
||||||
|
}
|
||||||
|
// Build bundle compatible with existing storage
|
||||||
|
bundle := &codex.CodexAuthBundle{
|
||||||
|
TokenData: codex.CodexTokenData{
|
||||||
|
IDToken: tokenResp.IDToken,
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
AccountID: accountID,
|
||||||
|
Email: email,
|
||||||
|
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
LastRefresh: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token storage and persist
|
||||||
|
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
|
||||||
|
openaiClient, errInit := client.NewCodexClient(h.cfg, tokenStorage)
|
||||||
|
if errInit != nil {
|
||||||
|
oauthStatus[state] = "Failed to initialize Codex client"
|
||||||
|
log.Fatalf("Failed to initialize Codex client: %v", errInit)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errSave := openaiClient.SaveTokenToFile(); errSave != nil {
|
||||||
|
oauthStatus[state] = "Failed to save authentication tokens"
|
||||||
|
log.Fatalf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
return
|
||||||
|
}
|
||||||
log.Info("Authentication successful!")
|
log.Info("Authentication successful!")
|
||||||
if authBundle.APIKey != "" {
|
if bundle.APIKey != "" {
|
||||||
log.Info("API key obtained and saved")
|
log.Info("API key obtained and saved")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("You can now use Codex services through this CLI")
|
log.Info("You can now use Codex services through this CLI")
|
||||||
|
delete(oauthStatus, state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL})
|
oauthStatus[state] = ""
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||||
@@ -598,6 +683,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
|
|
||||||
log.Info("Initializing Qwen authentication...")
|
log.Info("Initializing Qwen authentication...")
|
||||||
|
|
||||||
|
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||||
// Initialize Qwen auth service
|
// Initialize Qwen auth service
|
||||||
qwenAuth := qwen.NewQwenAuth(h.cfg)
|
qwenAuth := qwen.NewQwenAuth(h.cfg)
|
||||||
|
|
||||||
@@ -613,8 +699,9 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
log.Info("Waiting for authentication...")
|
log.Info("Waiting for authentication...")
|
||||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||||
if errPollForToken != nil {
|
if errPollForToken != nil {
|
||||||
|
oauthStatus[state] = "Authentication failed"
|
||||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||||
os.Exit(1)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create token storage
|
// Create token storage
|
||||||
@@ -628,12 +715,30 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
// Save token storage
|
// Save token storage
|
||||||
if err = qwenClient.SaveTokenToFile(); err != nil {
|
if err = qwenClient.SaveTokenToFile(); err != nil {
|
||||||
log.Fatalf("Failed to save authentication tokens: %v", err)
|
log.Fatalf("Failed to save authentication tokens: %v", err)
|
||||||
|
oauthStatus[state] = "Failed to save authentication tokens"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Authentication successful!")
|
log.Info("Authentication successful!")
|
||||||
log.Info("You can now use Qwen services through this CLI")
|
log.Info("You can now use Qwen services through this CLI")
|
||||||
|
delete(oauthStatus, state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL})
|
oauthStatus[state] = ""
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||||
|
state := c.Query("state")
|
||||||
|
if err, ok := oauthStatus[state]; ok {
|
||||||
|
if err != "" {
|
||||||
|
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||||
|
} else {
|
||||||
|
c.JSON(200, gin.H{"status": "wait"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
|
}
|
||||||
|
delete(oauthStatus, state)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -145,6 +146,46 @@ func (s *Server) setupRoutes() {
|
|||||||
})
|
})
|
||||||
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
||||||
|
|
||||||
|
// OAuth callback endpoints (reuse main server port)
|
||||||
|
// These endpoints receive provider redirects and persist
|
||||||
|
// the short-lived code/state for the waiting goroutine.
|
||||||
|
s.engine.GET("/anthropic/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
// Persist to a temporary file keyed by state
|
||||||
|
if state != "" {
|
||||||
|
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
|
||||||
|
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||||
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/codex/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if state != "" {
|
||||||
|
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
|
||||||
|
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||||
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/google/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if state != "" {
|
||||||
|
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
|
||||||
|
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||||
|
})
|
||||||
|
|
||||||
// Management API routes (delegated to management handlers)
|
// Management API routes (delegated to management handlers)
|
||||||
// New logic: if remote-management-key is empty, do not expose any management endpoint (404).
|
// New logic: if remote-management-key is empty, do not expose any management endpoint (404).
|
||||||
if s.cfg.RemoteManagement.SecretKey != "" {
|
if s.cfg.RemoteManagement.SecretKey != "" {
|
||||||
@@ -216,7 +257,7 @@ func (s *Server) setupRoutes() {
|
|||||||
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
||||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||||
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user