mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
fix(gemini): avoid stale manual oauth prompt and accept schemeless callbacks
This commit is contained in:
@@ -219,8 +219,8 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
|||||||
// - error: An error if the token acquisition fails, nil otherwise
|
// - error: An error if the token acquisition fails, nil otherwise
|
||||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||||
codeChan := make(chan string)
|
codeChan := make(chan string, 1)
|
||||||
errChan := make(chan error)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
// Create a new HTTP server with its own multiplexer.
|
// Create a new HTTP server with its own multiplexer.
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@@ -230,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.URL.Query().Get("error"); err != "" {
|
if err := r.URL.Query().Get("error"); err != "" {
|
||||||
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
||||||
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
select {
|
||||||
|
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
|
||||||
|
default:
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
code := r.URL.Query().Get("code")
|
code := r.URL.Query().Get("code")
|
||||||
if code == "" {
|
if code == "" {
|
||||||
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
|
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
|
||||||
errChan <- fmt.Errorf("code not found in callback")
|
select {
|
||||||
|
case errChan <- fmt.Errorf("code not found in callback"):
|
||||||
|
default:
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||||
codeChan <- code
|
select {
|
||||||
|
case codeChan <- code:
|
||||||
|
default:
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start the server in a goroutine.
|
// Start the server in a goroutine.
|
||||||
@@ -293,49 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
|
|
||||||
// Wait for the authorization code or an error.
|
// Wait for the authorization code or an error.
|
||||||
var authCode string
|
var authCode string
|
||||||
manualCodeChan := make(chan string, 1)
|
timeoutTimer := time.NewTimer(5 * time.Minute)
|
||||||
manualErrChan := make(chan error, 1)
|
defer timeoutTimer.Stop()
|
||||||
|
|
||||||
|
var manualPromptTimer *time.Timer
|
||||||
|
var manualPromptC <-chan time.Time
|
||||||
if opts != nil && opts.Prompt != nil {
|
if opts != nil && opts.Prompt != nil {
|
||||||
go func() {
|
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||||
|
manualPromptC = manualPromptTimer.C
|
||||||
|
defer manualPromptTimer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
waitForCallback:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case code := <-codeChan:
|
||||||
|
authCode = code
|
||||||
|
break waitForCallback
|
||||||
|
case err := <-errChan:
|
||||||
|
return nil, err
|
||||||
|
case <-manualPromptC:
|
||||||
|
manualPromptC = nil
|
||||||
|
if manualPromptTimer != nil {
|
||||||
|
manualPromptTimer.Stop()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case code := <-codeChan:
|
||||||
|
authCode = code
|
||||||
|
break waitForCallback
|
||||||
|
case err := <-errChan:
|
||||||
|
return nil, err
|
||||||
|
default:
|
||||||
|
}
|
||||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
manualErrChan <- err
|
return nil, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
parsed, err := misc.ParseOAuthCallback(input)
|
parsed, err := misc.ParseOAuthCallback(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
manualErrChan <- err
|
return nil, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if parsed == nil {
|
if parsed == nil {
|
||||||
return
|
continue
|
||||||
}
|
}
|
||||||
if parsed.Error != "" {
|
if parsed.Error != "" {
|
||||||
manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error)
|
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if parsed.Code == "" {
|
if parsed.Code == "" {
|
||||||
manualErrChan <- fmt.Errorf("code not found in callback")
|
return nil, fmt.Errorf("code not found in callback")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
manualCodeChan <- parsed.Code
|
authCode = parsed.Code
|
||||||
}()
|
break waitForCallback
|
||||||
} else {
|
case <-timeoutTimer.C:
|
||||||
manualCodeChan = nil
|
return nil, fmt.Errorf("oauth flow timed out")
|
||||||
manualErrChan = nil
|
}
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case code := <-codeChan:
|
|
||||||
authCode = code
|
|
||||||
case err := <-errChan:
|
|
||||||
return nil, err
|
|
||||||
case code := <-manualCodeChan:
|
|
||||||
authCode = code
|
|
||||||
case err := <-manualErrChan:
|
|
||||||
return nil, err
|
|
||||||
case <-time.After(5 * time.Minute): // Timeout
|
|
||||||
return nil, fmt.Errorf("oauth flow timed out")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown the server.
|
// Shutdown the server.
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
|||||||
if !strings.Contains(candidate, "://") {
|
if !strings.Contains(candidate, "://") {
|
||||||
if strings.HasPrefix(candidate, "?") {
|
if strings.HasPrefix(candidate, "?") {
|
||||||
candidate = "http://localhost" + candidate
|
candidate = "http://localhost" + candidate
|
||||||
|
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
|
||||||
|
candidate = "http://" + candidate
|
||||||
} else if strings.Contains(candidate, "=") {
|
} else if strings.Contains(candidate, "=") {
|
||||||
candidate = "http://localhost/?" + candidate
|
candidate = "http://localhost/?" + candidate
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Reference in New Issue
Block a user