diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index dc9b1034..7b18e738 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -219,8 +219,8 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf // - error: An error if the token acquisition fails, nil otherwise func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. - codeChan := make(chan string) - errChan := make(chan error) + codeChan := make(chan string, 1) + errChan := make(chan error, 1) // Create a new HTTP server with its own multiplexer. mux := http.NewServeMux() @@ -230,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { if err := r.URL.Query().Get("error"); err != "" { _, _ = fmt.Fprintf(w, "Authentication failed: %s", err) - errChan <- fmt.Errorf("authentication failed via callback: %s", err) + select { + case errChan <- fmt.Errorf("authentication failed via callback: %s", err): + default: + } return } code := r.URL.Query().Get("code") if code == "" { _, _ = fmt.Fprint(w, "Authentication failed: code not found.") - errChan <- fmt.Errorf("code not found in callback") + select { + case errChan <- fmt.Errorf("code not found in callback"): + default: + } return } _, _ = fmt.Fprint(w, "
You can close this window.
") - codeChan <- code + select { + case codeChan <- code: + default: + } }) // Start the server in a goroutine. @@ -293,49 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Wait for the authorization code or an error. var authCode string - manualCodeChan := make(chan string, 1) - manualErrChan := make(chan error, 1) + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time if opts != nil && opts.Prompt != nil { - go func() { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case code := <-codeChan: + authCode = code + break waitForCallback + case err := <-errChan: + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case code := <-codeChan: + authCode = code + break waitForCallback + case err := <-errChan: + return nil, err + default: + } input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") if err != nil { - manualErrChan <- err - return + return nil, err } parsed, err := misc.ParseOAuthCallback(input) if err != nil { - manualErrChan <- err - return + return nil, err } if parsed == nil { - return + continue } if parsed.Error != "" { - manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error) - return + return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error) } if parsed.Code == "" { - manualErrChan <- fmt.Errorf("code not found in callback") - return + return nil, fmt.Errorf("code not found in callback") } - manualCodeChan <- parsed.Code - }() - } else { - manualCodeChan = nil - manualErrChan = nil - } - - select { - case code := <-codeChan: - authCode = code - case err := <-errChan: - return nil, err - case code := <-manualCodeChan: - authCode = code - case err := <-manualErrChan: - return nil, err - case <-time.After(5 * time.Minute): // Timeout - return nil, fmt.Errorf("oauth flow timed out") + authCode = parsed.Code + break waitForCallback + case <-timeoutTimer.C: + return nil, fmt.Errorf("oauth flow timed out") + } } // Shutdown the server. diff --git a/internal/misc/oauth.go b/internal/misc/oauth.go index d5cae403..c14f39d2 100644 --- a/internal/misc/oauth.go +++ b/internal/misc/oauth.go @@ -42,6 +42,8 @@ func ParseOAuthCallback(input string) (*OAuthCallback, error) { if !strings.Contains(candidate, "://") { if strings.HasPrefix(candidate, "?") { candidate = "http://localhost" + candidate + } else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") { + candidate = "http://" + candidate } else if strings.Contains(candidate, "=") { candidate = "http://localhost/?" + candidate } else {