fix(oauth): prevent stale session timeouts after login

- stop callback forwarders by instance to avoid cross-session shutdowns
  - clear pending sessions for a provider after successful auth
This commit is contained in:
Supra4E8C
2025-12-21 10:48:40 +08:00
parent 05d201ece8
commit 781bc1521b
2 changed files with 81 additions and 10 deletions

View File

@@ -197,6 +197,19 @@ func stopCallbackForwarder(port int) {
stopForwarderInstance(port, forwarder)
}
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil {
return
}
callbackForwardersMu.Lock()
if current := callbackForwarders[port]; current == forwarder {
delete(callbackForwarders, port)
}
callbackForwardersMu.Unlock()
stopForwarderInstance(port, forwarder)
}
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil || forwarder.server == nil {
return
@@ -785,6 +798,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
RegisterOAuthSession(state, "anthropic")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
if errTarget != nil {
@@ -792,7 +806,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
var errStart error
if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
@@ -801,7 +816,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
go func() {
if isWebUI {
defer stopCallbackForwarder(anthropicCallbackPort)
defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
}
// Helper: wait for callback file
@@ -809,6 +824,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
deadline := time.Now().Add(timeout)
for {
if !IsOAuthSessionPending(state, "anthropic") {
return nil, errOAuthSessionNotPending
}
if time.Now().After(deadline) {
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return nil, fmt.Errorf("timeout waiting for OAuth callback")
@@ -828,6 +846,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
// Wait up to 5 minutes
resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
if errWait != nil {
if errors.Is(errWait, errOAuthSessionNotPending) {
return
}
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
log.Error(claude.GetUserFriendlyMessage(authErr))
return
@@ -933,6 +954,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
}
fmt.Println("You can now use Claude services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("anthropic")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
@@ -968,6 +990,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
RegisterOAuthSession(state, "gemini")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/google/callback")
if errTarget != nil {
@@ -975,7 +998,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
var errStart error
if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start gemini callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
@@ -984,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
go func() {
if isWebUI {
defer stopCallbackForwarder(geminiCallbackPort)
defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
}
// Wait for callback file written by server route
@@ -993,6 +1017,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if !IsOAuthSessionPending(state, "gemini") {
return
}
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
@@ -1168,6 +1195,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
}
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("gemini")
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
}()
@@ -1209,6 +1237,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
RegisterOAuthSession(state, "codex")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
if errTarget != nil {
@@ -1216,7 +1245,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
var errStart error
if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start codex callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
@@ -1225,7 +1255,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
go func() {
if isWebUI {
defer stopCallbackForwarder(codexCallbackPort)
defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
}
// Wait for callback file
@@ -1233,6 +1263,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute)
var code string
for {
if !IsOAuthSessionPending(state, "codex") {
return
}
if time.Now().After(deadline) {
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
log.Error(codex.GetUserFriendlyMessage(authErr))
@@ -1348,6 +1381,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
}
fmt.Println("You can now use Codex services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("codex")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
@@ -1393,6 +1427,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
RegisterOAuthSession(state, "antigravity")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
if errTarget != nil {
@@ -1400,7 +1435,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
var errStart error
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
@@ -1409,13 +1445,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
go func() {
if isWebUI {
defer stopCallbackForwarder(antigravityCallbackPort)
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if !IsOAuthSessionPending(state, "antigravity") {
return
}
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
@@ -1578,6 +1617,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
}
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("antigravity")
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID)
@@ -1655,6 +1695,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
RegisterOAuthSession(state, "iflow")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
if errTarget != nil {
@@ -1662,7 +1703,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
var errStart error
if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start iflow callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
return
@@ -1671,7 +1713,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
go func() {
if isWebUI {
defer stopCallbackForwarder(iflowauth.CallbackPort)
defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
}
fmt.Println("Waiting for authentication...")
@@ -1679,6 +1721,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute)
var resultMap map[string]string
for {
if !IsOAuthSessionPending(state, "iflow") {
return
}
if time.Now().After(deadline) {
SetOAuthSessionError(state, "Authentication failed")
fmt.Println("Authentication failed: timeout waiting for callback")
@@ -1745,6 +1790,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
}
fmt.Println("You can now use iFlow services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("iflow")
}()
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})

View File

@@ -111,6 +111,27 @@ func (s *oauthSessionStore) Complete(state string) {
delete(s.sessions, state)
}
func (s *oauthSessionStore) CompleteProvider(provider string) int {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return 0
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
removed := 0
for state, session := range s.sessions {
if strings.EqualFold(session.Provider, provider) {
delete(s.sessions, state)
removed++
}
}
return removed
}
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
state = strings.TrimSpace(state)
now := time.Now()
@@ -153,6 +174,10 @@ func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state,
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
func CompleteOAuthSessionsByProvider(provider string) int {
return oauthSessions.CompleteProvider(provider)
}
func GetOAuthSession(state string) (provider string, status string, ok bool) {
session, ok := oauthSessions.Get(state)
if !ok {