From cfa8ddb59f3dbddd7288c565163b5cd4b7ae25ad Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Fri, 19 Dec 2025 00:38:29 +0800 Subject: [PATCH 1/2] feat(oauth): add remote OAuth callback support with session management Introduce a centralized OAuth session store with TTL-based expiration to replace the previous simple map-based status tracking. Add a new /api/oauth/callback endpoint that allows remote clients to relay OAuth callback data back to the CLI proxy, enabling OAuth flows when the callback cannot reach the local machine directly. - Add oauth_sessions.go with thread-safe session store and validation - Add oauth_callback.go with POST handler for remote callback relay - Refactor auth_files.go to use new session management APIs - Register new callback route in server.go --- .../api/handlers/management/auth_files.go | 165 +++++------ .../api/handlers/management/oauth_callback.go | 100 +++++++ .../api/handlers/management/oauth_sessions.go | 258 ++++++++++++++++++ internal/api/server.go | 32 ++- 4 files changed, 466 insertions(+), 89 deletions(-) create mode 100644 internal/api/handlers/management/oauth_callback.go create mode 100644 internal/api/handlers/management/oauth_sessions.go diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index e29dc29f..433bae92 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -36,10 +36,6 @@ import ( "golang.org/x/oauth2/google" ) -var ( - oauthStatus = make(map[string]string) -) - var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( @@ -786,6 +782,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { return } + RegisterOAuthSession(state, "anthropic") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") @@ -812,7 +810,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { deadline := time.Now().Add(timeout) for { if time.Now().After(deadline) { - oauthStatus[state] = "Timeout waiting for OAuth callback" + SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") } data, errRead := os.ReadFile(path) @@ -837,13 +835,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errStr := resultMap["error"]; errStr != "" { oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(claude.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad request" + SetOAuthSessionError(state, "Bad request") return } if resultMap["state"] != state { authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"])) log.Error(claude.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "State code error" + SetOAuthSessionError(state, "State code error") return } @@ -876,7 +874,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errDo != nil { authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") return } defer func() { @@ -887,7 +885,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) return } var tResp struct { @@ -900,7 +898,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } if errU := json.Unmarshal(respBody, &tResp); errU != nil { log.Errorf("failed to parse token response: %v", errU) - oauthStatus[state] = "Failed to parse token response" + SetOAuthSessionError(state, "Failed to parse token response") return } bundle := &claude.ClaudeAuthBundle{ @@ -925,7 +923,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + SetOAuthSessionError(state, "Failed to save authentication tokens") return } @@ -934,10 +932,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Claude services through this CLI") - delete(oauthStatus, state) + CompleteOAuthSession(state) }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -968,6 +965,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) + RegisterOAuthSession(state, "gemini") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/google/callback") @@ -996,7 +995,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + SetOAuthSessionError(state, "OAuth flow timed out") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1005,13 +1004,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := m["error"]; errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") return } authCode = m["code"] if authCode == "" { log.Errorf("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + SetOAuthSessionError(state, "Authentication failed: code not found") return } break @@ -1023,7 +1022,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { token, err := conf.Exchange(ctx, authCode) if err != nil { log.Errorf("Failed to exchange token: %v", err) - oauthStatus[state] = "Failed to exchange token" + SetOAuthSessionError(state, "Failed to exchange token") return } @@ -1034,7 +1033,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errNewRequest != nil { log.Errorf("Could not get user info: %v", errNewRequest) - oauthStatus[state] = "Could not get user info" + SetOAuthSessionError(state, "Could not get user info") return } req.Header.Set("Content-Type", "application/json") @@ -1043,7 +1042,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { resp, errDo := authHTTPClient.Do(req) if errDo != nil { log.Errorf("Failed to execute request: %v", errDo) - oauthStatus[state] = "Failed to execute request" + SetOAuthSessionError(state, "Failed to execute request") return } defer func() { @@ -1055,7 +1054,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { bodyBytes, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) return } @@ -1064,7 +1063,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Printf("Authenticated user email: %s\n", email) } else { fmt.Println("Failed to get user email from token") - oauthStatus[state] = "Failed to get user email from token" } // Marshal/unmarshal oauth2.Token to generic map and enrich fields @@ -1072,7 +1070,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { jsonData, _ := json.Marshal(token) if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { log.Errorf("Failed to unmarshal token: %v", errUnmarshal) - oauthStatus[state] = "Failed to unmarshal token" + SetOAuthSessionError(state, "Failed to unmarshal token") return } @@ -1098,7 +1096,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) if errGetClient != nil { log.Errorf("failed to get authenticated client: %v", errGetClient) - oauthStatus[state] = "Failed to get authenticated client" + SetOAuthSessionError(state, "Failed to get authenticated client") return } fmt.Println("Authentication successful.") @@ -1108,12 +1106,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) if errAll != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") return } if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - oauthStatus[state] = "Failed to verify Cloud AI API status" + SetOAuthSessionError(state, "Failed to verify Cloud AI API status") return } ts.ProjectID = strings.Join(projects, ",") @@ -1121,26 +1119,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } else { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") return } if strings.TrimSpace(ts.ProjectID) == "" { log.Error("Onboarding did not return a project ID") - oauthStatus[state] = "Failed to resolve project ID" + SetOAuthSessionError(state, "Failed to resolve project ID") return } isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - oauthStatus[state] = "Failed to verify Cloud AI API status" + SetOAuthSessionError(state, "Failed to verify Cloud AI API status") return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the selected project") - oauthStatus[state] = "Cloud AI API not enabled" + SetOAuthSessionError(state, "Cloud AI API not enabled") return } } @@ -1163,15 +1161,14 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + SetOAuthSessionError(state, "Failed to save token to file") return } - delete(oauthStatus, state) + CompleteOAuthSession(state) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1207,6 +1204,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { return } + RegisterOAuthSession(state, "codex") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/codex/callback") @@ -1235,7 +1234,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "Timeout waiting for OAuth callback" + SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1245,12 +1244,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if errStr := m["error"]; errStr != "" { oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(codex.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad Request" + SetOAuthSessionError(state, "Bad Request") return } if m["state"] != state { authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) - oauthStatus[state] = "State code error" + SetOAuthSessionError(state, "State code error") log.Error(codex.GetUserFriendlyMessage(authErr)) return } @@ -1281,14 +1280,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) return } @@ -1299,7 +1298,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { ExpiresIn int `json:"expires_in"` } if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { - oauthStatus[state] = "Failed to parse token response" + SetOAuthSessionError(state, "Failed to parse token response") log.Errorf("failed to parse token response: %v", errU) return } @@ -1337,7 +1336,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" + SetOAuthSessionError(state, "Failed to save authentication tokens") log.Errorf("Failed to save authentication tokens: %v", errSave) return } @@ -1346,10 +1345,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Codex services through this CLI") - delete(oauthStatus, state) + CompleteOAuthSession(state) }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1390,6 +1388,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { params.Set("state", state) authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() + RegisterOAuthSession(state, "antigravity") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") @@ -1416,7 +1416,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + SetOAuthSessionError(state, "OAuth flow timed out") return } if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { @@ -1425,18 +1425,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := strings.TrimSpace(payload["error"]); errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") return } if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { log.Errorf("Authentication failed: state mismatch") - oauthStatus[state] = "Authentication failed: state mismatch" + SetOAuthSessionError(state, "Authentication failed: state mismatch") return } authCode = strings.TrimSpace(payload["code"]) if authCode == "" { log.Error("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + SetOAuthSessionError(state, "Authentication failed: code not found") return } break @@ -1455,7 +1455,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) if errNewRequest != nil { log.Errorf("Failed to build token request: %v", errNewRequest) - oauthStatus[state] = "Failed to build token request" + SetOAuthSessionError(state, "Failed to build token request") return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -1463,7 +1463,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { log.Errorf("Failed to execute token request: %v", errDo) - oauthStatus[state] = "Failed to exchange token" + SetOAuthSessionError(state, "Failed to exchange token") return } defer func() { @@ -1475,7 +1475,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { bodyBytes, _ := io.ReadAll(resp.Body) log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) return } @@ -1487,7 +1487,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { log.Errorf("Failed to parse token response: %v", errDecode) - oauthStatus[state] = "Failed to parse token response" + SetOAuthSessionError(state, "Failed to parse token response") return } @@ -1496,7 +1496,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errInfoReq != nil { log.Errorf("Failed to build user info request: %v", errInfoReq) - oauthStatus[state] = "Failed to build user info request" + SetOAuthSessionError(state, "Failed to build user info request") return } infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) @@ -1504,7 +1504,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoResp, errInfo := httpClient.Do(infoReq) if errInfo != nil { log.Errorf("Failed to execute user info request: %v", errInfo) - oauthStatus[state] = "Failed to execute user info request" + SetOAuthSessionError(state, "Failed to execute user info request") return } defer func() { @@ -1523,7 +1523,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } else { bodyBytes, _ := io.ReadAll(infoResp.Body) log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) return } } @@ -1571,11 +1571,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + SetOAuthSessionError(state, "Failed to save token to file") return } - delete(oauthStatus, state) + CompleteOAuthSession(state) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1583,7 +1583,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { fmt.Println("You can now use Antigravity services through this CLI") }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1605,11 +1604,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { } authURL := deviceFlow.VerificationURIComplete + RegisterOAuthSession(state, "qwen") + go func() { fmt.Println("Waiting for authentication...") tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) if errPollForToken != nil { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errPollForToken) return } @@ -1628,16 +1629,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + SetOAuthSessionError(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Println("You can now use Qwen services through this CLI") - delete(oauthStatus, state) + CompleteOAuthSession(state) }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1650,6 +1650,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { authSvc := iflowauth.NewIFlowAuth(h.cfg) authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort) + RegisterOAuthSession(state, "iflow") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/iflow/callback") @@ -1676,7 +1678,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { var resultMap map[string]string for { if time.Now().After(deadline) { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") return } @@ -1689,26 +1691,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %s\n", errStr) return } if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: state mismatch") return } code := strings.TrimSpace(resultMap["code"]) if code == "" { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: code missing") return } tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) if errExchange != nil { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errExchange) return } @@ -1730,7 +1732,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" + SetOAuthSessionError(state, "Failed to save authentication tokens") log.Errorf("Failed to save authentication tokens: %v", errSave) return } @@ -1740,10 +1742,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use iFlow services through this CLI") - delete(oauthStatus, state) + CompleteOAuthSession(state) }() - oauthStatus[state] = "" c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -2179,16 +2180,24 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec } func (h *Handler) GetAuthStatus(c *gin.Context) { - state := c.Query("state") - if err, ok := oauthStatus[state]; ok { - if err != "" { - c.JSON(200, gin.H{"status": "error", "error": err}) - } else { - c.JSON(200, gin.H{"status": "wait"}) - return - } - } else { - c.JSON(200, gin.H{"status": "ok"}) + state := strings.TrimSpace(c.Query("state")) + if state == "" { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"}) + return } - delete(oauthStatus, state) + if err := ValidateOAuthState(state); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"}) + return + } + + _, status, ok := GetOAuthSession(state) + if !ok { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + return + } + if status != "" { + c.JSON(http.StatusOK, gin.H{"status": "error", "error": status}) + return + } + c.JSON(http.StatusOK, gin.H{"status": "wait"}) } diff --git a/internal/api/handlers/management/oauth_callback.go b/internal/api/handlers/management/oauth_callback.go new file mode 100644 index 00000000..c69a332e --- /dev/null +++ b/internal/api/handlers/management/oauth_callback.go @@ -0,0 +1,100 @@ +package management + +import ( + "errors" + "net/http" + "net/url" + "strings" + + "github.com/gin-gonic/gin" +) + +type oauthCallbackRequest struct { + Provider string `json:"provider"` + RedirectURL string `json:"redirect_url"` + Code string `json:"code"` + State string `json:"state"` + Error string `json:"error"` +} + +func (h *Handler) PostOAuthCallback(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"}) + return + } + + var req oauthCallbackRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"}) + return + } + + canonicalProvider, err := NormalizeOAuthProvider(req.Provider) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"}) + return + } + + state := strings.TrimSpace(req.State) + code := strings.TrimSpace(req.Code) + errMsg := strings.TrimSpace(req.Error) + + if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" { + u, errParse := url.Parse(rawRedirect) + if errParse != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"}) + return + } + q := u.Query() + if state == "" { + state = strings.TrimSpace(q.Get("state")) + } + if code == "" { + code = strings.TrimSpace(q.Get("code")) + } + if errMsg == "" { + errMsg = strings.TrimSpace(q.Get("error")) + if errMsg == "" { + errMsg = strings.TrimSpace(q.Get("error_description")) + } + } + } + + if state == "" { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"}) + return + } + if err := ValidateOAuthState(state); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"}) + return + } + if code == "" && errMsg == "" { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"}) + return + } + + sessionProvider, sessionStatus, ok := GetOAuthSession(state) + if !ok { + c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"}) + return + } + if sessionStatus != "" { + c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"}) + return + } + if !strings.EqualFold(sessionProvider, canonicalProvider) { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"}) + return + } + + if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil { + if errors.Is(errWrite, errOAuthSessionNotPending) { + c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "ok"}) +} diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go new file mode 100644 index 00000000..f23b608c --- /dev/null +++ b/internal/api/handlers/management/oauth_sessions.go @@ -0,0 +1,258 @@ +package management + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + oauthSessionTTL = 10 * time.Minute + maxOAuthStateLength = 128 +) + +var ( + errInvalidOAuthState = errors.New("invalid oauth state") + errUnsupportedOAuthFlow = errors.New("unsupported oauth provider") + errOAuthSessionNotPending = errors.New("oauth session is not pending") +) + +type oauthSession struct { + Provider string + Status string + CreatedAt time.Time + ExpiresAt time.Time +} + +type oauthSessionStore struct { + mu sync.RWMutex + ttl time.Duration + sessions map[string]oauthSession +} + +func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore { + if ttl <= 0 { + ttl = oauthSessionTTL + } + return &oauthSessionStore{ + ttl: ttl, + sessions: make(map[string]oauthSession), + } +} + +func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) { + for state, session := range s.sessions { + if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { + delete(s.sessions, state) + } + } +} + +func (s *oauthSessionStore) Register(state, provider string) { + state = strings.TrimSpace(state) + provider = strings.ToLower(strings.TrimSpace(provider)) + if state == "" || provider == "" { + return + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + s.sessions[state] = oauthSession{ + Provider: provider, + Status: "", + CreatedAt: now, + ExpiresAt: now.Add(s.ttl), + } +} + +func (s *oauthSessionStore) SetError(state, message string) { + state = strings.TrimSpace(state) + message = strings.TrimSpace(message) + if state == "" { + return + } + if message == "" { + message = "Authentication failed" + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + if !ok { + return + } + session.Status = message + session.ExpiresAt = now.Add(s.ttl) + s.sessions[state] = session +} + +func (s *oauthSessionStore) Complete(state string) { + state = strings.TrimSpace(state) + if state == "" { + return + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + delete(s.sessions, state) +} + +func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { + state = strings.TrimSpace(state) + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + return session, ok +} + +func (s *oauthSessionStore) IsPending(state, provider string) bool { + state = strings.TrimSpace(state) + provider = strings.ToLower(strings.TrimSpace(provider)) + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + if !ok { + return false + } + if session.Status != "" { + return false + } + if provider == "" { + return true + } + return strings.EqualFold(session.Provider, provider) +} + +var oauthSessions = newOAuthSessionStore(oauthSessionTTL) + +func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) } + +func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) } + +func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } + +func GetOAuthSession(state string) (provider string, status string, ok bool) { + session, ok := oauthSessions.Get(state) + if !ok { + return "", "", false + } + return session.Provider, session.Status, true +} + +func IsOAuthSessionPending(state, provider string) bool { + return oauthSessions.IsPending(state, provider) +} + +func ValidateOAuthState(state string) error { + trimmed := strings.TrimSpace(state) + if trimmed == "" { + return fmt.Errorf("%w: empty", errInvalidOAuthState) + } + if len(trimmed) > maxOAuthStateLength { + return fmt.Errorf("%w: too long", errInvalidOAuthState) + } + if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") { + return fmt.Errorf("%w: contains path separator", errInvalidOAuthState) + } + if strings.Contains(trimmed, "..") { + return fmt.Errorf("%w: contains '..'", errInvalidOAuthState) + } + for _, r := range trimmed { + switch { + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9': + case r == '-' || r == '_' || r == '.': + default: + return fmt.Errorf("%w: invalid character", errInvalidOAuthState) + } + } + return nil +} + +func NormalizeOAuthProvider(provider string) (string, error) { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "anthropic", "claude": + return "anthropic", nil + case "codex", "openai": + return "codex", nil + case "gemini", "google": + return "gemini", nil + case "iflow", "i-flow": + return "iflow", nil + case "antigravity", "anti-gravity": + return "antigravity", nil + case "qwen": + return "qwen", nil + default: + return "", errUnsupportedOAuthFlow + } +} + +type oauthCallbackFilePayload struct { + Code string `json:"code"` + State string `json:"state"` + Error string `json:"error"` +} + +func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) { + if strings.TrimSpace(authDir) == "" { + return "", fmt.Errorf("auth dir is empty") + } + canonicalProvider, err := NormalizeOAuthProvider(provider) + if err != nil { + return "", err + } + if err := ValidateOAuthState(state); err != nil { + return "", err + } + + fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state) + filePath := filepath.Join(authDir, fileName) + payload := oauthCallbackFilePayload{ + Code: strings.TrimSpace(code), + State: strings.TrimSpace(state), + Error: strings.TrimSpace(errorMessage), + } + data, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("marshal oauth callback payload: %w", err) + } + if err := os.WriteFile(filePath, data, 0o600); err != nil { + return "", fmt.Errorf("write oauth callback file: %w", err) + } + return filePath, nil +} + +func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) { + canonicalProvider, err := NormalizeOAuthProvider(provider) + if err != nil { + return "", err + } + if !IsOAuthSessionPending(state, canonicalProvider) { + return "", errOAuthSessionNotPending + } + return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage) +} diff --git a/internal/api/server.go b/internal/api/server.go index e6d03bc3..d6fe91bf 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -354,10 +354,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") - // Persist to a temporary file keyed by state + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -367,9 +368,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -379,9 +382,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -391,9 +396,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -403,9 +410,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -577,6 +586,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) + mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } } From 1b358c931cc84c5a8786fd2f6816ca0181637031 Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Fri, 19 Dec 2025 12:15:22 +0800 Subject: [PATCH 2/2] fix: restore get-auth-status ok fallback and document it --- internal/api/handlers/management/auth_files.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 433bae92..bf5a5b9c 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -2182,7 +2182,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec func (h *Handler) GetAuthStatus(c *gin.Context) { state := strings.TrimSpace(c.Query("state")) if state == "" { - c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"}) + c.JSON(http.StatusOK, gin.H{"status": "ok"}) return } if err := ValidateOAuthState(state); err != nil {