From 781bc1521b827b4e2c9f60f17a5d8b5a5f28e85c Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 10:48:40 +0800 Subject: [PATCH] fix(oauth): prevent stale session timeouts after login - stop callback forwarders by instance to avoid cross-session shutdowns - clear pending sessions for a provider after successful auth --- .../api/handlers/management/auth_files.go | 66 ++++++++++++++++--- .../api/handlers/management/oauth_sessions.go | 25 +++++++ 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 4f42bd7a..41a4fde4 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -197,6 +197,19 @@ func stopCallbackForwarder(port int) { stopForwarderInstance(port, forwarder) } +func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) { + if forwarder == nil { + return + } + callbackForwardersMu.Lock() + if current := callbackForwarders[port]; current == forwarder { + delete(callbackForwarders, port) + } + callbackForwardersMu.Unlock() + + stopForwarderInstance(port, forwarder) +} + func stopForwarderInstance(port int, forwarder *callbackForwarder) { if forwarder == nil || forwarder.server == nil { return @@ -785,6 +798,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { RegisterOAuthSession(state, "anthropic") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") if errTarget != nil { @@ -792,7 +806,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start anthropic callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -801,7 +816,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(anthropicCallbackPort) + defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder) } // Helper: wait for callback file @@ -809,6 +824,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { deadline := time.Now().Add(timeout) for { + if !IsOAuthSessionPending(state, "anthropic") { + return nil, errOAuthSessionNotPending + } if time.Now().After(deadline) { SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") @@ -828,6 +846,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { // Wait up to 5 minutes resultMap, errWait := waitForFile(waitFile, 5*time.Minute) if errWait != nil { + if errors.Is(errWait, errOAuthSessionNotPending) { + return + } authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait) log.Error(claude.GetUserFriendlyMessage(authErr)) return @@ -933,6 +954,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } fmt.Println("You can now use Claude services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("anthropic") }() c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) @@ -968,6 +990,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { RegisterOAuthSession(state, "gemini") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/google/callback") if errTarget != nil { @@ -975,7 +998,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start gemini callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -984,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(geminiCallbackPort) + defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder) } // Wait for callback file written by server route @@ -993,6 +1017,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var authCode string for { + if !IsOAuthSessionPending(state, "gemini") { + return + } if time.Now().After(deadline) { log.Error("oauth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out") @@ -1168,6 +1195,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("gemini") fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() @@ -1209,6 +1237,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { RegisterOAuthSession(state, "codex") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/codex/callback") if errTarget != nil { @@ -1216,7 +1245,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start codex callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1225,7 +1255,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(codexCallbackPort) + defer stopCallbackForwarderInstance(codexCallbackPort, forwarder) } // Wait for callback file @@ -1233,6 +1263,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var code string for { + if !IsOAuthSessionPending(state, "codex") { + return + } if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) @@ -1348,6 +1381,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } fmt.Println("You can now use Codex services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("codex") }() c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) @@ -1393,6 +1427,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { RegisterOAuthSession(state, "antigravity") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") if errTarget != nil { @@ -1400,7 +1435,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start antigravity callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1409,13 +1445,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(antigravityCallbackPort) + defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder) } waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) deadline := time.Now().Add(5 * time.Minute) var authCode string for { + if !IsOAuthSessionPending(state, "antigravity") { + return + } if time.Now().After(deadline) { log.Error("oauth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out") @@ -1578,6 +1617,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("antigravity") fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1655,6 +1695,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { RegisterOAuthSession(state, "iflow") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/iflow/callback") if errTarget != nil { @@ -1662,7 +1703,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start iflow callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"}) return @@ -1671,7 +1713,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(iflowauth.CallbackPort) + defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder) } fmt.Println("Waiting for authentication...") @@ -1679,6 +1721,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var resultMap map[string]string for { + if !IsOAuthSessionPending(state, "iflow") { + return + } if time.Now().After(deadline) { SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") @@ -1745,6 +1790,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } fmt.Println("You can now use iFlow services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("iflow") }() c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go index f23b608c..05ff8d1f 100644 --- a/internal/api/handlers/management/oauth_sessions.go +++ b/internal/api/handlers/management/oauth_sessions.go @@ -111,6 +111,27 @@ func (s *oauthSessionStore) Complete(state string) { delete(s.sessions, state) } +func (s *oauthSessionStore) CompleteProvider(provider string) int { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" { + return 0 + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + removed := 0 + for state, session := range s.sessions { + if strings.EqualFold(session.Provider, provider) { + delete(s.sessions, state) + removed++ + } + } + return removed +} + func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { state = strings.TrimSpace(state) now := time.Now() @@ -153,6 +174,10 @@ func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } +func CompleteOAuthSessionsByProvider(provider string) int { + return oauthSessions.CompleteProvider(provider) +} + func GetOAuthSession(state string) (provider string, status string, ok bool) { session, ok := oauthSessions.Get(state) if !ok {