From 4fd70d5f1af9d08b95fb56eb32bf2dd13b005275 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 6 Oct 2025 01:52:42 +0800 Subject: [PATCH] feat(auth): add callback forwarder support for Web UI in OAuth flows - Introduced callback forwarders for Anthropic, Gemini, Codex, and iFlow OAuth flows. - Added `is_webui` query parameter detection to enhance Web UI compatibility. - Implemented mechanisms to start and stop callback forwarders dynamically. - Improved error handling and logging for callback server initialization. --- .../api/handlers/management/auth_files.go | 293 +++++++++++++++++- internal/api/server.go | 12 + 2 files changed, 301 insertions(+), 4 deletions(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 053d57dd..ac4e9b27 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -5,14 +5,17 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" + "net" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" + "sync" "time" "github.com/gin-gonic/gin" @@ -38,6 +41,23 @@ var ( var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} +const ( + anthropicCallbackPort = 54545 + geminiCallbackPort = 8085 + codexCallbackPort = 1455 +) + +type callbackForwarder struct { + provider string + server *http.Server + done chan struct{} +} + +var ( + callbackForwardersMu sync.Mutex + callbackForwarders = make(map[int]*callbackForwarder) +) + func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) { if len(meta) == 0 { return time.Time{}, false @@ -91,6 +111,120 @@ func parseLastRefreshValue(v any) (time.Time, bool) { return time.Time{}, false } +func isWebUIRequest(c *gin.Context) bool { + raw := strings.TrimSpace(c.Query("is_webui")) + if raw == "" { + return false + } + switch strings.ToLower(raw) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) { + callbackForwardersMu.Lock() + prev := callbackForwarders[port] + if prev != nil { + delete(callbackForwarders, port) + } + callbackForwardersMu.Unlock() + + if prev != nil { + stopForwarderInstance(port, prev) + } + + addr := fmt.Sprintf("127.0.0.1:%d", port) + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to listen on %s: %w", addr, err) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + target := targetBase + if raw := r.URL.RawQuery; raw != "" { + if strings.Contains(target, "?") { + target = target + "&" + raw + } else { + target = target + "?" + raw + } + } + w.Header().Set("Cache-Control", "no-store") + http.Redirect(w, r, target, http.StatusFound) + }) + + srv := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + done := make(chan struct{}) + + go func() { + if errServe := srv.Serve(ln); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) { + log.WithError(errServe).Warnf("callback forwarder for %s stopped unexpectedly", provider) + } + close(done) + }() + + forwarder := &callbackForwarder{ + provider: provider, + server: srv, + done: done, + } + + callbackForwardersMu.Lock() + callbackForwarders[port] = forwarder + callbackForwardersMu.Unlock() + + log.Infof("callback forwarder for %s listening on %s", provider, addr) + + return forwarder, nil +} + +func stopCallbackForwarder(port int) { + callbackForwardersMu.Lock() + forwarder := callbackForwarders[port] + if forwarder != nil { + delete(callbackForwarders, port) + } + callbackForwardersMu.Unlock() + + stopForwarderInstance(port, forwarder) +} + +func stopForwarderInstance(port int, forwarder *callbackForwarder) { + if forwarder == nil || forwarder.server == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if err := forwarder.server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.WithError(err).Warnf("failed to shut down callback forwarder on port %d", port) + } + + select { + case <-forwarder.done: + case <-time.After(2 * time.Second): + } + + log.Infof("callback forwarder on port %d stopped", port) +} + +func (h *Handler) managementCallbackURL(path string) (string, error) { + if h == nil || h.cfg == nil || h.cfg.Port <= 0 { + return "", fmt.Errorf("server port is not configured") + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return fmt.Sprintf("http://127.0.0.1:%d%s", h.cfg.Port, path), nil +} + // List auth files func (h *Handler) ListAuthFiles(c *gin.Context) { entries, err := os.ReadDir(h.cfg.AuthDir) @@ -390,9 +524,27 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { log.Fatalf("Failed to generate authorization URL: %v", err) return } - // Override redirect_uri in authorization URL to current server port + + isWebUI := isWebUIRequest(c) + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute anthropic callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start anthropic callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } go func() { + if isWebUI { + defer stopCallbackForwarder(anthropicCallbackPort) + } + // Helper: wait for callback file waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state)) waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { @@ -553,7 +705,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) + isWebUI := isWebUIRequest(c) + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/google/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute gemini callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start gemini callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } + go func() { + if isWebUI { + defer stopCallbackForwarder(geminiCallbackPort) + } + // Wait for callback file written by server route waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state)) fmt.Println("Waiting for authentication callback...") @@ -779,7 +950,26 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { return } + isWebUI := isWebUIRequest(c) + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/codex/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute codex callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start codex callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } + go func() { + if isWebUI { + defer stopCallbackForwarder(codexCallbackPort) + } + // Wait for callback file waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state)) deadline := time.Now().Add(5 * time.Minute) @@ -966,6 +1156,103 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { state := fmt.Sprintf("ifl-%d", time.Now().UnixNano()) authSvc := iflowauth.NewIFlowAuth(h.cfg) + authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort) + + isWebUI := isWebUIRequest(c) + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/iflow/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute iflow callback target") + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"}) + return + } + if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start iflow callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"}) + return + } + + go func() { + defer stopCallbackForwarder(iflowauth.CallbackPort) + fmt.Println("Waiting for authentication...") + + waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state)) + deadline := time.Now().Add(5 * time.Minute) + var resultMap map[string]string + for { + if time.Now().After(deadline) { + oauthStatus[state] = "Authentication failed" + fmt.Println("Authentication failed: timeout waiting for callback") + return + } + if data, errR := os.ReadFile(waitFile); errR == nil { + _ = os.Remove(waitFile) + _ = json.Unmarshal(data, &resultMap) + break + } + time.Sleep(500 * time.Millisecond) + } + + if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { + oauthStatus[state] = "Authentication failed" + fmt.Printf("Authentication failed: %s\n", errStr) + return + } + if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { + oauthStatus[state] = "Authentication failed" + fmt.Println("Authentication failed: state mismatch") + return + } + + code := strings.TrimSpace(resultMap["code"]) + if code == "" { + oauthStatus[state] = "Authentication failed" + fmt.Println("Authentication failed: code missing") + return + } + + tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) + if errExchange != nil { + oauthStatus[state] = "Authentication failed" + fmt.Printf("Authentication failed: %v\n", errExchange) + return + } + + tokenStorage := authSvc.CreateTokenStorage(tokenData) + identifier := strings.TrimSpace(tokenStorage.Email) + if identifier == "" { + identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli()) + tokenStorage.Email = identifier + } + record := &coreauth.Auth{ + ID: fmt.Sprintf("iflow-%s.json", identifier), + Provider: "iflow", + FileName: fmt.Sprintf("iflow-%s.json", identifier), + Storage: tokenStorage, + Metadata: map[string]any{"email": identifier, "api_key": tokenStorage.APIKey}, + Attributes: map[string]string{"api_key": tokenStorage.APIKey}, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + oauthStatus[state] = "Failed to save authentication tokens" + log.Fatalf("Failed to save authentication tokens: %v", errSave) + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if tokenStorage.APIKey != "" { + fmt.Println("API key obtained and saved") + } + fmt.Println("You can now use iFlow services through this CLI") + delete(oauthStatus, state) + }() + + oauthStatus[state] = "" + c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) + return + } + oauthServer := iflowauth.NewOAuthServer(iflowauth.CallbackPort) if err := oauthServer.Start(); err != nil { oauthStatus[state] = "Failed to start authentication server" @@ -974,8 +1261,6 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { return } - authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort) - go func() { fmt.Println("Waiting for authentication...") defer func() { @@ -1043,7 +1328,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { }() oauthStatus[state] = "" - c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) + c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) } func (h *Handler) GetAuthStatus(c *gin.Context) { diff --git a/internal/api/server.go b/internal/api/server.go index 2fb74e03..5eefb78a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -320,6 +320,18 @@ func (s *Server) setupRoutes() { c.String(http.StatusOK, "

Authentication successful!

You can close this window.

") }) + s.engine.GET("/iflow/callback", func(c *gin.Context) { + code := c.Query("code") + state := c.Query("state") + errStr := c.Query("error") + if state != "" { + file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state) + _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + } + c.Header("Content-Type", "text/html; charset=utf-8") + c.String(http.StatusOK, "

Authentication successful!

You can close this window.

") + }) + // Management routes are registered lazily by registerManagementRoutes when a secret is configured. }