feat (auth): CLI OAuth supports pasting callback URLs to complete login

- Added callback URL resolution and terminal prompt logic
  - Codex/Claude/iFlow/Antigravity/Gemini login supports callback URL or local callback completion
  - Update Gemini login option signature and manager call
  - CLI default prompt function is compatible with null input to continue waiting
This commit is contained in:
Supra4E8C
2025-12-20 18:25:55 +08:00
parent 10f8c795ac
commit 93414f1baa
14 changed files with 302 additions and 33 deletions

View File

@@ -1093,7 +1093,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
// Initialize authenticated HTTP client via GeminiAuth to honor proxy settings // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
gemAuth := geminiAuth.NewGeminiAuth() gemAuth := geminiAuth.NewGeminiAuth()
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{
NoBrowser: true,
})
if errGetClient != nil { if errGetClient != nil {
log.Errorf("failed to get authenticated client: %v", errGetClient) log.Errorf("failed to get authenticated client: %v", errGetClient)
SetOAuthSessionError(state, "Failed to get authenticated client") SetOAuthSessionError(state, "Failed to get authenticated client")

View File

@@ -18,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -46,6 +47,12 @@ var (
type GeminiAuth struct { type GeminiAuth struct {
} }
// WebLoginOptions customizes the interactive OAuth flow.
type WebLoginOptions struct {
NoBrowser bool
Prompt func(string) (string, error)
}
// NewGeminiAuth creates a new instance of GeminiAuth. // NewGeminiAuth creates a new instance of GeminiAuth.
func NewGeminiAuth() *GeminiAuth { func NewGeminiAuth() *GeminiAuth {
return &GeminiAuth{} return &GeminiAuth{}
@@ -59,12 +66,12 @@ func NewGeminiAuth() *GeminiAuth {
// - ctx: The context for the HTTP client // - ctx: The context for the HTTP client
// - ts: The Gemini token storage containing authentication tokens // - ts: The Gemini token storage containing authentication tokens
// - cfg: The configuration containing proxy settings // - cfg: The configuration containing proxy settings
// - noBrowser: Optional parameter to disable browser opening // - opts: Optional parameters to customize browser and prompt behavior
// //
// Returns: // Returns:
// - *http.Client: An HTTP client configured with authentication // - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise // - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) { func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
// Configure proxy settings for the HTTP client if a proxy URL is provided. // Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL) proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil { if err == nil {
@@ -109,7 +116,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
// If no token is found in storage, initiate the web-based OAuth flow. // If no token is found in storage, initiate the web-based OAuth flow.
if ts.Token == nil { if ts.Token == nil {
fmt.Printf("Could not load token from file, starting OAuth flow.\n") fmt.Printf("Could not load token from file, starting OAuth flow.\n")
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...) token, err = g.getTokenFromWeb(ctx, conf, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get token from web: %w", err) return nil, fmt.Errorf("failed to get token from web: %w", err)
} }
@@ -205,12 +212,12 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// Parameters: // Parameters:
// - ctx: The context for the HTTP client // - ctx: The context for the HTTP client
// - config: The OAuth2 configuration // - config: The OAuth2 configuration
// - noBrowser: Optional parameter to disable browser opening // - opts: Optional parameters to customize browser and prompt behavior
// //
// Returns: // Returns:
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise // - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) { func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function. // Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string) codeChan := make(chan string)
errChan := make(chan error) errChan := make(chan error)
@@ -250,7 +257,12 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Open the authorization URL in the user's browser. // Open the authorization URL in the user's browser.
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
if len(noBrowser) == 1 && !noBrowser[0] { noBrowser := false
if opts != nil {
noBrowser = opts.NoBrowser
}
if !noBrowser {
fmt.Println("Opening browser for authentication...") fmt.Println("Opening browser for authentication...")
// Check if browser is available // Check if browser is available
@@ -281,11 +293,47 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Wait for the authorization code or an error. // Wait for the authorization code or an error.
var authCode string var authCode string
manualCodeChan := make(chan string, 1)
manualErrChan := make(chan error, 1)
if opts != nil && opts.Prompt != nil {
go func() {
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil {
manualErrChan <- err
return
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
manualErrChan <- err
return
}
if parsed == nil {
return
}
if parsed.Error != "" {
manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error)
return
}
if parsed.Code == "" {
manualErrChan <- fmt.Errorf("code not found in callback")
return
}
manualCodeChan <- parsed.Code
}()
} else {
manualCodeChan = nil
manualErrChan = nil
}
select { select {
case code := <-codeChan: case code := <-codeChan:
authCode = code authCode = code
case err := <-errChan: case err := <-errChan:
return nil, err return nil, err
case code := <-manualCodeChan:
authCode = code
case err := <-manualErrChan:
return nil, err
case <-time.After(5 * time.Minute): // Timeout case <-time.After(5 * time.Minute): // Timeout
return nil, fmt.Errorf("oauth flow timed out") return nil, fmt.Errorf("oauth flow timed out")
} }

View File

@@ -24,12 +24,17 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)

View File

@@ -15,11 +15,16 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts) record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)

View File

@@ -20,13 +20,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
promptFn := options.Prompt promptFn := options.Prompt
if promptFn == nil { if promptFn == nil {
promptFn = func(prompt string) (string, error) { promptFn = defaultProjectPrompt()
fmt.Println()
fmt.Println(prompt)
var value string
_, err := fmt.Scanln(&value)
return value, err
}
} }
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{

View File

@@ -55,11 +55,17 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
ctx := context.Background() ctx := context.Background()
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
options.Prompt = promptFn
}
loginOpts := &sdkAuth.LoginOptions{ loginOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
ProjectID: strings.TrimSpace(projectID), ProjectID: strings.TrimSpace(projectID),
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
authenticator := sdkAuth.NewGeminiAuthenticator() authenticator := sdkAuth.NewGeminiAuthenticator()
@@ -76,7 +82,10 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
} }
geminiAuth := gemini.NewGeminiAuth() geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser) httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
NoBrowser: options.NoBrowser,
Prompt: promptFn,
})
if errClient != nil { if errClient != nil {
log.Errorf("Gemini authentication failed: %v", errClient) log.Errorf("Gemini authentication failed: %v", errClient)
return return
@@ -90,11 +99,6 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
return return
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn) selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil { if errSelection != nil {

View File

@@ -35,12 +35,17 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)

View File

@@ -4,6 +4,8 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/url"
"strings"
) )
// GenerateRandomState generates a cryptographically secure random state parameter // GenerateRandomState generates a cryptographically secure random state parameter
@@ -19,3 +21,81 @@ func GenerateRandomState() (string, error) {
} }
return hex.EncodeToString(bytes), nil return hex.EncodeToString(bytes), nil
} }
// OAuthCallback captures the parsed OAuth callback parameters.
type OAuthCallback struct {
Code string
State string
Error string
ErrorDescription string
}
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
// It returns nil when the input is empty.
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
trimmed := strings.TrimSpace(input)
if trimmed == "" {
return nil, nil
}
candidate := trimmed
if !strings.Contains(candidate, "://") {
if strings.HasPrefix(candidate, "?") {
candidate = "http://localhost" + candidate
} else if strings.Contains(candidate, "=") {
candidate = "http://localhost/?" + candidate
} else {
return nil, fmt.Errorf("invalid callback URL")
}
}
parsedURL, err := url.Parse(candidate)
if err != nil {
return nil, err
}
query := parsedURL.Query()
code := strings.TrimSpace(query.Get("code"))
state := strings.TrimSpace(query.Get("state"))
errCode := strings.TrimSpace(query.Get("error"))
errDesc := strings.TrimSpace(query.Get("error_description"))
if parsedURL.Fragment != "" {
if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil {
if code == "" {
code = strings.TrimSpace(fragQuery.Get("code"))
}
if state == "" {
state = strings.TrimSpace(fragQuery.Get("state"))
}
if errCode == "" {
errCode = strings.TrimSpace(fragQuery.Get("error"))
}
if errDesc == "" {
errDesc = strings.TrimSpace(fragQuery.Get("error_description"))
}
}
}
if code != "" && state == "" && strings.Contains(code, "#") {
parts := strings.SplitN(code, "#", 2)
code = parts[0]
state = parts[1]
}
if errCode == "" && errDesc != "" {
errCode = errDesc
errDesc = ""
}
if code == "" && errCode == "" {
return nil, fmt.Errorf("callback URL missing code")
}
return &OAuthCallback{
Code: code,
State: state,
Error: errCode,
ErrorDescription: errDesc,
}, nil
}

