feat(auth): add callback forwarder support for Web UI in OAuth flows

- Introduced callback forwarders for Anthropic, Gemini, Codex, and iFlow OAuth flows.
- Added `is_webui` query parameter detection to enhance Web UI compatibility.
- Implemented mechanisms to start and stop callback forwarders dynamically.
- Improved error handling and logging for callback server initialization.
This commit is contained in:
Luis Pater
2025-10-06 01:52:42 +08:00
parent 49c52a01b0
commit 4fd70d5f1a
2 changed files with 301 additions and 4 deletions

View File

@@ -5,14 +5,17 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
@@ -38,6 +41,23 @@ var (
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
const (
anthropicCallbackPort = 54545
geminiCallbackPort = 8085
codexCallbackPort = 1455
)
type callbackForwarder struct {
provider string
server *http.Server
done chan struct{}
}
var (
callbackForwardersMu sync.Mutex
callbackForwarders = make(map[int]*callbackForwarder)
)
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
if len(meta) == 0 {
return time.Time{}, false
@@ -91,6 +111,120 @@ func parseLastRefreshValue(v any) (time.Time, bool) {
return time.Time{}, false
}
func isWebUIRequest(c *gin.Context) bool {
raw := strings.TrimSpace(c.Query("is_webui"))
if raw == "" {
return false
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) {
callbackForwardersMu.Lock()
prev := callbackForwarders[port]
if prev != nil {
delete(callbackForwarders, port)
}
callbackForwardersMu.Unlock()
if prev != nil {
stopForwarderInstance(port, prev)
}
addr := fmt.Sprintf("127.0.0.1:%d", port)
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
target := targetBase
if raw := r.URL.RawQuery; raw != "" {
if strings.Contains(target, "?") {
target = target + "&" + raw
} else {
target = target + "?" + raw
}
}
w.Header().Set("Cache-Control", "no-store")
http.Redirect(w, r, target, http.StatusFound)
})
srv := &http.Server{
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
done := make(chan struct{})
go func() {
if errServe := srv.Serve(ln); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
log.WithError(errServe).Warnf("callback forwarder for %s stopped unexpectedly", provider)
}
close(done)
}()
forwarder := &callbackForwarder{
provider: provider,
server: srv,
done: done,
}
callbackForwardersMu.Lock()
callbackForwarders[port] = forwarder
callbackForwardersMu.Unlock()
log.Infof("callback forwarder for %s listening on %s", provider, addr)
return forwarder, nil
}
func stopCallbackForwarder(port int) {
callbackForwardersMu.Lock()
forwarder := callbackForwarders[port]
if forwarder != nil {
delete(callbackForwarders, port)
}
callbackForwardersMu.Unlock()
stopForwarderInstance(port, forwarder)
}
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil || forwarder.server == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := forwarder.server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.WithError(err).Warnf("failed to shut down callback forwarder on port %d", port)
}
select {
case <-forwarder.done:
case <-time.After(2 * time.Second):
}
log.Infof("callback forwarder on port %d stopped", port)
}
func (h *Handler) managementCallbackURL(path string) (string, error) {
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
return "", fmt.Errorf("server port is not configured")
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("http://127.0.0.1:%d%s", h.cfg.Port, path), nil
}
// List auth files
func (h *Handler) ListAuthFiles(c *gin.Context) {
entries, err := os.ReadDir(h.cfg.AuthDir)
@@ -390,9 +524,27 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
log.Fatalf("Failed to generate authorization URL: %v", err)
return
}
// Override redirect_uri in authorization URL to current server port
isWebUI := isWebUIRequest(c)
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute anthropic callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarder(anthropicCallbackPort)
}
// Helper: wait for callback file
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
@@ -553,7 +705,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
isWebUI := isWebUIRequest(c)
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/google/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute gemini callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start gemini callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarder(geminiCallbackPort)
}
// Wait for callback file written by server route
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
fmt.Println("Waiting for authentication callback...")
@@ -779,7 +950,26 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
return
}
isWebUI := isWebUIRequest(c)
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute codex callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start codex callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarder(codexCallbackPort)
}
// Wait for callback file
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
@@ -966,6 +1156,103 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
state := fmt.Sprintf("ifl-%d", time.Now().UnixNano())
authSvc := iflowauth.NewIFlowAuth(h.cfg)
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
isWebUI := isWebUIRequest(c)
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute iflow callback target")
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start iflow callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
return
}
go func() {
defer stopCallbackForwarder(iflowauth.CallbackPort)
fmt.Println("Waiting for authentication...")
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var resultMap map[string]string
for {
if time.Now().After(deadline) {
oauthStatus[state] = "Authentication failed"
fmt.Println("Authentication failed: timeout waiting for callback")
return
}
if data, errR := os.ReadFile(waitFile); errR == nil {
_ = os.Remove(waitFile)
_ = json.Unmarshal(data, &resultMap)
break
}
time.Sleep(500 * time.Millisecond)
}
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
oauthStatus[state] = "Authentication failed"
fmt.Printf("Authentication failed: %s\n", errStr)
return
}
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
oauthStatus[state] = "Authentication failed"
fmt.Println("Authentication failed: state mismatch")
return
}
code := strings.TrimSpace(resultMap["code"])
if code == "" {
oauthStatus[state] = "Authentication failed"
fmt.Println("Authentication failed: code missing")
return
}
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
if errExchange != nil {
oauthStatus[state] = "Authentication failed"
fmt.Printf("Authentication failed: %v\n", errExchange)
return
}
tokenStorage := authSvc.CreateTokenStorage(tokenData)
identifier := strings.TrimSpace(tokenStorage.Email)
if identifier == "" {
identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
tokenStorage.Email = identifier
}
record := &coreauth.Auth{
ID: fmt.Sprintf("iflow-%s.json", identifier),
Provider: "iflow",
FileName: fmt.Sprintf("iflow-%s.json", identifier),
Storage: tokenStorage,
Metadata: map[string]any{"email": identifier, "api_key": tokenStorage.APIKey},
Attributes: map[string]string{"api_key": tokenStorage.APIKey},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens"
log.Fatalf("Failed to save authentication tokens: %v", errSave)
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if tokenStorage.APIKey != "" {
fmt.Println("API key obtained and saved")
}
fmt.Println("You can now use iFlow services through this CLI")
delete(oauthStatus, state)
}()
oauthStatus[state] = ""
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
return
}
oauthServer := iflowauth.NewOAuthServer(iflowauth.CallbackPort)
if err := oauthServer.Start(); err != nil {
oauthStatus[state] = "Failed to start authentication server"
@@ -974,8 +1261,6 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
return
}
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
go func() {
fmt.Println("Waiting for authentication...")
defer func() {
@@ -1043,7 +1328,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
}()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) GetAuthStatus(c *gin.Context) {