feat/auth-hook: add post auth hook

This commit is contained in:
HEUDavid
2026-02-10 07:26:08 +08:00
parent a146c6c0aa
commit 94563d622c
7 changed files with 113 additions and 1 deletions

View File

@@ -864,11 +864,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
if store == nil { if store == nil {
return "", fmt.Errorf("token store unavailable") return "", fmt.Errorf("token store unavailable")
} }
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record) return store.Save(ctx, record)
} }
func (h *Handler) RequestAnthropicToken(c *gin.Context) { func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...") fmt.Println("Initializing Claude authentication...")
@@ -1013,6 +1019,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
@@ -1247,6 +1254,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
func (h *Handler) RequestCodexToken(c *gin.Context) { func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...") fmt.Println("Initializing Codex authentication...")
@@ -1392,6 +1400,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
func (h *Handler) RequestAntigravityToken(c *gin.Context) { func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...") fmt.Println("Initializing Antigravity authentication...")
@@ -1556,6 +1565,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
func (h *Handler) RequestQwenToken(c *gin.Context) { func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...") fmt.Println("Initializing Qwen authentication...")
@@ -1611,6 +1621,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
func (h *Handler) RequestKimiToken(c *gin.Context) { func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...") fmt.Println("Initializing Kimi authentication...")
@@ -1687,6 +1698,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
func (h *Handler) RequestIFlowToken(c *gin.Context) { func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...") fmt.Println("Initializing iFlow authentication...")
@@ -2266,3 +2278,28 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
} }
c.JSON(http.StatusOK, gin.H{"status": "wait"}) c.JSON(http.StatusOK, gin.H{"status": "wait"})
} }
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: make(map[string]string),
Headers: make(map[string]string),
}
// Capture all query parameters
for k, v := range c.Request.URL.Query() {
if len(v) > 0 {
info.Query[k] = v[0]
}
}
// Capture specific headers relevant for logging/auditing
headers := []string{"User-Agent", "X-Forwarded-For", "X-Real-IP", "Referer"}
for _, h := range headers {
if val := c.GetHeader(h); val != "" {
info.Headers[h] = val
}
}
return context.WithValue(ctx, "request_info", info)
}

View File

@@ -47,6 +47,7 @@ type Handler struct {
allowRemoteOverride bool allowRemoteOverride bool
envSecret string envSecret string
logDir string logDir string
postAuthHook coreauth.PostAuthHook
} }
// NewHandler creates a new management handler instance. // NewHandler creates a new management handler instance.
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
h.logDir = dir h.logDir = dir
} }
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints. // Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key. // All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true. // Additionally, remote access requires allow-remote-management=true.

View File

@@ -51,6 +51,7 @@ type serverOptionConfig struct {
keepAliveEnabled bool keepAliveEnabled bool
keepAliveTimeout time.Duration keepAliveTimeout time.Duration
keepAliveOnTimeout func() keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
} }
// ServerOption customises HTTP server construction. // ServerOption customises HTTP server construction.
@@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
} }
} }
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server. // Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration. // It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct { type Server struct {
@@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
} }
logDir := logging.ResolveLogDirectory(cfg) logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir) s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword s.localPassword = optionState.localPassword
// Setup routes // Setup routes

View File

@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
// Type indicates the authentication provider type, always "gemini" for this storage. // Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"` Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
} }
// SaveTokenToFile serializes the Gemini token storage to a JSON file. // SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token // This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage. // data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
// //
// Parameters: // Parameters:
// - authFilePath: The full path where the token file should be saved // - authFilePath: The full path where the token file should be saved
@@ -63,7 +73,24 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
} }
}() }()
if err = json.NewEncoder(f).Encode(ts); err != nil { // Convert struct to map for merging
data := make(map[string]any)
temp, errJson := json.Marshal(ts)
if errJson != nil {
return fmt.Errorf("failed to marshal struct: %w", errJson)
}
if errUnmarshal := json.Unmarshal(temp, &data); errUnmarshal != nil {
return fmt.Errorf("failed to unmarshal struct map: %w", errUnmarshal)
}
// Merge extra metadata
if ts.Metadata != nil {
for k, v := range ts.Metadata {
data[k] = v
}
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err) return fmt.Errorf("failed to write token to file: %w", err)
} }
return nil return nil

View File

@@ -62,8 +62,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
return "", fmt.Errorf("auth filestore: create dir failed: %w", err) return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
} }
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
type metadataSetter interface {
SetMetadata(map[string]any)
}
switch { switch {
case auth.Storage != nil: case auth.Storage != nil:
if setter, ok := auth.Storage.(metadataSetter); ok {
setter.SetMetadata(auth.Metadata)
}
if err = auth.Storage.SaveTokenToFile(path); err != nil { if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err return "", err
} }

View File

@@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@@ -12,6 +13,18 @@ import (
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth" baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
) )
// PostAuthHook defines a function that is called after an Auth record is created
// but before it is persisted to storage. This allows for modification of the
// Auth record (e.g., injecting metadata) based on external context.
type PostAuthHook func(context.Context, *Auth) error
// RequestInfo holds information extracted from the HTTP request.
// It is injected into the context passed to PostAuthHook.
type RequestInfo struct {
Query map[string]string
Headers map[string]string
}
// Auth encapsulates the runtime state and metadata associated with a single credential. // Auth encapsulates the runtime state and metadata associated with a single credential.
type Auth struct { type Auth struct {
// ID uniquely identifies the auth record across restarts. // ID uniquely identifies the auth record across restarts.

View File

@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
return b return b
} }
// WithPostAuthHook registers a hook to be called after an Auth record is created
// but before it is persisted to storage.
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
if hook == nil {
return b
}
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
return b
}
// Build validates inputs, applies defaults, and returns a ready-to-run service. // Build validates inputs, applies defaults, and returns a ready-to-run service.
func (b *Builder) Build() (*Service, error) { func (b *Builder) Build() (*Service, error) {
if b.cfg == nil { if b.cfg == nil {