package management import ( "encoding/json" "errors" "fmt" "os" "path/filepath" "strings" "sync" "time" ) const ( oauthSessionTTL = 10 * time.Minute maxOAuthStateLength = 128 ) var ( errInvalidOAuthState = errors.New("invalid oauth state") errUnsupportedOAuthFlow = errors.New("unsupported oauth provider") errOAuthSessionNotPending = errors.New("oauth session is not pending") ) type oauthSession struct { Provider string Status string CreatedAt time.Time ExpiresAt time.Time } type oauthSessionStore struct { mu sync.RWMutex ttl time.Duration sessions map[string]oauthSession } func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore { if ttl <= 0 { ttl = oauthSessionTTL } return &oauthSessionStore{ ttl: ttl, sessions: make(map[string]oauthSession), } } func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) { for state, session := range s.sessions { if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { delete(s.sessions, state) } } } func (s *oauthSessionStore) Register(state, provider string) { state = strings.TrimSpace(state) provider = strings.ToLower(strings.TrimSpace(provider)) if state == "" || provider == "" { return } now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.purgeExpiredLocked(now) s.sessions[state] = oauthSession{ Provider: provider, Status: "", CreatedAt: now, ExpiresAt: now.Add(s.ttl), } } func (s *oauthSessionStore) SetError(state, message string) { state = strings.TrimSpace(state) message = strings.TrimSpace(message) if state == "" { return } if message == "" { message = "Authentication failed" } now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.purgeExpiredLocked(now) session, ok := s.sessions[state] if !ok { return } session.Status = message session.ExpiresAt = now.Add(s.ttl) s.sessions[state] = session } func (s *oauthSessionStore) Complete(state string) { state = strings.TrimSpace(state) if state == "" { return } now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.purgeExpiredLocked(now) delete(s.sessions, state) } func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { state = strings.TrimSpace(state) now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.purgeExpiredLocked(now) session, ok := s.sessions[state] return session, ok } func (s *oauthSessionStore) IsPending(state, provider string) bool { state = strings.TrimSpace(state) provider = strings.ToLower(strings.TrimSpace(provider)) now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.purgeExpiredLocked(now) session, ok := s.sessions[state] if !ok { return false } if session.Status != "" { return false } if provider == "" { return true } return strings.EqualFold(session.Provider, provider) } var oauthSessions = newOAuthSessionStore(oauthSessionTTL) func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) } func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) } func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } func GetOAuthSession(state string) (provider string, status string, ok bool) { session, ok := oauthSessions.Get(state) if !ok { return "", "", false } return session.Provider, session.Status, true } func IsOAuthSessionPending(state, provider string) bool { return oauthSessions.IsPending(state, provider) } func ValidateOAuthState(state string) error { trimmed := strings.TrimSpace(state) if trimmed == "" { return fmt.Errorf("%w: empty", errInvalidOAuthState) } if len(trimmed) > maxOAuthStateLength { return fmt.Errorf("%w: too long", errInvalidOAuthState) } if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") { return fmt.Errorf("%w: contains path separator", errInvalidOAuthState) } if strings.Contains(trimmed, "..") { return fmt.Errorf("%w: contains '..'", errInvalidOAuthState) } for _, r := range trimmed { switch { case r >= 'a' && r <= 'z': case r >= 'A' && r <= 'Z': case r >= '0' && r <= '9': case r == '-' || r == '_' || r == '.': default: return fmt.Errorf("%w: invalid character", errInvalidOAuthState) } } return nil } func NormalizeOAuthProvider(provider string) (string, error) { switch strings.ToLower(strings.TrimSpace(provider)) { case "anthropic", "claude": return "anthropic", nil case "codex", "openai": return "codex", nil case "gemini", "google": return "gemini", nil case "iflow", "i-flow": return "iflow", nil case "antigravity", "anti-gravity": return "antigravity", nil case "qwen": return "qwen", nil default: return "", errUnsupportedOAuthFlow } } type oauthCallbackFilePayload struct { Code string `json:"code"` State string `json:"state"` Error string `json:"error"` } func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) { if strings.TrimSpace(authDir) == "" { return "", fmt.Errorf("auth dir is empty") } canonicalProvider, err := NormalizeOAuthProvider(provider) if err != nil { return "", err } if err := ValidateOAuthState(state); err != nil { return "", err } fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state) filePath := filepath.Join(authDir, fileName) payload := oauthCallbackFilePayload{ Code: strings.TrimSpace(code), State: strings.TrimSpace(state), Error: strings.TrimSpace(errorMessage), } data, err := json.Marshal(payload) if err != nil { return "", fmt.Errorf("marshal oauth callback payload: %w", err) } if err := os.WriteFile(filePath, data, 0o600); err != nil { return "", fmt.Errorf("write oauth callback file: %w", err) } return filePath, nil } func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) { canonicalProvider, err := NormalizeOAuthProvider(provider) if err != nil { return "", err } if !IsOAuthSessionPending(state, canonicalProvider) { return "", errOAuthSessionNotPending } return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage) }