fix(sdk/auth): prevent OAuth manual prompt goroutine leak,Use timer-based manual prompt per provider and remove oauth_callback helper.

This commit is contained in:
Supra4E8C
2025-12-21 07:06:28 +08:00
parent 24970baa57
commit cd0c94f48a
5 changed files with 193 additions and 98 deletions

View File

@@ -99,20 +99,54 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
fmt.Println("Waiting for antigravity authentication callback...")
var cbRes callbackResult
manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "antigravity")
select {
case res := <-cbChan:
cbRes = res
case manual := <-manualCh:
cbRes = callbackResult{
Code: manual.Code,
State: manual.State,
Error: manual.Error,
timeoutTimer := time.NewTimer(5 * time.Minute)
defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case res := <-cbChan:
cbRes = res
break waitForCallback
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case res := <-cbChan:
cbRes = res
break waitForCallback
default:
}
input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
cbRes = callbackResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("antigravity: authentication timed out")
}
case err = <-manualErrCh:
return nil, err
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("antigravity: authentication timed out")
}
if cbRes.Error != "" {

View File

@@ -100,7 +100,6 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
callbackCh := make(chan *claude.OAuthResult, 1)
callbackErrCh := make(chan error, 1)
manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Claude")
manualDescription := ""
go func() {
@@ -113,22 +112,58 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
}()
var result *claude.OAuthResult
select {
case result = <-callbackCh:
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
}
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
}
return nil, err
default:
}
input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
manualDescription = parsed.ErrorDescription
result = &claude.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
return nil, err
case manual := <-manualCh:
manualDescription = manual.ErrorDescription
result = &claude.OAuthResult{
Code: manual.Code,
State: manual.State,
Error: manual.Error,
}
case err = <-manualErrCh:
return nil, err
}
if result.Error != "" {

View File

@@ -99,7 +99,6 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
callbackCh := make(chan *codex.OAuthResult, 1)
callbackErrCh := make(chan error, 1)
manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Codex")
manualDescription := ""
go func() {
@@ -112,22 +111,58 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
}()
var result *codex.OAuthResult
select {
case result = <-callbackCh:
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
}
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
}
return nil, err
default:
}
input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
manualDescription = parsed.ErrorDescription
result = &codex.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
return nil, err
case manual := <-manualCh:
manualDescription = manual.ErrorDescription
result = &codex.OAuthResult{
Code: manual.Code,
State: manual.State,
Error: manual.Error,
}
case err = <-manualErrCh:
return nil, err
}
if result.Error != "" {

View File

@@ -86,7 +86,6 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
callbackCh := make(chan *iflow.OAuthResult, 1)
callbackErrCh := make(chan error, 1)
manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "iFlow")
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
@@ -98,18 +97,51 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
}()
var result *iflow.OAuthResult
select {
case result = <-callbackCh:
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
case manual := <-manualCh:
result = &iflow.OAuthResult{
Code: manual.Code,
State: manual.State,
Error: manual.Error,
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
default:
}
input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
result = &iflow.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
case err = <-manualErrCh:
return nil, err
}
if result.Error != "" {
return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error)

View File

@@ -1,41 +0,0 @@
package auth
import (
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
func promptForOAuthCallback(prompt func(string) (string, error), provider string) (<-chan *misc.OAuthCallback, <-chan error) {
if prompt == nil {
return nil, nil
}
resultCh := make(chan *misc.OAuthCallback, 1)
errCh := make(chan error, 1)
go func() {
label := provider
if label == "" {
label = "OAuth"
}
input, err := prompt(fmt.Sprintf("Paste the %s callback URL (or press Enter to keep waiting): ", label))
if err != nil {
errCh <- err
return
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
errCh <- err
return
}
if parsed == nil {
return
}
resultCh <- parsed
}()
return resultCh, errCh
}