Enhance OAuth handling for Anthropic, Codex, Gemini, and Qwen tokens

- Transitioned OAuth callback handling from temporary servers to predefined persistent endpoints.
- Simplified token retrieval by replacing in-memory handling with state-file-based persistence.
- Introduced unified `oauthStatus` map for tracking flow progress and errors.
- Added new `/auth/*/callback` routes, streamlining code and state management for OAuth flows.
- Improved error handling and logging in token exchange and callback flows.
This commit is contained in:
Luis Pater
2025-09-10 02:34:22 +08:00
parent 156e3b017d
commit 0449fefa60
2 changed files with 318 additions and 172 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -18,12 +19,17 @@ import (
"github.com/luispater/CLIProxyAPI/internal/auth/qwen" "github.com/luispater/CLIProxyAPI/internal/auth/qwen"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/misc" "github.com/luispater/CLIProxyAPI/internal/misc"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
) )
var (
oauthStatus = make(map[string]string)
)
// List auth files // List auth files
func (h *Handler) ListAuthFiles(c *gin.Context) { func (h *Handler) ListAuthFiles(c *gin.Context) {
entries, err := os.ReadDir(h.cfg.AuthDir) entries, err := os.ReadDir(h.cfg.AuthDir)
@@ -183,93 +189,143 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
// Initialize Claude auth service // Initialize Claude auth service
anthropicAuth := claude.NewClaudeAuth(h.cfg) anthropicAuth := claude.NewClaudeAuth(h.cfg)
// Generate authorization URL // Generate authorization URL (then override redirect_uri to reuse server port)
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes) authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
if err != nil { if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err) log.Fatalf("Failed to generate authorization URL: %v", err)
return return
} }
// Override redirect_uri in authorization URL to current server port
go func() { go func() {
// Initialize OAuth server // Helper: wait for callback file
oauthServer := claude.NewOAuthServer(54545) waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
oauthStatus[state] = "Timeout waiting for OAuth callback"
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
data, errRead := os.ReadFile(path)
if errRead == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(path)
return m, nil
}
time.Sleep(500 * time.Millisecond)
}
}
// Start OAuth callback server log.Info("Waiting for authentication callback...")
if err = oauthServer.Start(); err != nil { // Wait up to 5 minutes
if strings.Contains(err.Error(), "already in use") { resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
authErr := claude.NewAuthenticationError(claude.ErrPortInUse, err) if errWait != nil {
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
log.Error(claude.GetUserFriendlyMessage(authErr)) log.Error(claude.GetUserFriendlyMessage(authErr))
return return
} }
authErr := claude.NewAuthenticationError(claude.ErrServerStartFailed, err) if errStr := resultMap["error"]; errStr != "" {
log.Fatalf("Failed to start OAuth callback server: %v", authErr) oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(claude.GetUserFriendlyMessage(oauthErr))
oauthStatus[state] = "Bad request"
return
}
if resultMap["state"] != state {
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
log.Error(claude.GetUserFriendlyMessage(authErr))
oauthStatus[state] = "State code error"
return
}
// Parse code (Claude may append state after '#')
rawCode := resultMap["code"]
code := strings.Split(rawCode, "#")[0]
// Exchange code for tokens (replicate logic using updated redirect_uri)
// Extract client_id from the modified auth URL
clientID := ""
if u2, errP := url.Parse(authURL); errP == nil {
clientID = u2.Query().Get("client_id")
}
// Build request
bodyMap := map[string]any{
"code": code,
"state": state,
"grant_type": "authorization_code",
"client_id": clientID,
"redirect_uri": "http://localhost:54545/callback",
"code_verifier": pkceCodes.CodeVerifier,
}
bodyJSON, _ := json.Marshal(bodyMap)
httpClient := util.SetProxy(h.cfg, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
oauthStatus[state] = "Failed to exchange authorization code for tokens"
return return
} }
defer func() { defer func() {
if err = oauthServer.Stop(ctx); err != nil { if errClose := resp.Body.Close(); errClose != nil {
log.Warnf("Failed to stop OAuth server: %v", err) log.Errorf("failed to close response body: %v", errClose)
} }
}() }()
respBody, _ := io.ReadAll(resp.Body)
log.Info("Waiting for authentication callback...") if resp.StatusCode != http.StatusOK {
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
// Wait for OAuth callback oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
result, errWaitForCallback := oauthServer.WaitForCallback(5 * time.Minute)
if errWaitForCallback != nil {
if strings.Contains(errWaitForCallback.Error(), "timeout") {
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWaitForCallback)
log.Error(claude.GetUserFriendlyMessage(authErr))
} else {
log.Errorf("Authentication failed: %v", errWaitForCallback)
}
return return
} }
var tResp struct {
if result.Error != "" { AccessToken string `json:"access_token"`
oauthErr := claude.NewOAuthError(result.Error, "", http.StatusBadRequest) RefreshToken string `json:"refresh_token"`
log.Error(claude.GetUserFriendlyMessage(oauthErr)) ExpiresIn int `json:"expires_in"`
Account struct {
EmailAddress string `json:"email_address"`
} `json:"account"`
}
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
log.Errorf("failed to parse token response: %v", errU)
oauthStatus[state] = "Failed to parse token response"
return return
} }
bundle := &claude.ClaudeAuthBundle{
// Validate state parameter TokenData: claude.ClaudeTokenData{
if result.State != state { AccessToken: tResp.AccessToken,
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, result.State)) RefreshToken: tResp.RefreshToken,
log.Error(claude.GetUserFriendlyMessage(authErr)) Email: tResp.Account.EmailAddress,
return Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
} },
LastRefresh: time.Now().Format(time.RFC3339),
log.Debug("Authorization code received, exchanging for tokens...")
// Exchange authorization code for tokens
authBundle, errExchangeCodeForTokens := anthropicAuth.ExchangeCodeForTokens(ctx, result.Code, state, pkceCodes)
if errExchangeCodeForTokens != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchangeCodeForTokens)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
log.Debug("This may be due to network issues or invalid authorization code")
return
} }
// Create token storage // Create token storage
tokenStorage := anthropicAuth.CreateTokenStorage(authBundle) tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
// Initialize Claude client // Initialize Claude client
anthropicClient := client.NewClaudeClient(h.cfg, tokenStorage) anthropicClient := client.NewClaudeClient(h.cfg, tokenStorage)
// Save token storage // Save token storage
if errWaitForCallback = anthropicClient.SaveTokenToFile(); errWaitForCallback != nil { if errSave := anthropicClient.SaveTokenToFile(); errSave != nil {
log.Fatalf("Failed to save authentication tokens: %v", errWaitForCallback) log.Fatalf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens"
return return
} }
log.Info("Authentication successful!") log.Info("Authentication successful!")
if authBundle.APIKey != "" { if bundle.APIKey != "" {
log.Info("API key obtained and saved") log.Info("API key obtained and saved")
} }
log.Info("You can now use Claude services through this CLI") log.Info("You can now use Claude services through this CLI")
delete(oauthStatus, state)
}() }()
c.JSON(200, gin.H{"status": "ok", "url": authURL}) oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
@@ -294,67 +350,46 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
} }
// Build authorization URL and return it immediately // Build authorization URL and return it immediately
authURL := conf.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
go func() { go func() {
codeChan := make(chan string) // Wait for callback file written by server route
errChan := make(chan error) waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
mux := http.NewServeMux()
server := &http.Server{Addr: ":8085", Handler: mux}
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
return
}
code := r.URL.Query().Get("code")
if code == "" {
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
errChan <- fmt.Errorf("code not found in callback")
return
}
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
codeChan <- code
})
go func() {
if errListen := server.ListenAndServe(); errListen != nil && errListen != http.ErrServerClosed {
log.Fatalf("ListenAndServe(): %v", errListen)
}
}()
log.Info("Waiting for authentication callback...") log.Info("Waiting for authentication callback...")
deadline := time.Now().Add(5 * time.Minute)
var authCode string var authCode string
select { for {
case code := <-codeChan: if time.Now().After(deadline) {
authCode = code
case errCallback := <-errChan:
log.Errorf("Authentication failed: %v", errCallback)
// Attempt graceful shutdown
if errShutdown := server.Shutdown(ctx); errShutdown != nil {
log.Warnf("Failed to shut down server: %v", errShutdown)
}
return
case <-time.After(5 * time.Minute):
log.Error("oauth flow timed out") log.Error("oauth flow timed out")
if errShutdown := server.Shutdown(ctx); errShutdown != nil { oauthStatus[state] = "OAuth flow timed out"
log.Warnf("Failed to shut down server after timeout: %v", errShutdown)
}
return return
} }
if data, errR := os.ReadFile(waitFile); errR == nil {
// Shutdown the callback server after receiving the code var m map[string]string
if errShutdown := server.Shutdown(ctx); errShutdown != nil { _ = json.Unmarshal(data, &m)
log.Warnf("Failed to shut down server: %v", errShutdown) _ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
oauthStatus[state] = "Authentication failed"
return
}
authCode = m["code"]
if authCode == "" {
log.Errorf("Authentication failed: code not found")
oauthStatus[state] = "Authentication failed: code not found"
return
}
break
}
time.Sleep(500 * time.Millisecond)
} }
// Exchange authorization code for token // Exchange authorization code for token
token, err := conf.Exchange(ctx, authCode) token, err := conf.Exchange(ctx, authCode)
if err != nil { if err != nil {
log.Errorf("Failed to exchange token: %v", err) log.Errorf("Failed to exchange token: %v", err)
oauthStatus[state] = "Failed to exchange token"
return return
} }
@@ -363,6 +398,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errNewRequest != nil { if errNewRequest != nil {
log.Errorf("Could not get user info: %v", errNewRequest) log.Errorf("Could not get user info: %v", errNewRequest)
oauthStatus[state] = "Could not get user info"
return return
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -371,6 +407,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
resp, errDo := httpClient.Do(req) resp, errDo := httpClient.Do(req)
if errDo != nil { if errDo != nil {
log.Errorf("Failed to execute request: %v", errDo) log.Errorf("Failed to execute request: %v", errDo)
oauthStatus[state] = "Failed to execute request"
return return
} }
defer func() { defer func() {
@@ -382,6 +419,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
return return
} }
@@ -390,6 +428,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
log.Infof("Authenticated user email: %s", email) log.Infof("Authenticated user email: %s", email)
} else { } else {
log.Info("Failed to get user email from token") log.Info("Failed to get user email from token")
oauthStatus[state] = "Failed to get user email from token"
} }
// Marshal/unmarshal oauth2.Token to generic map and enrich fields // Marshal/unmarshal oauth2.Token to generic map and enrich fields
@@ -397,6 +436,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
jsonData, _ := json.Marshal(token) jsonData, _ := json.Marshal(token)
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
log.Errorf("Failed to unmarshal token: %v", errUnmarshal) log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
oauthStatus[state] = "Failed to unmarshal token"
return return
} }
@@ -421,6 +461,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
httpClient2, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) httpClient2, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
if errGetClient != nil { if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient) log.Fatalf("failed to get authenticated client: %v", errGetClient)
oauthStatus[state] = "Failed to get authenticated client"
return return
} }
log.Info("Authentication successful.") log.Info("Authentication successful.")
@@ -432,9 +473,11 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
if err = cliClient.SetupUser(ctx, ts.Email, projectID); err != nil { if err = cliClient.SetupUser(ctx, ts.Email, projectID); err != nil {
if err.Error() == "failed to start user onboarding, need define a project id" { if err.Error() == "failed to start user onboarding, need define a project id" {
log.Error("Failed to start user onboarding: A project ID is required.") log.Error("Failed to start user onboarding: A project ID is required.")
oauthStatus[state] = "Failed to start user onboarding: A project ID is required"
project, errGetProjectList := cliClient.GetProjectList(ctx) project, errGetProjectList := cliClient.GetProjectList(ctx)
if errGetProjectList != nil { if errGetProjectList != nil {
log.Fatalf("Failed to get project list: %v", err) log.Fatalf("Failed to get project list: %v", err)
oauthStatus[state] = "Failed to get project list"
} else { } else {
log.Infof("Your account %s needs to specify a project ID.", ts.Email) log.Infof("Your account %s needs to specify a project ID.", ts.Email)
log.Info("========================================================================") log.Info("========================================================================")
@@ -447,6 +490,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
} }
} else { } else {
log.Fatalf("Failed to complete user setup: %v", err) log.Fatalf("Failed to complete user setup: %v", err)
oauthStatus[state] = "Failed to complete user setup"
} }
return return
} }
@@ -458,24 +502,29 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled() isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
if checkErr != nil { if checkErr != nil {
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr) log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr)
oauthStatus[state] = "Failed to check if Cloud AI API is enabled"
return return
} }
cliClient.SetIsChecked(isChecked) cliClient.SetIsChecked(isChecked)
if !isChecked { if !isChecked {
log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.") log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.")
oauthStatus[state] = "Failed to check if Cloud AI API is enabled"
return return
} }
} }
if err = cliClient.SaveTokenToFile(); err != nil { if err = cliClient.SaveTokenToFile(); err != nil {
log.Fatalf("Failed to save token to file: %v", err) log.Fatalf("Failed to save token to file: %v", err)
oauthStatus[state] = "Failed to save token to file"
return return
} }
delete(oauthStatus, state)
log.Info("You can now use Gemini CLI services through this CLI") log.Info("You can now use Gemini CLI services through this CLI")
}() }()
c.JSON(200, gin.H{"status": "ok", "url": authURL}) oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
func (h *Handler) RequestCodexToken(c *gin.Context) { func (h *Handler) RequestCodexToken(c *gin.Context) {
@@ -508,89 +557,125 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
} }
go func() { go func() {
// Initialize OAuth server // Wait for callback file
oauthServer := codex.NewOAuthServer(1455) waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
// Start OAuth callback server var code string
if err = oauthServer.Start(); err != nil { for {
if strings.Contains(err.Error(), "already in use") { if time.Now().After(deadline) {
authErr := codex.NewAuthenticationError(codex.ErrPortInUse, err) authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
log.Error(codex.GetUserFriendlyMessage(authErr)) log.Error(codex.GetUserFriendlyMessage(authErr))
oauthStatus[state] = "Timeout waiting for OAuth callback"
return return
} }
authErr := codex.NewAuthenticationError(codex.ErrServerStartFailed, err) if data, errR := os.ReadFile(waitFile); errR == nil {
log.Fatalf("Failed to start OAuth callback server: %v", authErr) var m map[string]string
return _ = json.Unmarshal(data, &m)
} _ = os.Remove(waitFile)
defer func() { if errStr := m["error"]; errStr != "" {
if err = oauthServer.Stop(ctx); err != nil { oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Warnf("Failed to stop OAuth server: %v", err)
}
}()
log.Info("Waiting for authentication callback...")
// Wait for OAuth callback
result, errWaitForCallback := oauthServer.WaitForCallback(5 * time.Minute)
if errWaitForCallback != nil {
if strings.Contains(errWaitForCallback.Error(), "timeout") {
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, errWaitForCallback)
log.Error(codex.GetUserFriendlyMessage(authErr))
} else {
log.Errorf("Authentication failed: %v", errWaitForCallback)
}
return
}
if result.Error != "" {
oauthErr := codex.NewOAuthError(result.Error, "", http.StatusBadRequest)
log.Error(codex.GetUserFriendlyMessage(oauthErr)) log.Error(codex.GetUserFriendlyMessage(oauthErr))
oauthStatus[state] = "Bad Request"
return return
} }
if m["state"] != state {
// Validate state parameter authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
if result.State != state { oauthStatus[state] = "State code error"
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, result.State))
log.Error(codex.GetUserFriendlyMessage(authErr)) log.Error(codex.GetUserFriendlyMessage(authErr))
return return
} }
code = m["code"]
break
}
time.Sleep(500 * time.Millisecond)
}
log.Debug("Authorization code received, exchanging for tokens...") log.Debug("Authorization code received, exchanging for tokens...")
// Extract client_id from authURL
// Exchange authorization code for tokens clientID := ""
authBundle, errExchangeCodeForTokens := openaiAuth.ExchangeCodeForTokens(ctx, result.Code, pkceCodes) if u2, errP := url.Parse(authURL); errP == nil {
if errExchangeCodeForTokens != nil { clientID = u2.Query().Get("client_id")
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchangeCodeForTokens) }
// Exchange code for tokens with redirect equal to mgmtRedirect
form := url.Values{
"grant_type": {"authorization_code"},
"client_id": {clientID},
"code": {code},
"redirect_uri": {"http://localhost:1455/auth/callback"},
"code_verifier": {pkceCodes.CodeVerifier},
}
httpClient := util.SetProxy(h.cfg, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
oauthStatus[state] = "Failed to exchange authorization code for tokens"
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
log.Debug("This may be due to network issues or invalid authorization code")
return return
} }
defer func() { _ = resp.Body.Close() }()
// Create token storage respBody, _ := io.ReadAll(resp.Body)
tokenStorage := openaiAuth.CreateTokenStorage(authBundle) if resp.StatusCode != http.StatusOK {
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
// Initialize Codex client log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
openaiClient, errWaitForCallback := client.NewCodexClient(h.cfg, tokenStorage)
if errWaitForCallback != nil {
log.Fatalf("Failed to initialize Codex client: %v", errWaitForCallback)
return return
} }
var tokenResp struct {
// Save token storage AccessToken string `json:"access_token"`
if errWaitForCallback = openaiClient.SaveTokenToFile(); errWaitForCallback != nil { RefreshToken string `json:"refresh_token"`
log.Fatalf("Failed to save authentication tokens: %v", errWaitForCallback) IDToken string `json:"id_token"`
ExpiresIn int `json:"expires_in"`
}
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
oauthStatus[state] = "Failed to parse token response"
log.Errorf("failed to parse token response: %v", errU)
return return
} }
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
email := ""
accountID := ""
if claims != nil {
email = claims.GetUserEmail()
accountID = claims.GetAccountID()
}
// Build bundle compatible with existing storage
bundle := &codex.CodexAuthBundle{
TokenData: codex.CodexTokenData{
IDToken: tokenResp.IDToken,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
AccountID: accountID,
Email: email,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
}
// Create token storage and persist
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
openaiClient, errInit := client.NewCodexClient(h.cfg, tokenStorage)
if errInit != nil {
oauthStatus[state] = "Failed to initialize Codex client"
log.Fatalf("Failed to initialize Codex client: %v", errInit)
return
}
if errSave := openaiClient.SaveTokenToFile(); errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens"
log.Fatalf("Failed to save authentication tokens: %v", errSave)
return
}
log.Info("Authentication successful!") log.Info("Authentication successful!")
if authBundle.APIKey != "" { if bundle.APIKey != "" {
log.Info("API key obtained and saved") log.Info("API key obtained and saved")
} }
log.Info("You can now use Codex services through this CLI") log.Info("You can now use Codex services through this CLI")
delete(oauthStatus, state)
}() }()
c.JSON(200, gin.H{"status": "ok", "url": authURL}) oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
func (h *Handler) RequestQwenToken(c *gin.Context) { func (h *Handler) RequestQwenToken(c *gin.Context) {
@@ -598,6 +683,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
log.Info("Initializing Qwen authentication...") log.Info("Initializing Qwen authentication...")
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
// Initialize Qwen auth service // Initialize Qwen auth service
qwenAuth := qwen.NewQwenAuth(h.cfg) qwenAuth := qwen.NewQwenAuth(h.cfg)
@@ -613,8 +699,9 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
log.Info("Waiting for authentication...") log.Info("Waiting for authentication...")
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
if errPollForToken != nil { if errPollForToken != nil {
oauthStatus[state] = "Authentication failed"
fmt.Printf("Authentication failed: %v\n", errPollForToken) fmt.Printf("Authentication failed: %v\n", errPollForToken)
os.Exit(1) return
} }
// Create token storage // Create token storage
@@ -628,12 +715,30 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
// Save token storage // Save token storage
if err = qwenClient.SaveTokenToFile(); err != nil { if err = qwenClient.SaveTokenToFile(); err != nil {
log.Fatalf("Failed to save authentication tokens: %v", err) log.Fatalf("Failed to save authentication tokens: %v", err)
oauthStatus[state] = "Failed to save authentication tokens"
return return
} }
log.Info("Authentication successful!") log.Info("Authentication successful!")
log.Info("You can now use Qwen services through this CLI") log.Info("You can now use Qwen services through this CLI")
delete(oauthStatus, state)
}() }()
c.JSON(200, gin.H{"status": "ok", "url": authURL}) oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) GetAuthStatus(c *gin.Context) {
state := c.Query("state")
if err, ok := oauthStatus[state]; ok {
if err != "" {
c.JSON(200, gin.H{"status": "error", "error": err})
} else {
c.JSON(200, gin.H{"status": "wait"})
return
}
} else {
c.JSON(200, gin.H{"status": "ok"})
}
delete(oauthStatus, state)
} }

View File

@@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"os"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -145,6 +146,46 @@ func (s *Server) setupRoutes() {
}) })
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
// OAuth callback endpoints (reuse main server port)
// These endpoints receive provider redirects and persist
// the short-lived code/state for the waiting goroutine.
s.engine.GET("/anthropic/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
// Persist to a temporary file keyed by state
if state != "" {
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
})
s.engine.GET("/codex/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if state != "" {
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
})
s.engine.GET("/google/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if state != "" {
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
})
// Management API routes (delegated to management handlers) // Management API routes (delegated to management handlers)
// New logic: if remote-management-key is empty, do not expose any management endpoint (404). // New logic: if remote-management-key is empty, do not expose any management endpoint (404).
if s.cfg.RemoteManagement.SecretKey != "" { if s.cfg.RemoteManagement.SecretKey != "" {
@@ -216,7 +257,7 @@ func (s *Server) setupRoutes() {
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken) mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken) mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
} }
} }
} }