mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(oauth): add remote OAuth callback support with session management
Introduce a centralized OAuth session store with TTL-based expiration to replace the previous simple map-based status tracking. Add a new /api/oauth/callback endpoint that allows remote clients to relay OAuth callback data back to the CLI proxy, enabling OAuth flows when the callback cannot reach the local machine directly. - Add oauth_sessions.go with thread-safe session store and validation - Add oauth_callback.go with POST handler for remote callback relay - Refactor auth_files.go to use new session management APIs - Register new callback route in server.go
This commit is contained in:
@@ -36,10 +36,6 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
var (
|
||||
oauthStatus = make(map[string]string)
|
||||
)
|
||||
|
||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||
|
||||
const (
|
||||
@@ -786,6 +782,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "anthropic")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
|
||||
@@ -812,7 +810,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
data, errRead := os.ReadFile(path)
|
||||
@@ -837,13 +835,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errStr := resultMap["error"]; errStr != "" {
|
||||
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad request"
|
||||
SetOAuthSessionError(state, "Bad request")
|
||||
return
|
||||
}
|
||||
if resultMap["state"] != state {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -876,7 +874,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errDo != nil {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -887,7 +885,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
var tResp struct {
|
||||
@@ -900,7 +898,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
bundle := &claude.ClaudeAuthBundle{
|
||||
@@ -925,7 +923,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -934,10 +932,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Claude services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -968,6 +965,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
RegisterOAuthSession(state, "gemini")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/google/callback")
|
||||
@@ -996,7 +995,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -1005,13 +1004,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
authCode = m["code"]
|
||||
if authCode == "" {
|
||||
log.Errorf("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -1023,7 +1022,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
token, err := conf.Exchange(ctx, authCode)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to exchange token: %v", err)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1034,7 +1033,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Could not get user info: %v", errNewRequest)
|
||||
oauthStatus[state] = "Could not get user info"
|
||||
SetOAuthSessionError(state, "Could not get user info")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -1043,7 +1042,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
resp, errDo := authHTTPClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to execute request"
|
||||
SetOAuthSessionError(state, "Failed to execute request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1055,7 +1054,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1064,7 +1063,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
fmt.Printf("Authenticated user email: %s\n", email)
|
||||
} else {
|
||||
fmt.Println("Failed to get user email from token")
|
||||
oauthStatus[state] = "Failed to get user email from token"
|
||||
}
|
||||
|
||||
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
||||
@@ -1072,7 +1070,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
jsonData, _ := json.Marshal(token)
|
||||
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
||||
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
||||
oauthStatus[state] = "Failed to unmarshal token"
|
||||
SetOAuthSessionError(state, "Failed to unmarshal token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1098,7 +1096,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
||||
if errGetClient != nil {
|
||||
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
||||
oauthStatus[state] = "Failed to get authenticated client"
|
||||
SetOAuthSessionError(state, "Failed to get authenticated client")
|
||||
return
|
||||
}
|
||||
fmt.Println("Authentication successful.")
|
||||
@@ -1108,12 +1106,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
@@ -1121,26 +1119,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Onboarding did not return a project ID")
|
||||
oauthStatus[state] = "Failed to resolve project ID"
|
||||
SetOAuthSessionError(state, "Failed to resolve project ID")
|
||||
return
|
||||
}
|
||||
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
oauthStatus[state] = "Cloud AI API not enabled"
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1163,15 +1161,14 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1207,6 +1204,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "codex")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
|
||||
@@ -1235,7 +1234,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
if time.Now().After(deadline) {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -1245,12 +1244,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad Request"
|
||||
SetOAuthSessionError(state, "Bad Request")
|
||||
return
|
||||
}
|
||||
if m["state"] != state {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
return
|
||||
}
|
||||
@@ -1281,14 +1280,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
@@ -1299,7 +1298,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
return
|
||||
}
|
||||
@@ -1337,7 +1336,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1346,10 +1345,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Codex services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1390,6 +1388,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
params.Set("state", state)
|
||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
||||
|
||||
RegisterOAuthSession(state, "antigravity")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
|
||||
@@ -1416,7 +1416,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||
@@ -1425,18 +1425,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||
log.Errorf("Authentication failed: state mismatch")
|
||||
oauthStatus[state] = "Authentication failed: state mismatch"
|
||||
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
authCode = strings.TrimSpace(payload["code"])
|
||||
if authCode == "" {
|
||||
log.Error("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -1455,7 +1455,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
||||
oauthStatus[state] = "Failed to build token request"
|
||||
SetOAuthSessionError(state, "Failed to build token request")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
@@ -1463,7 +1463,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute token request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1475,7 +1475,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1487,7 +1487,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
}
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1496,7 +1496,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errInfoReq != nil {
|
||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
||||
oauthStatus[state] = "Failed to build user info request"
|
||||
SetOAuthSessionError(state, "Failed to build user info request")
|
||||
return
|
||||
}
|
||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||
@@ -1504,7 +1504,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoResp, errInfo := httpClient.Do(infoReq)
|
||||
if errInfo != nil {
|
||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
||||
oauthStatus[state] = "Failed to execute user info request"
|
||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1523,7 +1523,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
} else {
|
||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1571,11 +1571,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
@@ -1583,7 +1583,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
fmt.Println("You can now use Antigravity services through this CLI")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1605,11 +1604,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
|
||||
RegisterOAuthSession(state, "qwen")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||
if errPollForToken != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||
return
|
||||
}
|
||||
@@ -1628,16 +1629,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Qwen services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1650,6 +1650,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
||||
|
||||
RegisterOAuthSession(state, "iflow")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
|
||||
@@ -1676,7 +1678,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
var resultMap map[string]string
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||
return
|
||||
}
|
||||
@@ -1689,26 +1691,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||
return
|
||||
}
|
||||
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(resultMap["code"])
|
||||
if code == "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: code missing")
|
||||
return
|
||||
}
|
||||
|
||||
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||
if errExchange != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||
return
|
||||
}
|
||||
@@ -1730,7 +1732,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1740,10 +1742,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use iFlow services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -2179,16 +2180,24 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
}
|
||||
|
||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if err, ok := oauthStatus[state]; ok {
|
||||
if err != "" {
|
||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "wait"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
||||
return
|
||||
}
|
||||
delete(oauthStatus, state)
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
|
||||
_, status, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
if status != "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
100
internal/api/handlers/management/oauth_callback.go
Normal file
100
internal/api/handlers/management/oauth_callback.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type oauthCallbackRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (h *Handler) PostOAuthCallback(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req oauthCallbackRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
|
||||
return
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(req.State)
|
||||
code := strings.TrimSpace(req.Code)
|
||||
errMsg := strings.TrimSpace(req.Error)
|
||||
|
||||
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
|
||||
u, errParse := url.Parse(rawRedirect)
|
||||
if errParse != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
}
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error"))
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
||||
return
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
if code == "" && errMsg == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
|
||||
return
|
||||
}
|
||||
if sessionStatus != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(sessionProvider, canonicalProvider) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
|
||||
if errors.Is(errWrite, errOAuthSessionNotPending) {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
258
internal/api/handlers/management/oauth_sessions.go
Normal file
258
internal/api/handlers/management/oauth_sessions.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthSessionTTL = 10 * time.Minute
|
||||
maxOAuthStateLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidOAuthState = errors.New("invalid oauth state")
|
||||
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
||||
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
||||
)
|
||||
|
||||
type oauthSession struct {
|
||||
Provider string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type oauthSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
sessions map[string]oauthSession
|
||||
}
|
||||
|
||||
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
||||
if ttl <= 0 {
|
||||
ttl = oauthSessionTTL
|
||||
}
|
||||
return &oauthSessionStore{
|
||||
ttl: ttl,
|
||||
sessions: make(map[string]oauthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
||||
for state, session := range s.sessions {
|
||||
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Register(state, provider string) {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if state == "" || provider == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
s.sessions[state] = oauthSession{
|
||||
Provider: provider,
|
||||
Status: "",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) SetError(state, message string) {
|
||||
state = strings.TrimSpace(state)
|
||||
message = strings.TrimSpace(message)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
if message == "" {
|
||||
message = "Authentication failed"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
session.Status = message
|
||||
session.ExpiresAt = now.Add(s.ttl)
|
||||
s.sessions[state] = session
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Complete(state string) {
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
||||
state = strings.TrimSpace(state)
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if session.Status != "" {
|
||||
return false
|
||||
}
|
||||
if provider == "" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(session.Provider, provider)
|
||||
}
|
||||
|
||||
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
||||
|
||||
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
||||
|
||||
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
||||
|
||||
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
||||
|
||||
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
||||
session, ok := oauthSessions.Get(state)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return session.Provider, session.Status, true
|
||||
}
|
||||
|
||||
func IsOAuthSessionPending(state, provider string) bool {
|
||||
return oauthSessions.IsPending(state, provider)
|
||||
}
|
||||
|
||||
func ValidateOAuthState(state string) error {
|
||||
trimmed := strings.TrimSpace(state)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
||||
}
|
||||
if len(trimmed) > maxOAuthStateLength {
|
||||
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
||||
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "..") {
|
||||
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
||||
}
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "anthropic", "claude":
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
return "iflow", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "qwen":
|
||||
return "qwen", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
}
|
||||
|
||||
type oauthCallbackFilePayload struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
if strings.TrimSpace(authDir) == "" {
|
||||
return "", fmt.Errorf("auth dir is empty")
|
||||
}
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
payload := oauthCallbackFilePayload{
|
||||
Code: strings.TrimSpace(code),
|
||||
State: strings.TrimSpace(state),
|
||||
Error: strings.TrimSpace(errorMessage),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
||||
return "", fmt.Errorf("write oauth callback file: %w", err)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !IsOAuthSessionPending(state, canonicalProvider) {
|
||||
return "", errOAuthSessionNotPending
|
||||
}
|
||||
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
||||
}
|
||||
Reference in New Issue
Block a user