mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 10:54:03 +08:00
feat/auth-hook: add post auth hook
This commit is contained in:
@@ -864,11 +864,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
|||||||
if store == nil {
|
if store == nil {
|
||||||
return "", fmt.Errorf("token store unavailable")
|
return "", fmt.Errorf("token store unavailable")
|
||||||
}
|
}
|
||||||
|
if h.postAuthHook != nil {
|
||||||
|
if err := h.postAuthHook(ctx, record); err != nil {
|
||||||
|
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return store.Save(ctx, record)
|
return store.Save(ctx, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Claude authentication...")
|
fmt.Println("Initializing Claude authentication...")
|
||||||
|
|
||||||
@@ -1013,6 +1019,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||||
|
|
||||||
@@ -1247,6 +1254,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Codex authentication...")
|
fmt.Println("Initializing Codex authentication...")
|
||||||
|
|
||||||
@@ -1392,6 +1400,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Antigravity authentication...")
|
fmt.Println("Initializing Antigravity authentication...")
|
||||||
|
|
||||||
@@ -1556,6 +1565,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Qwen authentication...")
|
fmt.Println("Initializing Qwen authentication...")
|
||||||
|
|
||||||
@@ -1611,6 +1621,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Kimi authentication...")
|
fmt.Println("Initializing Kimi authentication...")
|
||||||
|
|
||||||
@@ -1687,6 +1698,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing iFlow authentication...")
|
fmt.Println("Initializing iFlow authentication...")
|
||||||
|
|
||||||
@@ -2266,3 +2278,28 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PopulateAuthContext extracts request info and adds it to the context
|
||||||
|
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||||
|
info := &coreauth.RequestInfo{
|
||||||
|
Query: make(map[string]string),
|
||||||
|
Headers: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture all query parameters
|
||||||
|
for k, v := range c.Request.URL.Query() {
|
||||||
|
if len(v) > 0 {
|
||||||
|
info.Query[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture specific headers relevant for logging/auditing
|
||||||
|
headers := []string{"User-Agent", "X-Forwarded-For", "X-Real-IP", "Referer"}
|
||||||
|
for _, h := range headers {
|
||||||
|
if val := c.GetHeader(h); val != "" {
|
||||||
|
info.Headers[h] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return context.WithValue(ctx, "request_info", info)
|
||||||
|
}
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type Handler struct {
|
|||||||
allowRemoteOverride bool
|
allowRemoteOverride bool
|
||||||
envSecret string
|
envSecret string
|
||||||
logDir string
|
logDir string
|
||||||
|
postAuthHook coreauth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new management handler instance.
|
// NewHandler creates a new management handler instance.
|
||||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
|||||||
h.logDir = dir
|
h.logDir = dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||||
|
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||||
|
h.postAuthHook = hook
|
||||||
|
}
|
||||||
|
|
||||||
// Middleware enforces access control for management endpoints.
|
// Middleware enforces access control for management endpoints.
|
||||||
// All requests (local and remote) require a valid management key.
|
// All requests (local and remote) require a valid management key.
|
||||||
// Additionally, remote access requires allow-remote-management=true.
|
// Additionally, remote access requires allow-remote-management=true.
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ type serverOptionConfig struct {
|
|||||||
keepAliveEnabled bool
|
keepAliveEnabled bool
|
||||||
keepAliveTimeout time.Duration
|
keepAliveTimeout time.Duration
|
||||||
keepAliveOnTimeout func()
|
keepAliveOnTimeout func()
|
||||||
|
postAuthHook auth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerOption customises HTTP server construction.
|
// ServerOption customises HTTP server construction.
|
||||||
@@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||||
|
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||||
|
return func(cfg *serverOptionConfig) {
|
||||||
|
cfg.postAuthHook = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Server represents the main API server.
|
// Server represents the main API server.
|
||||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
}
|
}
|
||||||
logDir := logging.ResolveLogDirectory(cfg)
|
logDir := logging.ResolveLogDirectory(cfg)
|
||||||
s.mgmt.SetLogDirectory(logDir)
|
s.mgmt.SetLogDirectory(logDir)
|
||||||
|
if optionState.postAuthHook != nil {
|
||||||
|
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||||
|
}
|
||||||
s.localPassword = optionState.localPassword
|
s.localPassword = optionState.localPassword
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
|
|||||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
|||||||
|
|
||||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -63,7 +73,24 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Convert struct to map for merging
|
||||||
|
data := make(map[string]any)
|
||||||
|
temp, errJson := json.Marshal(ts)
|
||||||
|
if errJson != nil {
|
||||||
|
return fmt.Errorf("failed to marshal struct: %w", errJson)
|
||||||
|
}
|
||||||
|
if errUnmarshal := json.Unmarshal(temp, &data); errUnmarshal != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal struct map: %w", errUnmarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge extra metadata
|
||||||
|
if ts.Metadata != nil {
|
||||||
|
for k, v := range ts.Metadata {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -62,8 +62,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
|||||||
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
|
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
|
||||||
|
type metadataSetter interface {
|
||||||
|
SetMetadata(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case auth.Storage != nil:
|
case auth.Storage != nil:
|
||||||
|
if setter, ok := auth.Storage.(metadataSetter); ok {
|
||||||
|
setter.SetMetadata(auth.Metadata)
|
||||||
|
}
|
||||||
if err = auth.Storage.SaveTokenToFile(path); err != nil {
|
if err = auth.Storage.SaveTokenToFile(path); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -12,6 +13,18 @@ import (
|
|||||||
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PostAuthHook defines a function that is called after an Auth record is created
|
||||||
|
// but before it is persisted to storage. This allows for modification of the
|
||||||
|
// Auth record (e.g., injecting metadata) based on external context.
|
||||||
|
type PostAuthHook func(context.Context, *Auth) error
|
||||||
|
|
||||||
|
// RequestInfo holds information extracted from the HTTP request.
|
||||||
|
// It is injected into the context passed to PostAuthHook.
|
||||||
|
type RequestInfo struct {
|
||||||
|
Query map[string]string
|
||||||
|
Headers map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
// ID uniquely identifies the auth record across restarts.
|
// ID uniquely identifies the auth record across restarts.
|
||||||
|
|||||||
@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPostAuthHook registers a hook to be called after an Auth record is created
|
||||||
|
// but before it is persisted to storage.
|
||||||
|
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
|
||||||
|
if hook == nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// Build validates inputs, applies defaults, and returns a ready-to-run service.
|
// Build validates inputs, applies defaults, and returns a ready-to-run service.
|
||||||
func (b *Builder) Build() (*Service, error) {
|
func (b *Builder) Build() (*Service, error) {
|
||||||
if b.cfg == nil {
|
if b.cfg == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user