View File

@@ -99,9 +99,18 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
fmt.Println("Waiting for antigravity authentication callback...") fmt.Println("Waiting for antigravity authentication callback...")
var cbRes callbackResult var cbRes callbackResult
manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "antigravity")
select { select {
case res := <-cbChan: case res := <-cbChan:
cbRes = res cbRes = res
case manual := <-manualCh:
cbRes = callbackResult{
Code: manual.Code,
State: manual.State,
Error: manual.Error,
}
case err = <-manualErrCh:
return nil, err
case <-time.After(5 * time.Minute): case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("antigravity: authentication timed out") return nil, fmt.Errorf("antigravity: authentication timed out")
} }

View File

@@ -98,16 +98,41 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
fmt.Println("Waiting for Claude authentication callback...") fmt.Println("Waiting for Claude authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *claude.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Claude")
manualDescription := ""
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *claude.OAuthResult
select {
case result = <-callbackCh:
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") { if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
} }
return nil, err return nil, err
case manual := <-manualCh:
manualDescription = manual.ErrorDescription
result = &claude.OAuthResult{
Code: manual.Code,
State: manual.State,
Error: manual.Error,
}
case err = <-manualErrCh:
return nil, err
} }
if result.Error != "" { if result.Error != "" {
return nil, claude.NewOAuthError(result.Error, "", http.StatusBadRequest) return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
} }
if result.State != state { if result.State != state {

View File

@@ -97,16 +97,41 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
fmt.Println("Waiting for Codex authentication callback...") fmt.Println("Waiting for Codex authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *codex.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Codex")
manualDescription := ""
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *codex.OAuthResult
select {
case result = <-callbackCh:
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") { if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
} }
return nil, err return nil, err
case manual := <-manualCh:
manualDescription = manual.ErrorDescription
result = &codex.OAuthResult{
Code: manual.Code,
State: manual.State,
Error: manual.Error,
}
case err = <-manualErrCh:
return nil, err
} }
if result.Error != "" { if result.Error != "" {
return nil, codex.NewOAuthError(result.Error, "", http.StatusBadRequest) return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
} }
if result.State != state { if result.State != state {

View File

@@ -44,7 +44,10 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
} }
geminiAuth := gemini.NewGeminiAuth() geminiAuth := gemini.NewGeminiAuth()
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, opts.NoBrowser) _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
NoBrowser: opts.NoBrowser,
Prompt: opts.Prompt,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("gemini authentication failed: %w", err) return nil, fmt.Errorf("gemini authentication failed: %w", err)
} }

View File

@@ -84,9 +84,32 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
fmt.Println("Waiting for iFlow authentication callback...") fmt.Println("Waiting for iFlow authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *iflow.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "iFlow")
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *iflow.OAuthResult
select {
case result = <-callbackCh:
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
case manual := <-manualCh:
result = &iflow.OAuthResult{
Code: manual.Code,
State: manual.State,
Error: manual.Error,
}
case err = <-manualErrCh:
return nil, err
} }
if result.Error != "" { if result.Error != "" {
return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error)

View File

@@ -0,0 +1,41 @@
package auth
import (
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
func promptForOAuthCallback(prompt func(string) (string, error), provider string) (<-chan *misc.OAuthCallback, <-chan error) {
if prompt == nil {
return nil, nil
}
resultCh := make(chan *misc.OAuthCallback, 1)
errCh := make(chan error, 1)
go func() {
label := provider
if label == "" {
label = "OAuth"
}
input, err := prompt(fmt.Sprintf("Paste the %s callback URL (or press Enter to keep waiting): ", label))
if err != nil {
errCh <- err
return
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
errCh <- err
return
}
if parsed == nil {
return
}
resultCh <- parsed
}()
return resultCh, errCh
}