fix(gemini): avoid stale manual oauth prompt and accept schemeless callbacks

This commit is contained in:
Supra4E8C
2025-12-20 19:03:38 +08:00
parent 93414f1baa
commit 9855615f1e
2 changed files with 57 additions and 35 deletions

View File

@@ -219,8 +219,8 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - error: An error if the token acquisition fails, nil otherwise // - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function. // Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string) codeChan := make(chan string, 1)
errChan := make(chan error) errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer. // Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -230,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" { if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err) _, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
errChan <- fmt.Errorf("authentication failed via callback: %s", err) select {
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
default:
}
return return
} }
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
if code == "" { if code == "" {
_, _ = fmt.Fprint(w, "Authentication failed: code not found.") _, _ = fmt.Fprint(w, "Authentication failed: code not found.")
errChan <- fmt.Errorf("code not found in callback") select {
case errChan <- fmt.Errorf("code not found in callback"):
default:
}
return return
} }
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>") _, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
codeChan <- code select {
case codeChan <- code:
default:
}
}) })
// Start the server in a goroutine. // Start the server in a goroutine.
@@ -293,50 +302,61 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Wait for the authorization code or an error. // Wait for the authorization code or an error.
var authCode string var authCode string
manualCodeChan := make(chan string, 1) timeoutTimer := time.NewTimer(5 * time.Minute)
manualErrChan := make(chan error, 1) defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts != nil && opts.Prompt != nil { if opts != nil && opts.Prompt != nil {
go func() { manualPromptTimer = time.NewTimer(15 * time.Second)
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") manualPromptC = manualPromptTimer.C
if err != nil { defer manualPromptTimer.Stop()
manualErrChan <- err
return
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
manualErrChan <- err
return
}
if parsed == nil {
return
}
if parsed.Error != "" {
manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error)
return
}
if parsed.Code == "" {
manualErrChan <- fmt.Errorf("code not found in callback")
return
}
manualCodeChan <- parsed.Code
}()
} else {
manualCodeChan = nil
manualErrChan = nil
} }
waitForCallback:
for {
select { select {
case code := <-codeChan: case code := <-codeChan:
authCode = code authCode = code
break waitForCallback
case err := <-errChan: case err := <-errChan:
return nil, err return nil, err
case code := <-manualCodeChan: case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case code := <-codeChan:
authCode = code authCode = code
case err := <-manualErrChan: break waitForCallback
case err := <-errChan:
return nil, err return nil, err
case <-time.After(5 * time.Minute): // Timeout default:
}
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil {
return nil, err
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
return nil, err
}
if parsed == nil {
continue
}
if parsed.Error != "" {
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
}
if parsed.Code == "" {
return nil, fmt.Errorf("code not found in callback")
}
authCode = parsed.Code
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("oauth flow timed out") return nil, fmt.Errorf("oauth flow timed out")
} }
}
// Shutdown the server. // Shutdown the server.
if err := server.Shutdown(ctx); err != nil { if err := server.Shutdown(ctx); err != nil {

View File

@@ -42,6 +42,8 @@ func ParseOAuthCallback(input string) (*OAuthCallback, error) {
if !strings.Contains(candidate, "://") { if !strings.Contains(candidate, "://") {
if strings.HasPrefix(candidate, "?") { if strings.HasPrefix(candidate, "?") {
candidate = "http://localhost" + candidate candidate = "http://localhost" + candidate
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
candidate = "http://" + candidate
} else if strings.Contains(candidate, "=") { } else if strings.Contains(candidate, "=") {
candidate = "http://localhost/?" + candidate candidate = "http://localhost/?" + candidate
} else { } else {