mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(auth): add callback forwarder support for Web UI in OAuth flows
- Introduced callback forwarders for Anthropic, Gemini, Codex, and iFlow OAuth flows. - Added `is_webui` query parameter detection to enhance Web UI compatibility. - Implemented mechanisms to start and stop callback forwarders dynamically. - Improved error handling and logging for callback server initialization.
This commit is contained in:
@@ -5,14 +5,17 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -38,6 +41,23 @@ var (
|
||||
|
||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||
|
||||
const (
|
||||
anthropicCallbackPort = 54545
|
||||
geminiCallbackPort = 8085
|
||||
codexCallbackPort = 1455
|
||||
)
|
||||
|
||||
type callbackForwarder struct {
|
||||
provider string
|
||||
server *http.Server
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
var (
|
||||
callbackForwardersMu sync.Mutex
|
||||
callbackForwarders = make(map[int]*callbackForwarder)
|
||||
)
|
||||
|
||||
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
|
||||
if len(meta) == 0 {
|
||||
return time.Time{}, false
|
||||
@@ -91,6 +111,120 @@ func parseLastRefreshValue(v any) (time.Time, bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func isWebUIRequest(c *gin.Context) bool {
|
||||
raw := strings.TrimSpace(c.Query("is_webui"))
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
switch strings.ToLower(raw) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) {
|
||||
callbackForwardersMu.Lock()
|
||||
prev := callbackForwarders[port]
|
||||
if prev != nil {
|
||||
delete(callbackForwarders, port)
|
||||
}
|
||||
callbackForwardersMu.Unlock()
|
||||
|
||||
if prev != nil {
|
||||
stopForwarderInstance(port, prev)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
target := targetBase
|
||||
if raw := r.URL.RawQuery; raw != "" {
|
||||
if strings.Contains(target, "?") {
|
||||
target = target + "&" + raw
|
||||
} else {
|
||||
target = target + "?" + raw
|
||||
}
|
||||
}
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
http.Redirect(w, r, target, http.StatusFound)
|
||||
})
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
}
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
if errServe := srv.Serve(ln); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
||||
log.WithError(errServe).Warnf("callback forwarder for %s stopped unexpectedly", provider)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
forwarder := &callbackForwarder{
|
||||
provider: provider,
|
||||
server: srv,
|
||||
done: done,
|
||||
}
|
||||
|
||||
callbackForwardersMu.Lock()
|
||||
callbackForwarders[port] = forwarder
|
||||
callbackForwardersMu.Unlock()
|
||||
|
||||
log.Infof("callback forwarder for %s listening on %s", provider, addr)
|
||||
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func stopCallbackForwarder(port int) {
|
||||
callbackForwardersMu.Lock()
|
||||
forwarder := callbackForwarders[port]
|
||||
if forwarder != nil {
|
||||
delete(callbackForwarders, port)
|
||||
}
|
||||
callbackForwardersMu.Unlock()
|
||||
|
||||
stopForwarderInstance(port, forwarder)
|
||||
}
|
||||
|
||||
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||
if forwarder == nil || forwarder.server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := forwarder.server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.WithError(err).Warnf("failed to shut down callback forwarder on port %d", port)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-forwarder.done:
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
|
||||
log.Infof("callback forwarder on port %d stopped", port)
|
||||
}
|
||||
|
||||
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
||||
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
||||
return "", fmt.Errorf("server port is not configured")
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
return fmt.Sprintf("http://127.0.0.1:%d%s", h.cfg.Port, path), nil
|
||||
}
|
||||
|
||||
// List auth files
|
||||
func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||
entries, err := os.ReadDir(h.cfg.AuthDir)
|
||||
@@ -390,9 +524,27 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
log.Fatalf("Failed to generate authorization URL: %v", err)
|
||||
return
|
||||
}
|
||||
// Override redirect_uri in authorization URL to current server port
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute anthropic callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(anthropicCallbackPort)
|
||||
}
|
||||
|
||||
// Helper: wait for callback file
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
|
||||
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
|
||||
@@ -553,7 +705,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/google/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute gemini callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start gemini callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(geminiCallbackPort)
|
||||
}
|
||||
|
||||
// Wait for callback file written by server route
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
|
||||
fmt.Println("Waiting for authentication callback...")
|
||||
@@ -779,7 +950,26 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute codex callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start codex callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(codexCallbackPort)
|
||||
}
|
||||
|
||||
// Wait for callback file
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
@@ -966,6 +1156,103 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
|
||||
state := fmt.Sprintf("ifl-%d", time.Now().UnixNano())
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute iflow callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start iflow callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stopCallbackForwarder(iflowauth.CallbackPort)
|
||||
fmt.Println("Waiting for authentication...")
|
||||
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var resultMap map[string]string
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
_ = os.Remove(waitFile)
|
||||
_ = json.Unmarshal(data, &resultMap)
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||
return
|
||||
}
|
||||
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
fmt.Println("Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(resultMap["code"])
|
||||
if code == "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
fmt.Println("Authentication failed: code missing")
|
||||
return
|
||||
}
|
||||
|
||||
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||
if errExchange != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||
return
|
||||
}
|
||||
|
||||
tokenStorage := authSvc.CreateTokenStorage(tokenData)
|
||||
identifier := strings.TrimSpace(tokenStorage.Email)
|
||||
if identifier == "" {
|
||||
identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
|
||||
tokenStorage.Email = identifier
|
||||
}
|
||||
record := &coreauth.Auth{
|
||||
ID: fmt.Sprintf("iflow-%s.json", identifier),
|
||||
Provider: "iflow",
|
||||
FileName: fmt.Sprintf("iflow-%s.json", identifier),
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{"email": identifier, "api_key": tokenStorage.APIKey},
|
||||
Attributes: map[string]string{"api_key": tokenStorage.APIKey},
|
||||
}
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
log.Fatalf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if tokenStorage.APIKey != "" {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use iFlow services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
return
|
||||
}
|
||||
|
||||
oauthServer := iflowauth.NewOAuthServer(iflowauth.CallbackPort)
|
||||
if err := oauthServer.Start(); err != nil {
|
||||
oauthStatus[state] = "Failed to start authentication server"
|
||||
@@ -974,8 +1261,6 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
defer func() {
|
||||
@@ -1043,7 +1328,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
|
||||
Reference in New Issue
Block a user