mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 13:00:52 +08:00
feat(auth): add callback forwarder support for Web UI in OAuth flows
- Introduced callback forwarders for Anthropic, Gemini, Codex, and iFlow OAuth flows. - Added `is_webui` query parameter detection to enhance Web UI compatibility. - Implemented mechanisms to start and stop callback forwarders dynamically. - Improved error handling and logging for callback server initialization.
This commit is contained in:
@@ -5,14 +5,17 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -38,6 +41,23 @@ var (
|
|||||||
|
|
||||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||||
|
|
||||||
|
const (
|
||||||
|
anthropicCallbackPort = 54545
|
||||||
|
geminiCallbackPort = 8085
|
||||||
|
codexCallbackPort = 1455
|
||||||
|
)
|
||||||
|
|
||||||
|
type callbackForwarder struct {
|
||||||
|
provider string
|
||||||
|
server *http.Server
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
callbackForwardersMu sync.Mutex
|
||||||
|
callbackForwarders = make(map[int]*callbackForwarder)
|
||||||
|
)
|
||||||
|
|
||||||
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
|
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
|
||||||
if len(meta) == 0 {
|
if len(meta) == 0 {
|
||||||
return time.Time{}, false
|
return time.Time{}, false
|
||||||
@@ -91,6 +111,120 @@ func parseLastRefreshValue(v any) (time.Time, bool) {
|
|||||||
return time.Time{}, false
|
return time.Time{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isWebUIRequest(c *gin.Context) bool {
|
||||||
|
raw := strings.TrimSpace(c.Query("is_webui"))
|
||||||
|
if raw == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch strings.ToLower(raw) {
|
||||||
|
case "1", "true", "yes", "on":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) {
|
||||||
|
callbackForwardersMu.Lock()
|
||||||
|
prev := callbackForwarders[port]
|
||||||
|
if prev != nil {
|
||||||
|
delete(callbackForwarders, port)
|
||||||
|
}
|
||||||
|
callbackForwardersMu.Unlock()
|
||||||
|
|
||||||
|
if prev != nil {
|
||||||
|
stopForwarderInstance(port, prev)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
ln, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
target := targetBase
|
||||||
|
if raw := r.URL.RawQuery; raw != "" {
|
||||||
|
if strings.Contains(target, "?") {
|
||||||
|
target = target + "&" + raw
|
||||||
|
} else {
|
||||||
|
target = target + "?" + raw
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
|
http.Redirect(w, r, target, http.StatusFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
WriteTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if errServe := srv.Serve(ln); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
||||||
|
log.WithError(errServe).Warnf("callback forwarder for %s stopped unexpectedly", provider)
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
forwarder := &callbackForwarder{
|
||||||
|
provider: provider,
|
||||||
|
server: srv,
|
||||||
|
done: done,
|
||||||
|
}
|
||||||
|
|
||||||
|
callbackForwardersMu.Lock()
|
||||||
|
callbackForwarders[port] = forwarder
|
||||||
|
callbackForwardersMu.Unlock()
|
||||||
|
|
||||||
|
log.Infof("callback forwarder for %s listening on %s", provider, addr)
|
||||||
|
|
||||||
|
return forwarder, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopCallbackForwarder(port int) {
|
||||||
|
callbackForwardersMu.Lock()
|
||||||
|
forwarder := callbackForwarders[port]
|
||||||
|
if forwarder != nil {
|
||||||
|
delete(callbackForwarders, port)
|
||||||
|
}
|
||||||
|
callbackForwardersMu.Unlock()
|
||||||
|
|
||||||
|
stopForwarderInstance(port, forwarder)
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||||
|
if forwarder == nil || forwarder.server == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := forwarder.server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.WithError(err).Warnf("failed to shut down callback forwarder on port %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-forwarder.done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("callback forwarder on port %d stopped", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
||||||
|
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
||||||
|
return "", fmt.Errorf("server port is not configured")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(path, "/") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("http://127.0.0.1:%d%s", h.cfg.Port, path), nil
|
||||||
|
}
|
||||||
|
|
||||||
// List auth files
|
// List auth files
|
||||||
func (h *Handler) ListAuthFiles(c *gin.Context) {
|
func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||||
entries, err := os.ReadDir(h.cfg.AuthDir)
|
entries, err := os.ReadDir(h.cfg.AuthDir)
|
||||||
@@ -390,9 +524,27 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
log.Fatalf("Failed to generate authorization URL: %v", err)
|
log.Fatalf("Failed to generate authorization URL: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Override redirect_uri in authorization URL to current server port
|
|
||||||
|
isWebUI := isWebUIRequest(c)
|
||||||
|
if isWebUI {
|
||||||
|
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
|
||||||
|
if errTarget != nil {
|
||||||
|
log.WithError(errTarget).Error("failed to compute anthropic callback target")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
|
||||||
|
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
if isWebUI {
|
||||||
|
defer stopCallbackForwarder(anthropicCallbackPort)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper: wait for callback file
|
// Helper: wait for callback file
|
||||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
|
||||||
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
|
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
|
||||||
@@ -553,7 +705,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||||
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||||
|
|
||||||
|
isWebUI := isWebUIRequest(c)
|
||||||
|
if isWebUI {
|
||||||
|
targetURL, errTarget := h.managementCallbackURL("/google/callback")
|
||||||
|
if errTarget != nil {
|
||||||
|
log.WithError(errTarget).Error("failed to compute gemini callback target")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
|
||||||
|
log.WithError(errStart).Error("failed to start gemini callback forwarder")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
if isWebUI {
|
||||||
|
defer stopCallbackForwarder(geminiCallbackPort)
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for callback file written by server route
|
// Wait for callback file written by server route
|
||||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
|
||||||
fmt.Println("Waiting for authentication callback...")
|
fmt.Println("Waiting for authentication callback...")
|
||||||
@@ -779,7 +950,26 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isWebUI := isWebUIRequest(c)
|
||||||
|
if isWebUI {
|
||||||
|
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
|
||||||
|
if errTarget != nil {
|
||||||
|
log.WithError(errTarget).Error("failed to compute codex callback target")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
|
||||||
|
log.WithError(errStart).Error("failed to start codex callback forwarder")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
if isWebUI {
|
||||||
|
defer stopCallbackForwarder(codexCallbackPort)
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for callback file
|
// Wait for callback file
|
||||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
|
||||||
deadline := time.Now().Add(5 * time.Minute)
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
@@ -966,6 +1156,103 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
|
|
||||||
state := fmt.Sprintf("ifl-%d", time.Now().UnixNano())
|
state := fmt.Sprintf("ifl-%d", time.Now().UnixNano())
|
||||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||||
|
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
||||||
|
|
||||||
|
isWebUI := isWebUIRequest(c)
|
||||||
|
if isWebUI {
|
||||||
|
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
|
||||||
|
if errTarget != nil {
|
||||||
|
log.WithError(errTarget).Error("failed to compute iflow callback target")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
|
||||||
|
log.WithError(errStart).Error("failed to start iflow callback forwarder")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer stopCallbackForwarder(iflowauth.CallbackPort)
|
||||||
|
fmt.Println("Waiting for authentication...")
|
||||||
|
|
||||||
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state))
|
||||||
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
|
var resultMap map[string]string
|
||||||
|
for {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
oauthStatus[state] = "Authentication failed"
|
||||||
|
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
|
_ = os.Remove(waitFile)
|
||||||
|
_ = json.Unmarshal(data, &resultMap)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||||
|
oauthStatus[state] = "Authentication failed"
|
||||||
|
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||||
|
oauthStatus[state] = "Authentication failed"
|
||||||
|
fmt.Println("Authentication failed: state mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := strings.TrimSpace(resultMap["code"])
|
||||||
|
if code == "" {
|
||||||
|
oauthStatus[state] = "Authentication failed"
|
||||||
|
fmt.Println("Authentication failed: code missing")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||||
|
if errExchange != nil {
|
||||||
|
oauthStatus[state] = "Authentication failed"
|
||||||
|
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStorage := authSvc.CreateTokenStorage(tokenData)
|
||||||
|
identifier := strings.TrimSpace(tokenStorage.Email)
|
||||||
|
if identifier == "" {
|
||||||
|
identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
|
||||||
|
tokenStorage.Email = identifier
|
||||||
|
}
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fmt.Sprintf("iflow-%s.json", identifier),
|
||||||
|
Provider: "iflow",
|
||||||
|
FileName: fmt.Sprintf("iflow-%s.json", identifier),
|
||||||
|
Storage: tokenStorage,
|
||||||
|
Metadata: map[string]any{"email": identifier, "api_key": tokenStorage.APIKey},
|
||||||
|
Attributes: map[string]string{"api_key": tokenStorage.APIKey},
|
||||||
|
}
|
||||||
|
|
||||||
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
|
if errSave != nil {
|
||||||
|
oauthStatus[state] = "Failed to save authentication tokens"
|
||||||
|
log.Fatalf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
|
if tokenStorage.APIKey != "" {
|
||||||
|
fmt.Println("API key obtained and saved")
|
||||||
|
}
|
||||||
|
fmt.Println("You can now use iFlow services through this CLI")
|
||||||
|
delete(oauthStatus, state)
|
||||||
|
}()
|
||||||
|
|
||||||
|
oauthStatus[state] = ""
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
oauthServer := iflowauth.NewOAuthServer(iflowauth.CallbackPort)
|
oauthServer := iflowauth.NewOAuthServer(iflowauth.CallbackPort)
|
||||||
if err := oauthServer.Start(); err != nil {
|
if err := oauthServer.Start(); err != nil {
|
||||||
oauthStatus[state] = "Failed to start authentication server"
|
oauthStatus[state] = "Failed to start authentication server"
|
||||||
@@ -974,8 +1261,6 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
fmt.Println("Waiting for authentication...")
|
fmt.Println("Waiting for authentication...")
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1043,7 +1328,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
oauthStatus[state] = ""
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||||
|
|||||||
@@ -320,6 +320,18 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/iflow/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if state != "" {
|
||||||
|
file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state)
|
||||||
|
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||||
|
})
|
||||||
|
|
||||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user