fix(gemini): avoid stale manual oauth prompt and accept schemeless callbacks

This commit is contained in:
Supra4E8C
2025-12-20 19:03:38 +08:00
parent 93414f1baa
commit 9855615f1e
2 changed files with 57 additions and 35 deletions

View File

@@ -219,8 +219,8 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - error: An error if the token acquisition fails, nil otherwise // - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function. // Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string) codeChan := make(chan string, 1)
errChan := make(chan error) errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer. // Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -230,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" { if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err) _, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
errChan <- fmt.Errorf("authentication failed via callback: %s", err) select {
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
default:
}
return return
} }
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
if code == "" { if code == "" {
_, _ = fmt.Fprint(w, "Authentication failed: code not found.") _, _ = fmt.Fprint(w, "Authentication failed: code not found.")
errChan <- fmt.Errorf("code not found in callback") select {
case errChan <- fmt.Errorf("code not found in callback"):
default:
}
return return
} }
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>") _, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
codeChan <- code select {
case codeChan <- code:
default:
}
}) })
// Start the server in a goroutine. // Start the server in a goroutine.
@@ -293,49 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Wait for the authorization code or an error. // Wait for the authorization code or an error.
var authCode string var authCode string
manualCodeChan := make(chan string, 1) timeoutTimer := time.NewTimer(5 * time.Minute)
manualErrChan := make(chan error, 1) defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts != nil && opts.Prompt != nil { if opts != nil && opts.Prompt != nil {
go func() { manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
default:
}
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil { if err != nil {
manualErrChan <- err return nil, err
return
} }
parsed, err := misc.ParseOAuthCallback(input) parsed, err := misc.ParseOAuthCallback(input)
if err != nil { if err != nil {
manualErrChan <- err return nil, err
return
} }
if parsed == nil { if parsed == nil {
return continue
} }
if parsed.Error != "" { if parsed.Error != "" {
manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error) return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
return
} }
if parsed.Code == "" { if parsed.Code == "" {
manualErrChan <- fmt.Errorf("code not found in callback") return nil, fmt.Errorf("code not found in callback")
return
} }
manualCodeChan <- parsed.Code authCode = parsed.Code
}() break waitForCallback
} else { case <-timeoutTimer.C:
manualCodeChan = nil return nil, fmt.Errorf("oauth flow timed out")
manualErrChan = nil }
}
select {
case code := <-codeChan:
authCode = code
case err := <-errChan:
return nil, err
case code := <-manualCodeChan:
authCode = code
case err := <-manualErrChan:
return nil, err
case <-time.After(5 * time.Minute): // Timeout
return nil, fmt.Errorf("oauth flow timed out")
} }
// Shutdown the server. // Shutdown the server.

View File

@@ -42,6 +42,8 @@ func ParseOAuthCallback(input string) (*OAuthCallback, error) {
if !strings.Contains(candidate, "://") { if !strings.Contains(candidate, "://") {
if strings.HasPrefix(candidate, "?") { if strings.HasPrefix(candidate, "?") {
candidate = "http://localhost" + candidate candidate = "http://localhost" + candidate
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
candidate = "http://" + candidate
} else if strings.Contains(candidate, "=") { } else if strings.Contains(candidate, "=") {
candidate = "http://localhost/?" + candidate candidate = "http://localhost/?" + candidate
} else { } else {