mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
Introduce a centralized OAuth session store with TTL-based expiration to replace the previous simple map-based status tracking. Add a new /api/oauth/callback endpoint that allows remote clients to relay OAuth callback data back to the CLI proxy, enabling OAuth flows when the callback cannot reach the local machine directly. - Add oauth_sessions.go with thread-safe session store and validation - Add oauth_callback.go with POST handler for remote callback relay - Refactor auth_files.go to use new session management APIs - Register new callback route in server.go
259 lines
6.0 KiB
Go
259 lines
6.0 KiB
Go
package management
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
oauthSessionTTL = 10 * time.Minute
|
|
maxOAuthStateLength = 128
|
|
)
|
|
|
|
var (
|
|
errInvalidOAuthState = errors.New("invalid oauth state")
|
|
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
|
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
|
)
|
|
|
|
type oauthSession struct {
|
|
Provider string
|
|
Status string
|
|
CreatedAt time.Time
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
type oauthSessionStore struct {
|
|
mu sync.RWMutex
|
|
ttl time.Duration
|
|
sessions map[string]oauthSession
|
|
}
|
|
|
|
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
|
if ttl <= 0 {
|
|
ttl = oauthSessionTTL
|
|
}
|
|
return &oauthSessionStore{
|
|
ttl: ttl,
|
|
sessions: make(map[string]oauthSession),
|
|
}
|
|
}
|
|
|
|
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
|
for state, session := range s.sessions {
|
|
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
|
delete(s.sessions, state)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *oauthSessionStore) Register(state, provider string) {
|
|
state = strings.TrimSpace(state)
|
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
|
if state == "" || provider == "" {
|
|
return
|
|
}
|
|
now := time.Now()
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.purgeExpiredLocked(now)
|
|
s.sessions[state] = oauthSession{
|
|
Provider: provider,
|
|
Status: "",
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(s.ttl),
|
|
}
|
|
}
|
|
|
|
func (s *oauthSessionStore) SetError(state, message string) {
|
|
state = strings.TrimSpace(state)
|
|
message = strings.TrimSpace(message)
|
|
if state == "" {
|
|
return
|
|
}
|
|
if message == "" {
|
|
message = "Authentication failed"
|
|
}
|
|
now := time.Now()
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.purgeExpiredLocked(now)
|
|
session, ok := s.sessions[state]
|
|
if !ok {
|
|
return
|
|
}
|
|
session.Status = message
|
|
session.ExpiresAt = now.Add(s.ttl)
|
|
s.sessions[state] = session
|
|
}
|
|
|
|
func (s *oauthSessionStore) Complete(state string) {
|
|
state = strings.TrimSpace(state)
|
|
if state == "" {
|
|
return
|
|
}
|
|
now := time.Now()
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.purgeExpiredLocked(now)
|
|
delete(s.sessions, state)
|
|
}
|
|
|
|
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
|
state = strings.TrimSpace(state)
|
|
now := time.Now()
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.purgeExpiredLocked(now)
|
|
session, ok := s.sessions[state]
|
|
return session, ok
|
|
}
|
|
|
|
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
|
state = strings.TrimSpace(state)
|
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
|
now := time.Now()
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.purgeExpiredLocked(now)
|
|
session, ok := s.sessions[state]
|
|
if !ok {
|
|
return false
|
|
}
|
|
if session.Status != "" {
|
|
return false
|
|
}
|
|
if provider == "" {
|
|
return true
|
|
}
|
|
return strings.EqualFold(session.Provider, provider)
|
|
}
|
|
|
|
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
|
|
|
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
|
|
|
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
|
|
|
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
|
|
|
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
|
session, ok := oauthSessions.Get(state)
|
|
if !ok {
|
|
return "", "", false
|
|
}
|
|
return session.Provider, session.Status, true
|
|
}
|
|
|
|
func IsOAuthSessionPending(state, provider string) bool {
|
|
return oauthSessions.IsPending(state, provider)
|
|
}
|
|
|
|
func ValidateOAuthState(state string) error {
|
|
trimmed := strings.TrimSpace(state)
|
|
if trimmed == "" {
|
|
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
|
}
|
|
if len(trimmed) > maxOAuthStateLength {
|
|
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
|
}
|
|
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
|
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
|
}
|
|
if strings.Contains(trimmed, "..") {
|
|
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
|
}
|
|
for _, r := range trimmed {
|
|
switch {
|
|
case r >= 'a' && r <= 'z':
|
|
case r >= 'A' && r <= 'Z':
|
|
case r >= '0' && r <= '9':
|
|
case r == '-' || r == '_' || r == '.':
|
|
default:
|
|
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func NormalizeOAuthProvider(provider string) (string, error) {
|
|
switch strings.ToLower(strings.TrimSpace(provider)) {
|
|
case "anthropic", "claude":
|
|
return "anthropic", nil
|
|
case "codex", "openai":
|
|
return "codex", nil
|
|
case "gemini", "google":
|
|
return "gemini", nil
|
|
case "iflow", "i-flow":
|
|
return "iflow", nil
|
|
case "antigravity", "anti-gravity":
|
|
return "antigravity", nil
|
|
case "qwen":
|
|
return "qwen", nil
|
|
default:
|
|
return "", errUnsupportedOAuthFlow
|
|
}
|
|
}
|
|
|
|
type oauthCallbackFilePayload struct {
|
|
Code string `json:"code"`
|
|
State string `json:"state"`
|
|
Error string `json:"error"`
|
|
}
|
|
|
|
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
|
if strings.TrimSpace(authDir) == "" {
|
|
return "", fmt.Errorf("auth dir is empty")
|
|
}
|
|
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if err := ValidateOAuthState(state); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
|
filePath := filepath.Join(authDir, fileName)
|
|
payload := oauthCallbackFilePayload{
|
|
Code: strings.TrimSpace(code),
|
|
State: strings.TrimSpace(state),
|
|
Error: strings.TrimSpace(errorMessage),
|
|
}
|
|
data, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
|
}
|
|
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
|
return "", fmt.Errorf("write oauth callback file: %w", err)
|
|
}
|
|
return filePath, nil
|
|
}
|
|
|
|
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
|
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if !IsOAuthSessionPending(state, canonicalProvider) {
|
|
return "", errOAuthSessionNotPending
|
|
}
|
|
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
|
}
|