From cd0c94f48acc9b66f51e4c1710923c37c9d2255c Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 07:06:28 +0800 Subject: [PATCH] fix(sdk/auth): prevent OAuth manual prompt goroutine leak,Use timer-based manual prompt per provider and remove oauth_callback helper. --- sdk/auth/antigravity.go | 60 ++++++++++++++++++++++++++-------- sdk/auth/claude.go | 67 +++++++++++++++++++++++++++++--------- sdk/auth/codex.go | 67 +++++++++++++++++++++++++++++--------- sdk/auth/iflow.go | 56 ++++++++++++++++++++++++------- sdk/auth/oauth_callback.go | 41 ----------------------- 5 files changed, 193 insertions(+), 98 deletions(-) delete mode 100644 sdk/auth/oauth_callback.go diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index 832bd88e..ae22f772 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -99,20 +99,54 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o fmt.Println("Waiting for antigravity authentication callback...") var cbRes callbackResult - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "antigravity") - select { - case res := <-cbChan: - cbRes = res - case manual := <-manualCh: - cbRes = callbackResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case res := <-cbChan: + cbRes = res + break waitForCallback + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case res := <-cbChan: + cbRes = res + break waitForCallback + default: + } + input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + cbRes = callbackResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback + case <-timeoutTimer.C: + return nil, fmt.Errorf("antigravity: authentication timed out") } - case err = <-manualErrCh: - return nil, err - case <-time.After(5 * time.Minute): - return nil, fmt.Errorf("antigravity: authentication timed out") } if cbRes.Error != "" { diff --git a/sdk/auth/claude.go b/sdk/auth/claude.go index d88cdf29..c43b78cd 100644 --- a/sdk/auth/claude.go +++ b/sdk/auth/claude.go @@ -100,7 +100,6 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt callbackCh := make(chan *claude.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Claude") manualDescription := "" go func() { @@ -113,22 +112,58 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt }() var result *claude.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { - return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + } + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + } + return nil, err + default: + } + input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + manualDescription = parsed.ErrorDescription + result = &claude.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - return nil, err - case manual := <-manualCh: - manualDescription = manual.ErrorDescription - result = &claude.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, - } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index b0a6b4a4..99992525 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -99,7 +99,6 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts callbackCh := make(chan *codex.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Codex") manualDescription := "" go func() { @@ -112,22 +111,58 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts }() var result *codex.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { - return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + } + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + } + return nil, err + default: + } + input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + manualDescription = parsed.ErrorDescription + result = &codex.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - return nil, err - case manual := <-manualCh: - manualDescription = manual.ErrorDescription - result = &codex.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, - } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { diff --git a/sdk/auth/iflow.go b/sdk/auth/iflow.go index d7621a99..3fd82f1d 100644 --- a/sdk/auth/iflow.go +++ b/sdk/auth/iflow.go @@ -86,7 +86,6 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts callbackCh := make(chan *iflow.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "iFlow") go func() { result, errWait := oauthServer.WaitForCallback(5 * time.Minute) @@ -98,18 +97,51 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts }() var result *iflow.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) - case manual := <-manualCh: - result = &iflow.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + default: + } + input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + result = &iflow.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) diff --git a/sdk/auth/oauth_callback.go b/sdk/auth/oauth_callback.go deleted file mode 100644 index 3f0ac925..00000000 --- a/sdk/auth/oauth_callback.go +++ /dev/null @@ -1,41 +0,0 @@ -package auth - -import ( - "fmt" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" -) - -func promptForOAuthCallback(prompt func(string) (string, error), provider string) (<-chan *misc.OAuthCallback, <-chan error) { - if prompt == nil { - return nil, nil - } - - resultCh := make(chan *misc.OAuthCallback, 1) - errCh := make(chan error, 1) - - go func() { - label := provider - if label == "" { - label = "OAuth" - } - input, err := prompt(fmt.Sprintf("Paste the %s callback URL (or press Enter to keep waiting): ", label)) - if err != nil { - errCh <- err - return - } - - parsed, err := misc.ParseOAuthCallback(input) - if err != nil { - errCh <- err - return - } - if parsed == nil { - return - } - - resultCh <- parsed - }() - - return resultCh, errCh -}