fix(gemini): avoid stale manual oauth prompt and accept schemeless callbacks

This commit is contained in:
Supra4E8C
2025-12-20 19:03:38 +08:00
parent 93414f1baa
commit 9855615f1e
2 changed files with 57 additions and 35 deletions

View File

@@ -219,8 +219,8 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string)
errChan := make(chan error)
codeChan := make(chan string, 1)
errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux()
@@ -230,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
select {
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
default:
}
return
}
code := r.URL.Query().Get("code")
if code == "" {
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
errChan <- fmt.Errorf("code not found in callback")
select {
case errChan <- fmt.Errorf("code not found in callback"):
default:
}
return
}
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
codeChan <- code
select {
case codeChan <- code:
default:
}
})
// Start the server in a goroutine.
@@ -293,49 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Wait for the authorization code or an error.
var authCode string
manualCodeChan := make(chan string, 1)
manualErrChan := make(chan error, 1)
timeoutTimer := time.NewTimer(5 * time.Minute)
defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts != nil && opts.Prompt != nil {
go func() {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
default:
}
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil {
manualErrChan <- err
return
return nil, err
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
manualErrChan <- err
return
return nil, err
}
if parsed == nil {
return
continue
}
if parsed.Error != "" {
manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error)
return
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
}
if parsed.Code == "" {
manualErrChan <- fmt.Errorf("code not found in callback")
return
return nil, fmt.Errorf("code not found in callback")
}
manualCodeChan <- parsed.Code
}()
} else {
manualCodeChan = nil
manualErrChan = nil
}
select {
case code := <-codeChan:
authCode = code
case err := <-errChan:
return nil, err
case code := <-manualCodeChan:
authCode = code
case err := <-manualErrChan:
return nil, err
case <-time.After(5 * time.Minute): // Timeout
return nil, fmt.Errorf("oauth flow timed out")
authCode = parsed.Code
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("oauth flow timed out")
}
}
// Shutdown the server.

View File

@@ -42,6 +42,8 @@ func ParseOAuthCallback(input string) (*OAuthCallback, error) {
if !strings.Contains(candidate, "://") {
if strings.HasPrefix(candidate, "?") {
candidate = "http://localhost" + candidate
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
candidate = "http://" + candidate
} else if strings.Contains(candidate, "=") {
candidate = "http://localhost/?" + candidate
} else {