From 93414f1baa7d3734defb5b6937576b1474916a7e Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sat, 20 Dec 2025 18:25:55 +0800 Subject: [PATCH 1/6] feat (auth): CLI OAuth supports pasting callback URLs to complete login - Added callback URL resolution and terminal prompt logic - Codex/Claude/iFlow/Antigravity/Gemini login supports callback URL or local callback completion - Update Gemini login option signature and manager call - CLI default prompt function is compatible with null input to continue waiting --- .../api/handlers/management/auth_files.go | 4 +- internal/auth/gemini/gemini_auth.go | 60 ++++++++++++-- internal/cmd/anthropic_login.go | 7 +- internal/cmd/antigravity_login.go | 7 +- internal/cmd/iflow_login.go | 8 +- internal/cmd/login.go | 18 +++-- internal/cmd/openai_login.go | 7 +- internal/misc/oauth.go | 80 +++++++++++++++++++ sdk/auth/antigravity.go | 9 +++ sdk/auth/claude.go | 31 ++++++- sdk/auth/codex.go | 31 ++++++- sdk/auth/gemini.go | 5 +- sdk/auth/iflow.go | 27 ++++++- sdk/auth/oauth_callback.go | 41 ++++++++++ 14 files changed, 302 insertions(+), 33 deletions(-) create mode 100644 sdk/auth/oauth_callback.go diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index bf5a5b9c..4f42bd7a 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -1093,7 +1093,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings gemAuth := geminiAuth.NewGeminiAuth() - gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) + gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{ + NoBrowser: true, + }) if errGetClient != nil { log.Errorf("failed to get authenticated client: %v", errGetClient) SetOAuthSessionError(state, "Failed to get authenticated client") diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index f173c95f..dc9b1034 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -18,6 +18,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -46,6 +47,12 @@ var ( type GeminiAuth struct { } +// WebLoginOptions customizes the interactive OAuth flow. +type WebLoginOptions struct { + NoBrowser bool + Prompt func(string) (string, error) +} + // NewGeminiAuth creates a new instance of GeminiAuth. func NewGeminiAuth() *GeminiAuth { return &GeminiAuth{} @@ -59,12 +66,12 @@ func NewGeminiAuth() *GeminiAuth { // - ctx: The context for the HTTP client // - ts: The Gemini token storage containing authentication tokens // - cfg: The configuration containing proxy settings -// - noBrowser: Optional parameter to disable browser opening +// - opts: Optional parameters to customize browser and prompt behavior // // Returns: // - *http.Client: An HTTP client configured with authentication // - error: An error if the client configuration fails, nil otherwise -func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) { +func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) { // Configure proxy settings for the HTTP client if a proxy URL is provided. proxyURL, err := url.Parse(cfg.ProxyURL) if err == nil { @@ -109,7 +116,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken // If no token is found in storage, initiate the web-based OAuth flow. if ts.Token == nil { fmt.Printf("Could not load token from file, starting OAuth flow.\n") - token, err = g.getTokenFromWeb(ctx, conf, noBrowser...) + token, err = g.getTokenFromWeb(ctx, conf, opts) if err != nil { return nil, fmt.Errorf("failed to get token from web: %w", err) } @@ -205,12 +212,12 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf // Parameters: // - ctx: The context for the HTTP client // - config: The OAuth2 configuration -// - noBrowser: Optional parameter to disable browser opening +// - opts: Optional parameters to customize browser and prompt behavior // // Returns: // - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - error: An error if the token acquisition fails, nil otherwise -func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) { +func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. codeChan := make(chan string) errChan := make(chan error) @@ -250,7 +257,12 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Open the authorization URL in the user's browser. authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - if len(noBrowser) == 1 && !noBrowser[0] { + noBrowser := false + if opts != nil { + noBrowser = opts.NoBrowser + } + + if !noBrowser { fmt.Println("Opening browser for authentication...") // Check if browser is available @@ -281,11 +293,47 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Wait for the authorization code or an error. var authCode string + manualCodeChan := make(chan string, 1) + manualErrChan := make(chan error, 1) + if opts != nil && opts.Prompt != nil { + go func() { + input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") + if err != nil { + manualErrChan <- err + return + } + parsed, err := misc.ParseOAuthCallback(input) + if err != nil { + manualErrChan <- err + return + } + if parsed == nil { + return + } + if parsed.Error != "" { + manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error) + return + } + if parsed.Code == "" { + manualErrChan <- fmt.Errorf("code not found in callback") + return + } + manualCodeChan <- parsed.Code + }() + } else { + manualCodeChan = nil + manualErrChan = nil + } + select { case code := <-codeChan: authCode = code case err := <-errChan: return nil, err + case code := <-manualCodeChan: + authCode = code + case err := <-manualErrChan: + return nil, err case <-time.After(5 * time.Minute): // Timeout return nil, fmt.Errorf("oauth flow timed out") } diff --git a/internal/cmd/anthropic_login.go b/internal/cmd/anthropic_login.go index 8e9d01cd..6efd87a8 100644 --- a/internal/cmd/anthropic_login.go +++ b/internal/cmd/anthropic_login.go @@ -24,12 +24,17 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } _, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts) diff --git a/internal/cmd/antigravity_login.go b/internal/cmd/antigravity_login.go index b2602638..1cd42899 100644 --- a/internal/cmd/antigravity_login.go +++ b/internal/cmd/antigravity_login.go @@ -15,11 +15,16 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts) diff --git a/internal/cmd/iflow_login.go b/internal/cmd/iflow_login.go index ba43470b..cf00b63c 100644 --- a/internal/cmd/iflow_login.go +++ b/internal/cmd/iflow_login.go @@ -20,13 +20,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) { promptFn := options.Prompt if promptFn == nil { - promptFn = func(prompt string) (string, error) { - fmt.Println() - fmt.Println(prompt) - var value string - _, err := fmt.Scanln(&value) - return value, err - } + promptFn = defaultProjectPrompt() } authOpts := &sdkAuth.LoginOptions{ diff --git a/internal/cmd/login.go b/internal/cmd/login.go index de01cec5..0f079b4b 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -55,11 +55,17 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { ctx := context.Background() + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + options.Prompt = promptFn + } + loginOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, ProjectID: strings.TrimSpace(projectID), Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } authenticator := sdkAuth.NewGeminiAuthenticator() @@ -76,7 +82,10 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { } geminiAuth := gemini.NewGeminiAuth() - httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser) + httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{ + NoBrowser: options.NoBrowser, + Prompt: promptFn, + }) if errClient != nil { log.Errorf("Gemini authentication failed: %v", errClient) return @@ -90,11 +99,6 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { return } - promptFn := options.Prompt - if promptFn == nil { - promptFn = defaultProjectPrompt() - } - selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) if errSelection != nil { diff --git a/internal/cmd/openai_login.go b/internal/cmd/openai_login.go index e402e476..d981f6ae 100644 --- a/internal/cmd/openai_login.go +++ b/internal/cmd/openai_login.go @@ -35,12 +35,17 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) diff --git a/internal/misc/oauth.go b/internal/misc/oauth.go index acf034b2..d5cae403 100644 --- a/internal/misc/oauth.go +++ b/internal/misc/oauth.go @@ -4,6 +4,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/url" + "strings" ) // GenerateRandomState generates a cryptographically secure random state parameter @@ -19,3 +21,81 @@ func GenerateRandomState() (string, error) { } return hex.EncodeToString(bytes), nil } + +// OAuthCallback captures the parsed OAuth callback parameters. +type OAuthCallback struct { + Code string + State string + Error string + ErrorDescription string +} + +// ParseOAuthCallback extracts OAuth parameters from a callback URL. +// It returns nil when the input is empty. +func ParseOAuthCallback(input string) (*OAuthCallback, error) { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return nil, nil + } + + candidate := trimmed + if !strings.Contains(candidate, "://") { + if strings.HasPrefix(candidate, "?") { + candidate = "http://localhost" + candidate + } else if strings.Contains(candidate, "=") { + candidate = "http://localhost/?" + candidate + } else { + return nil, fmt.Errorf("invalid callback URL") + } + } + + parsedURL, err := url.Parse(candidate) + if err != nil { + return nil, err + } + + query := parsedURL.Query() + code := strings.TrimSpace(query.Get("code")) + state := strings.TrimSpace(query.Get("state")) + errCode := strings.TrimSpace(query.Get("error")) + errDesc := strings.TrimSpace(query.Get("error_description")) + + if parsedURL.Fragment != "" { + if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil { + if code == "" { + code = strings.TrimSpace(fragQuery.Get("code")) + } + if state == "" { + state = strings.TrimSpace(fragQuery.Get("state")) + } + if errCode == "" { + errCode = strings.TrimSpace(fragQuery.Get("error")) + } + if errDesc == "" { + errDesc = strings.TrimSpace(fragQuery.Get("error_description")) + } + } + } + + if code != "" && state == "" && strings.Contains(code, "#") { + parts := strings.SplitN(code, "#", 2) + code = parts[0] + state = parts[1] + } + + if errCode == "" && errDesc != "" { + errCode = errDesc + errDesc = "" + } + + if code == "" && errCode == "" { + return nil, fmt.Errorf("callback URL missing code") + } + + return &OAuthCallback{ + Code: code, + State: state, + Error: errCode, + ErrorDescription: errDesc, + }, nil +} diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index b3d7f6c5..832bd88e 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -99,9 +99,18 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o fmt.Println("Waiting for antigravity authentication callback...") var cbRes callbackResult + manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "antigravity") select { case res := <-cbChan: cbRes = res + case manual := <-manualCh: + cbRes = callbackResult{ + Code: manual.Code, + State: manual.State, + Error: manual.Error, + } + case err = <-manualErrCh: + return nil, err case <-time.After(5 * time.Minute): return nil, fmt.Errorf("antigravity: authentication timed out") } diff --git a/sdk/auth/claude.go b/sdk/auth/claude.go index da9e5065..d88cdf29 100644 --- a/sdk/auth/claude.go +++ b/sdk/auth/claude.go @@ -98,16 +98,41 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt fmt.Println("Waiting for Claude authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { + callbackCh := make(chan *claude.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Claude") + manualDescription := "" + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *claude.OAuthResult + select { + case result = <-callbackCh: + case err = <-callbackErrCh: if strings.Contains(err.Error(), "timeout") { return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) } return nil, err + case manual := <-manualCh: + manualDescription = manual.ErrorDescription + result = &claude.OAuthResult{ + Code: manual.Code, + State: manual.State, + Error: manual.Error, + } + case err = <-manualErrCh: + return nil, err } if result.Error != "" { - return nil, claude.NewOAuthError(result.Error, "", http.StatusBadRequest) + return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest) } if result.State != state { diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index 138c2141..b0a6b4a4 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -97,16 +97,41 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts fmt.Println("Waiting for Codex authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { + callbackCh := make(chan *codex.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Codex") + manualDescription := "" + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *codex.OAuthResult + select { + case result = <-callbackCh: + case err = <-callbackErrCh: if strings.Contains(err.Error(), "timeout") { return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) } return nil, err + case manual := <-manualCh: + manualDescription = manual.ErrorDescription + result = &codex.OAuthResult{ + Code: manual.Code, + State: manual.State, + Error: manual.Error, + } + case err = <-manualErrCh: + return nil, err } if result.Error != "" { - return nil, codex.NewOAuthError(result.Error, "", http.StatusBadRequest) + return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest) } if result.State != state { diff --git a/sdk/auth/gemini.go b/sdk/auth/gemini.go index 7110101f..75ef4579 100644 --- a/sdk/auth/gemini.go +++ b/sdk/auth/gemini.go @@ -44,7 +44,10 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt } geminiAuth := gemini.NewGeminiAuth() - _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, opts.NoBrowser) + _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{ + NoBrowser: opts.NoBrowser, + Prompt: opts.Prompt, + }) if err != nil { return nil, fmt.Errorf("gemini authentication failed: %w", err) } diff --git a/sdk/auth/iflow.go b/sdk/auth/iflow.go index ee96bdaa..d7621a99 100644 --- a/sdk/auth/iflow.go +++ b/sdk/auth/iflow.go @@ -84,9 +84,32 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts fmt.Println("Waiting for iFlow authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { + callbackCh := make(chan *iflow.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "iFlow") + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *iflow.OAuthResult + select { + case result = <-callbackCh: + case err = <-callbackErrCh: return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + case manual := <-manualCh: + result = &iflow.OAuthResult{ + Code: manual.Code, + State: manual.State, + Error: manual.Error, + } + case err = <-manualErrCh: + return nil, err } if result.Error != "" { return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) diff --git a/sdk/auth/oauth_callback.go b/sdk/auth/oauth_callback.go new file mode 100644 index 00000000..3f0ac925 --- /dev/null +++ b/sdk/auth/oauth_callback.go @@ -0,0 +1,41 @@ +package auth + +import ( + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" +) + +func promptForOAuthCallback(prompt func(string) (string, error), provider string) (<-chan *misc.OAuthCallback, <-chan error) { + if prompt == nil { + return nil, nil + } + + resultCh := make(chan *misc.OAuthCallback, 1) + errCh := make(chan error, 1) + + go func() { + label := provider + if label == "" { + label = "OAuth" + } + input, err := prompt(fmt.Sprintf("Paste the %s callback URL (or press Enter to keep waiting): ", label)) + if err != nil { + errCh <- err + return + } + + parsed, err := misc.ParseOAuthCallback(input) + if err != nil { + errCh <- err + return + } + if parsed == nil { + return + } + + resultCh <- parsed + }() + + return resultCh, errCh +} From 9855615f1ee7c936c9d349c3be1babaa2559ab69 Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sat, 20 Dec 2025 19:03:38 +0800 Subject: [PATCH 2/6] fix(gemini): avoid stale manual oauth prompt and accept schemeless callbacks --- internal/auth/gemini/gemini_auth.go | 90 ++++++++++++++++++----------- internal/misc/oauth.go | 2 + 2 files changed, 57 insertions(+), 35 deletions(-) diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index dc9b1034..7b18e738 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -219,8 +219,8 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf // - error: An error if the token acquisition fails, nil otherwise func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. - codeChan := make(chan string) - errChan := make(chan error) + codeChan := make(chan string, 1) + errChan := make(chan error, 1) // Create a new HTTP server with its own multiplexer. mux := http.NewServeMux() @@ -230,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { if err := r.URL.Query().Get("error"); err != "" { _, _ = fmt.Fprintf(w, "Authentication failed: %s", err) - errChan <- fmt.Errorf("authentication failed via callback: %s", err) + select { + case errChan <- fmt.Errorf("authentication failed via callback: %s", err): + default: + } return } code := r.URL.Query().Get("code") if code == "" { _, _ = fmt.Fprint(w, "Authentication failed: code not found.") - errChan <- fmt.Errorf("code not found in callback") + select { + case errChan <- fmt.Errorf("code not found in callback"): + default: + } return } _, _ = fmt.Fprint(w, "

Authentication successful!

You can close this window.

") - codeChan <- code + select { + case codeChan <- code: + default: + } }) // Start the server in a goroutine. @@ -293,49 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Wait for the authorization code or an error. var authCode string - manualCodeChan := make(chan string, 1) - manualErrChan := make(chan error, 1) + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time if opts != nil && opts.Prompt != nil { - go func() { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case code := <-codeChan: + authCode = code + break waitForCallback + case err := <-errChan: + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case code := <-codeChan: + authCode = code + break waitForCallback + case err := <-errChan: + return nil, err + default: + } input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") if err != nil { - manualErrChan <- err - return + return nil, err } parsed, err := misc.ParseOAuthCallback(input) if err != nil { - manualErrChan <- err - return + return nil, err } if parsed == nil { - return + continue } if parsed.Error != "" { - manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error) - return + return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error) } if parsed.Code == "" { - manualErrChan <- fmt.Errorf("code not found in callback") - return + return nil, fmt.Errorf("code not found in callback") } - manualCodeChan <- parsed.Code - }() - } else { - manualCodeChan = nil - manualErrChan = nil - } - - select { - case code := <-codeChan: - authCode = code - case err := <-errChan: - return nil, err - case code := <-manualCodeChan: - authCode = code - case err := <-manualErrChan: - return nil, err - case <-time.After(5 * time.Minute): // Timeout - return nil, fmt.Errorf("oauth flow timed out") + authCode = parsed.Code + break waitForCallback + case <-timeoutTimer.C: + return nil, fmt.Errorf("oauth flow timed out") + } } // Shutdown the server. diff --git a/internal/misc/oauth.go b/internal/misc/oauth.go index d5cae403..c14f39d2 100644 --- a/internal/misc/oauth.go +++ b/internal/misc/oauth.go @@ -42,6 +42,8 @@ func ParseOAuthCallback(input string) (*OAuthCallback, error) { if !strings.Contains(candidate, "://") { if strings.HasPrefix(candidate, "?") { candidate = "http://localhost" + candidate + } else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") { + candidate = "http://" + candidate } else if strings.Contains(candidate, "=") { candidate = "http://localhost/?" + candidate } else { From 24970baa576dc76754233a6eae1582e04926435c Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 02:14:28 +0800 Subject: [PATCH 3/6] management: allow prefix updates in provider PATCH handlers --- .../api/handlers/management/config_lists.go | 369 ++++++++++-------- 1 file changed, 212 insertions(+), 157 deletions(-) diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index a0d0b169..7e42b64b 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -145,71 +145,74 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchGeminiKey(c *gin.Context) { + type geminiKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.GeminiKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *geminiKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - value.APIKey = strings.TrimSpace(value.APIKey) - value.BaseURL = strings.TrimSpace(value.BaseURL) - value.ProxyURL = strings.TrimSpace(value.ProxyURL) - value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) - if value.APIKey == "" { - // Treat empty API key as delete. - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { - h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...) - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } - if body.Match != nil { - match := strings.TrimSpace(*body.Match) - if match != "" { - out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) - removed := false - for i := range h.cfg.GeminiKey { - if !removed && h.cfg.GeminiKey[i].APIKey == match { - removed = true - continue - } - out = append(out, h.cfg.GeminiKey[i]) - } - if removed { - h.cfg.GeminiKey = out - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) + if match != "" { + for i := range h.cfg.GeminiKey { + if h.cfg.GeminiKey[i].APIKey == match { + targetIndex = i + break } } } + } + if targetIndex == -1 { c.JSON(404, gin.H{"error": "item not found"}) return } - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { - h.cfg.GeminiKey[*body.Index] = value - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } - if body.Match != nil { - match := strings.TrimSpace(*body.Match) - for i := range h.cfg.GeminiKey { - if h.cfg.GeminiKey[i].APIKey == match { - h.cfg.GeminiKey[i] = value - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } + entry := h.cfg.GeminiKey[targetIndex] + if body.Value.APIKey != nil { + trimmed := strings.TrimSpace(*body.Value.APIKey) + if trimmed == "" { + h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...) + h.cfg.SanitizeGeminiKeys() + h.persist(c) + return } + entry.APIKey = trimmed } - c.JSON(404, gin.H{"error": "item not found"}) + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + h.cfg.GeminiKey[targetIndex] = entry + h.cfg.SanitizeGeminiKeys() + h.persist(c) } + func (h *Handler) DeleteGeminiKey(c *gin.Context) { if val := strings.TrimSpace(c.Query("api-key")); val != "" { out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) @@ -268,35 +271,70 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchClaudeKey(c *gin.Context) { + type claudeKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Models *[]config.ClaudeModel `json:"models"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.ClaudeKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *claudeKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - normalizeClaudeKey(&value) + targetIndex := -1 if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) { - h.cfg.ClaudeKey[*body.Index] = value - h.cfg.SanitizeClaudeKeys() - h.persist(c) - return + targetIndex = *body.Index } - if body.Match != nil { + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) for i := range h.cfg.ClaudeKey { - if h.cfg.ClaudeKey[i].APIKey == *body.Match { - h.cfg.ClaudeKey[i] = value - h.cfg.SanitizeClaudeKeys() - h.persist(c) - return + if h.cfg.ClaudeKey[i].APIKey == match { + targetIndex = i + break } } } - c.JSON(404, gin.H{"error": "item not found"}) + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.ClaudeKey[targetIndex] + if body.Value.APIKey != nil { + entry.APIKey = strings.TrimSpace(*body.Value.APIKey) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Models != nil { + entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + normalizeClaudeKey(&entry) + h.cfg.ClaudeKey[targetIndex] = entry + h.cfg.SanitizeClaudeKeys() + h.persist(c) } + func (h *Handler) DeleteClaudeKey(c *gin.Context) { if val := c.Query("api-key"); val != "" { out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) @@ -356,62 +394,73 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) { h.persist(c) } func (h *Handler) PatchOpenAICompat(c *gin.Context) { + type openAICompatPatch struct { + Name *string `json:"name"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"` + Models *[]config.OpenAICompatibilityModel `json:"models"` + Headers *map[string]string `json:"headers"` + } var body struct { - Name *string `json:"name"` - Index *int `json:"index"` - Value *config.OpenAICompatibility `json:"value"` + Name *string `json:"name"` + Index *int `json:"index"` + Value *openAICompatPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - normalizeOpenAICompatibilityEntry(body.Value) - // If base-url becomes empty, delete the provider instead of updating - if strings.TrimSpace(body.Value.BaseURL) == "" { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { - h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...) + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Name != nil { + match := strings.TrimSpace(*body.Name) + for i := range h.cfg.OpenAICompatibility { + if h.cfg.OpenAICompatibility[i].Name == match { + targetIndex = i + break + } + } + } + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.OpenAICompatibility[targetIndex] + if body.Value.Name != nil { + entry.Name = strings.TrimSpace(*body.Value.Name) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...) h.cfg.SanitizeOpenAICompatibility() h.persist(c) return } - if body.Name != nil { - out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) - removed := false - for i := range h.cfg.OpenAICompatibility { - if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name { - removed = true - continue - } - out = append(out, h.cfg.OpenAICompatibility[i]) - } - if removed { - h.cfg.OpenAICompatibility = out - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return - } - } - c.JSON(404, gin.H{"error": "item not found"}) - return + entry.BaseURL = trimmed } - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { - h.cfg.OpenAICompatibility[*body.Index] = *body.Value - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return + if body.Value.APIKeyEntries != nil { + entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...) } - if body.Name != nil { - for i := range h.cfg.OpenAICompatibility { - if h.cfg.OpenAICompatibility[i].Name == *body.Name { - h.cfg.OpenAICompatibility[i] = *body.Value - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return - } - } + if body.Value.Models != nil { + entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...) } - c.JSON(404, gin.H{"error": "item not found"}) + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + normalizeOpenAICompatibilityEntry(&entry) + h.cfg.OpenAICompatibility[targetIndex] = entry + h.cfg.SanitizeOpenAICompatibility() + h.persist(c) } + func (h *Handler) DeleteOpenAICompat(c *gin.Context) { if name := c.Query("name"); name != "" { out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) @@ -563,66 +612,72 @@ func (h *Handler) PutCodexKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchCodexKey(c *gin.Context) { + type codexKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.CodexKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *codexKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - value.APIKey = strings.TrimSpace(value.APIKey) - value.BaseURL = strings.TrimSpace(value.BaseURL) - value.ProxyURL = strings.TrimSpace(value.ProxyURL) - value.Headers = config.NormalizeHeaders(value.Headers) - value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) - // If base-url becomes empty, delete instead of update - if value.BaseURL == "" { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { - h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...) - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - if body.Match != nil { - out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) - removed := false - for i := range h.cfg.CodexKey { - if !removed && h.cfg.CodexKey[i].APIKey == *body.Match { - removed = true - continue - } - out = append(out, h.cfg.CodexKey[i]) - } - if removed { - h.cfg.CodexKey = out - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - } - } else { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { - h.cfg.CodexKey[*body.Index] = value - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - if body.Match != nil { - for i := range h.cfg.CodexKey { - if h.cfg.CodexKey[i].APIKey == *body.Match { - h.cfg.CodexKey[i] = value - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) + for i := range h.cfg.CodexKey { + if h.cfg.CodexKey[i].APIKey == match { + targetIndex = i + break } } } - c.JSON(404, gin.H{"error": "item not found"}) + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.CodexKey[targetIndex] + if body.Value.APIKey != nil { + entry.APIKey = strings.TrimSpace(*body.Value.APIKey) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...) + h.cfg.SanitizeCodexKeys() + h.persist(c) + return + } + entry.BaseURL = trimmed + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + h.cfg.CodexKey[targetIndex] = entry + h.cfg.SanitizeCodexKeys() + h.persist(c) } + func (h *Handler) DeleteCodexKey(c *gin.Context) { if val := c.Query("api-key"); val != "" { out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) From cd0c94f48acc9b66f51e4c1710923c37c9d2255c Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 07:06:28 +0800 Subject: [PATCH 4/6] fix(sdk/auth): prevent OAuth manual prompt goroutine leak,Use timer-based manual prompt per provider and remove oauth_callback helper. --- sdk/auth/antigravity.go | 60 ++++++++++++++++++++++++++-------- sdk/auth/claude.go | 67 +++++++++++++++++++++++++++++--------- sdk/auth/codex.go | 67 +++++++++++++++++++++++++++++--------- sdk/auth/iflow.go | 56 ++++++++++++++++++++++++------- sdk/auth/oauth_callback.go | 41 ----------------------- 5 files changed, 193 insertions(+), 98 deletions(-) delete mode 100644 sdk/auth/oauth_callback.go diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index 832bd88e..ae22f772 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -99,20 +99,54 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o fmt.Println("Waiting for antigravity authentication callback...") var cbRes callbackResult - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "antigravity") - select { - case res := <-cbChan: - cbRes = res - case manual := <-manualCh: - cbRes = callbackResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case res := <-cbChan: + cbRes = res + break waitForCallback + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case res := <-cbChan: + cbRes = res + break waitForCallback + default: + } + input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + cbRes = callbackResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback + case <-timeoutTimer.C: + return nil, fmt.Errorf("antigravity: authentication timed out") } - case err = <-manualErrCh: - return nil, err - case <-time.After(5 * time.Minute): - return nil, fmt.Errorf("antigravity: authentication timed out") } if cbRes.Error != "" { diff --git a/sdk/auth/claude.go b/sdk/auth/claude.go index d88cdf29..c43b78cd 100644 --- a/sdk/auth/claude.go +++ b/sdk/auth/claude.go @@ -100,7 +100,6 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt callbackCh := make(chan *claude.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Claude") manualDescription := "" go func() { @@ -113,22 +112,58 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt }() var result *claude.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { - return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + } + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + } + return nil, err + default: + } + input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + manualDescription = parsed.ErrorDescription + result = &claude.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - return nil, err - case manual := <-manualCh: - manualDescription = manual.ErrorDescription - result = &claude.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, - } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index b0a6b4a4..99992525 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -99,7 +99,6 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts callbackCh := make(chan *codex.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Codex") manualDescription := "" go func() { @@ -112,22 +111,58 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts }() var result *codex.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { - return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + } + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + } + return nil, err + default: + } + input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + manualDescription = parsed.ErrorDescription + result = &codex.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - return nil, err - case manual := <-manualCh: - manualDescription = manual.ErrorDescription - result = &codex.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, - } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { diff --git a/sdk/auth/iflow.go b/sdk/auth/iflow.go index d7621a99..3fd82f1d 100644 --- a/sdk/auth/iflow.go +++ b/sdk/auth/iflow.go @@ -86,7 +86,6 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts callbackCh := make(chan *iflow.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "iFlow") go func() { result, errWait := oauthServer.WaitForCallback(5 * time.Minute) @@ -98,18 +97,51 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts }() var result *iflow.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) - case manual := <-manualCh: - result = &iflow.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + default: + } + input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + result = &iflow.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) diff --git a/sdk/auth/oauth_callback.go b/sdk/auth/oauth_callback.go deleted file mode 100644 index 3f0ac925..00000000 --- a/sdk/auth/oauth_callback.go +++ /dev/null @@ -1,41 +0,0 @@ -package auth - -import ( - "fmt" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" -) - -func promptForOAuthCallback(prompt func(string) (string, error), provider string) (<-chan *misc.OAuthCallback, <-chan error) { - if prompt == nil { - return nil, nil - } - - resultCh := make(chan *misc.OAuthCallback, 1) - errCh := make(chan error, 1) - - go func() { - label := provider - if label == "" { - label = "OAuth" - } - input, err := prompt(fmt.Sprintf("Paste the %s callback URL (or press Enter to keep waiting): ", label)) - if err != nil { - errCh <- err - return - } - - parsed, err := misc.ParseOAuthCallback(input) - if err != nil { - errCh <- err - return - } - if parsed == nil { - return - } - - resultCh <- parsed - }() - - return resultCh, errCh -} From 05d201ece84ef10817a0b92544cd1ca816adface Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 07:21:12 +0800 Subject: [PATCH 5/6] fix(gemini): gate callback prompt on project_id --- internal/cmd/login.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 0f079b4b..3bb0b9a5 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -58,14 +58,19 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { promptFn := options.Prompt if promptFn == nil { promptFn = defaultProjectPrompt() - options.Prompt = promptFn + } + + trimmedProjectID := strings.TrimSpace(projectID) + callbackPrompt := promptFn + if trimmedProjectID == "" { + callbackPrompt = nil } loginOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, - ProjectID: strings.TrimSpace(projectID), + ProjectID: trimmedProjectID, Metadata: map[string]string{}, - Prompt: promptFn, + Prompt: callbackPrompt, } authenticator := sdkAuth.NewGeminiAuthenticator() @@ -84,7 +89,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { geminiAuth := gemini.NewGeminiAuth() httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{ NoBrowser: options.NoBrowser, - Prompt: promptFn, + Prompt: callbackPrompt, }) if errClient != nil { log.Errorf("Gemini authentication failed: %v", errClient) @@ -99,7 +104,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { return } - selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn) + selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) if errSelection != nil { log.Errorf("Invalid project selection: %v", errSelection) From 781bc1521b827b4e2c9f60f17a5d8b5a5f28e85c Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 10:48:40 +0800 Subject: [PATCH 6/6] fix(oauth): prevent stale session timeouts after login - stop callback forwarders by instance to avoid cross-session shutdowns - clear pending sessions for a provider after successful auth --- .../api/handlers/management/auth_files.go | 66 ++++++++++++++++--- .../api/handlers/management/oauth_sessions.go | 25 +++++++ 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 4f42bd7a..41a4fde4 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -197,6 +197,19 @@ func stopCallbackForwarder(port int) { stopForwarderInstance(port, forwarder) } +func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) { + if forwarder == nil { + return + } + callbackForwardersMu.Lock() + if current := callbackForwarders[port]; current == forwarder { + delete(callbackForwarders, port) + } + callbackForwardersMu.Unlock() + + stopForwarderInstance(port, forwarder) +} + func stopForwarderInstance(port int, forwarder *callbackForwarder) { if forwarder == nil || forwarder.server == nil { return @@ -785,6 +798,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { RegisterOAuthSession(state, "anthropic") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") if errTarget != nil { @@ -792,7 +806,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start anthropic callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -801,7 +816,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(anthropicCallbackPort) + defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder) } // Helper: wait for callback file @@ -809,6 +824,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { deadline := time.Now().Add(timeout) for { + if !IsOAuthSessionPending(state, "anthropic") { + return nil, errOAuthSessionNotPending + } if time.Now().After(deadline) { SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") @@ -828,6 +846,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { // Wait up to 5 minutes resultMap, errWait := waitForFile(waitFile, 5*time.Minute) if errWait != nil { + if errors.Is(errWait, errOAuthSessionNotPending) { + return + } authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait) log.Error(claude.GetUserFriendlyMessage(authErr)) return @@ -933,6 +954,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } fmt.Println("You can now use Claude services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("anthropic") }() c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) @@ -968,6 +990,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { RegisterOAuthSession(state, "gemini") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/google/callback") if errTarget != nil { @@ -975,7 +998,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start gemini callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -984,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(geminiCallbackPort) + defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder) } // Wait for callback file written by server route @@ -993,6 +1017,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var authCode string for { + if !IsOAuthSessionPending(state, "gemini") { + return + } if time.Now().After(deadline) { log.Error("oauth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out") @@ -1168,6 +1195,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("gemini") fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() @@ -1209,6 +1237,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { RegisterOAuthSession(state, "codex") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/codex/callback") if errTarget != nil { @@ -1216,7 +1245,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start codex callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1225,7 +1255,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(codexCallbackPort) + defer stopCallbackForwarderInstance(codexCallbackPort, forwarder) } // Wait for callback file @@ -1233,6 +1263,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var code string for { + if !IsOAuthSessionPending(state, "codex") { + return + } if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) @@ -1348,6 +1381,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } fmt.Println("You can now use Codex services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("codex") }() c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) @@ -1393,6 +1427,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { RegisterOAuthSession(state, "antigravity") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") if errTarget != nil { @@ -1400,7 +1435,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start antigravity callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1409,13 +1445,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(antigravityCallbackPort) + defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder) } waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) deadline := time.Now().Add(5 * time.Minute) var authCode string for { + if !IsOAuthSessionPending(state, "antigravity") { + return + } if time.Now().After(deadline) { log.Error("oauth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out") @@ -1578,6 +1617,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("antigravity") fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1655,6 +1695,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { RegisterOAuthSession(state, "iflow") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/iflow/callback") if errTarget != nil { @@ -1662,7 +1703,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start iflow callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"}) return @@ -1671,7 +1713,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(iflowauth.CallbackPort) + defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder) } fmt.Println("Waiting for authentication...") @@ -1679,6 +1721,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var resultMap map[string]string for { + if !IsOAuthSessionPending(state, "iflow") { + return + } if time.Now().After(deadline) { SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") @@ -1745,6 +1790,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } fmt.Println("You can now use iFlow services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("iflow") }() c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go index f23b608c..05ff8d1f 100644 --- a/internal/api/handlers/management/oauth_sessions.go +++ b/internal/api/handlers/management/oauth_sessions.go @@ -111,6 +111,27 @@ func (s *oauthSessionStore) Complete(state string) { delete(s.sessions, state) } +func (s *oauthSessionStore) CompleteProvider(provider string) int { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" { + return 0 + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + removed := 0 + for state, session := range s.sessions { + if strings.EqualFold(session.Provider, provider) { + delete(s.sessions, state) + removed++ + } + } + return removed +} + func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { state = strings.TrimSpace(state) now := time.Now() @@ -153,6 +174,10 @@ func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } +func CompleteOAuthSessionsByProvider(provider string) int { + return oauthSessions.CompleteProvider(provider) +} + func GetOAuthSession(state string) (provider string, status string, ok bool) { session, ok := oauthSessions.Get(state) if !ok {