Merge pull request #1527 from HEUDavid/feat/auth-hook

feat(auth): add post-auth hook mechanism
This commit is contained in:
Luis Pater
2026-02-24 05:33:13 +08:00
committed by GitHub
13 changed files with 222 additions and 6 deletions

View File

@@ -945,11 +945,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
if store == nil { if store == nil {
return "", fmt.Errorf("token store unavailable") return "", fmt.Errorf("token store unavailable")
} }
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record) return store.Save(ctx, record)
} }
func (h *Handler) RequestAnthropicToken(c *gin.Context) { func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...") fmt.Println("Initializing Claude authentication...")
@@ -1094,6 +1100,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
@@ -1352,6 +1359,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
func (h *Handler) RequestCodexToken(c *gin.Context) { func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...") fmt.Println("Initializing Codex authentication...")
@@ -1497,6 +1505,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
func (h *Handler) RequestAntigravityToken(c *gin.Context) { func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...") fmt.Println("Initializing Antigravity authentication...")
@@ -1661,6 +1670,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
func (h *Handler) RequestQwenToken(c *gin.Context) { func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...") fmt.Println("Initializing Qwen authentication...")
@@ -1716,6 +1726,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
func (h *Handler) RequestKimiToken(c *gin.Context) { func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...") fmt.Println("Initializing Kimi authentication...")
@@ -1792,6 +1803,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
func (h *Handler) RequestIFlowToken(c *gin.Context) { func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...") fmt.Println("Initializing iFlow authentication...")
@@ -2412,3 +2424,12 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
} }
c.JSON(http.StatusOK, gin.H{"status": "wait"}) c.JSON(http.StatusOK, gin.H{"status": "wait"})
} }
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: c.Request.URL.Query(),
Headers: c.Request.Header,
}
return coreauth.WithRequestInfo(ctx, info)
}

View File

@@ -47,6 +47,7 @@ type Handler struct {
allowRemoteOverride bool allowRemoteOverride bool
envSecret string envSecret string
logDir string logDir string
postAuthHook coreauth.PostAuthHook
} }
// NewHandler creates a new management handler instance. // NewHandler creates a new management handler instance.
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
h.logDir = dir h.logDir = dir
} }
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints. // Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key. // All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true. // Additionally, remote access requires allow-remote-management=true.

View File

@@ -51,6 +51,7 @@ type serverOptionConfig struct {
keepAliveEnabled bool keepAliveEnabled bool
keepAliveTimeout time.Duration keepAliveTimeout time.Duration
keepAliveOnTimeout func() keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
} }
// ServerOption customises HTTP server construction. // ServerOption customises HTTP server construction.
@@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
} }
} }
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server. // Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration. // It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct { type Server struct {
@@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
} }
logDir := logging.ResolveLogDirectory(cfg) logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir) s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword s.localPassword = optionState.localPassword
// Setup routes // Setup routes

View File

@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
// Expire is the timestamp when the current access token expires. // Expire is the timestamp when the current access token expires.
Expire string `json:"expired"` Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
} }
// SaveTokenToFile serializes the Claude token storage to a JSON file. // SaveTokenToFile serializes the Claude token storage to a JSON file.
// This method creates the necessary directory structure and writes the token // This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage. // data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
// //
// Parameters: // Parameters:
// - authFilePath: The full path where the token file should be saved // - authFilePath: The full path where the token file should be saved
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close() _ = f.Close()
}() }()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
// Encode and write the token data as JSON // Encode and write the token data as JSON
if err = json.NewEncoder(f).Encode(ts); err != nil { if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err) return fmt.Errorf("failed to write token to file: %w", err)
} }
return nil return nil

View File

@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
Type string `json:"type"` Type string `json:"type"`
// Expire is the timestamp when the current access token expires. // Expire is the timestamp when the current access token expires.
Expire string `json:"expired"` Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
} }
// SaveTokenToFile serializes the Codex token storage to a JSON file. // SaveTokenToFile serializes the Codex token storage to a JSON file.
// This method creates the necessary directory structure and writes the token // This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage. // data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
// //
// Parameters: // Parameters:
// - authFilePath: The full path where the token file should be saved // - authFilePath: The full path where the token file should be saved
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close() _ = f.Close()
}() }()
if err = json.NewEncoder(f).Encode(ts); err != nil { // Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err) return fmt.Errorf("failed to write token to file: %w", err)
} }
return nil return nil

View File

@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
// Type indicates the authentication provider type, always "gemini" for this storage. // Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"` Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
} }
// SaveTokenToFile serializes the Gemini token storage to a JSON file. // SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token // This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage. // data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
// //
// Parameters: // Parameters:
// - authFilePath: The full path where the token file should be saved // - authFilePath: The full path where the token file should be saved
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath) misc.LogSavingCredentials(authFilePath)
ts.Type = "gemini" ts.Type = "gemini"
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err) return fmt.Errorf("failed to create directory: %v", err)
} }
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
} }
}() }()
if err = json.NewEncoder(f).Encode(ts); err != nil { enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err) return fmt.Errorf("failed to write token to file: %w", err)
} }
return nil return nil

View File

@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
Scope string `json:"scope"` Scope string `json:"scope"`
Cookie string `json:"cookie"` Cookie string `json:"cookie"`
Type string `json:"type"` Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
} }
// SaveTokenToFile serialises the token storage to disk. // SaveTokenToFile serialises the token storage to disk.
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
} }
defer func() { _ = f.Close() }() defer func() { _ = f.Close() }()
if err = json.NewEncoder(f).Encode(ts); err != nil { // Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("iflow token: encode token failed: %w", err) return fmt.Errorf("iflow token: encode token failed: %w", err)
} }
return nil return nil

View File

@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
Expired string `json:"expired,omitempty"` Expired string `json:"expired,omitempty"`
// Type indicates the authentication provider type, always "kimi" for this storage. // Type indicates the authentication provider type, always "kimi" for this storage.
Type string `json:"type"` Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
} }
// KimiTokenData holds the raw OAuth token response from Kimi. // KimiTokenData holds the raw OAuth token response from Kimi.
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close() _ = f.Close()
}() }()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(f) encoder := json.NewEncoder(f)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
if err = encoder.Encode(ts); err != nil { if err = encoder.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err) return fmt.Errorf("failed to write token to file: %w", err)
} }
return nil return nil

View File

@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
Type string `json:"type"` Type string `json:"type"`
// Expire is the timestamp when the current access token expires. // Expire is the timestamp when the current access token expires.
Expire string `json:"expired"` Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
} }
// SaveTokenToFile serializes the Qwen token storage to a JSON file. // SaveTokenToFile serializes the Qwen token storage to a JSON file.
// This method creates the necessary directory structure and writes the token // This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage. // data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
// //
// Parameters: // Parameters:
// - authFilePath: The full path where the token file should be saved // - authFilePath: The full path where the token file should be saved
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close() _ = f.Close()
}() }()
if err = json.NewEncoder(f).Encode(ts); err != nil { // Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err) return fmt.Errorf("failed to write token to file: %w", err)
} }
return nil return nil

View File

@@ -1,6 +1,7 @@
package misc package misc
import ( import (
"encoding/json"
"fmt" "fmt"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
func LogCredentialSeparator() { func LogCredentialSeparator() {
log.Debug(credentialSeparator) log.Debug(credentialSeparator)
} }
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
var data map[string]any
// Fast path: if source is already a map, just copy it to avoid mutation of original
if srcMap, ok := source.(map[string]any); ok {
data = make(map[string]any, len(srcMap)+len(metadata))
for k, v := range srcMap {
data[k] = v
}
} else {
// Slow path: marshal to JSON and back to map to respect JSON tags
temp, err := json.Marshal(source)
if err != nil {
return nil, fmt.Errorf("failed to marshal source: %w", err)
}
if err := json.Unmarshal(temp, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
}
// Merge extra metadata
if metadata != nil {
if data == nil {
data = make(map[string]any)
}
for k, v := range metadata {
data[k] = v
}
}
return data, nil
}

View File

@@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
return "", fmt.Errorf("auth filestore: create dir failed: %w", err) return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
} }
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
type metadataSetter interface {
SetMetadata(map[string]any)
}
switch { switch {
case auth.Storage != nil: case auth.Storage != nil:
if setter, ok := auth.Storage.(metadataSetter); ok {
setter.SetMetadata(auth.Metadata)
}
if err = auth.Storage.SaveTokenToFile(path); err != nil { if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err return "", err
} }

View File

@@ -1,9 +1,12 @@
package auth package auth
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -12,6 +15,33 @@ import (
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth" baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
) )
// PostAuthHook defines a function that is called after an Auth record is created
// but before it is persisted to storage. This allows for modification of the
// Auth record (e.g., injecting metadata) based on external context.
type PostAuthHook func(context.Context, *Auth) error
// RequestInfo holds information extracted from the HTTP request.
// It is injected into the context passed to PostAuthHook.
type RequestInfo struct {
Query url.Values
Headers http.Header
}
type requestInfoKey struct{}
// WithRequestInfo returns a new context with the given RequestInfo attached.
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, info)
}
// GetRequestInfo retrieves the RequestInfo from the context, if present.
func GetRequestInfo(ctx context.Context) *RequestInfo {
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
return val
}
return nil
}
// Auth encapsulates the runtime state and metadata associated with a single credential. // Auth encapsulates the runtime state and metadata associated with a single credential.
type Auth struct { type Auth struct {
// ID uniquely identifies the auth record across restarts. // ID uniquely identifies the auth record across restarts.

View File

@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
return b return b
} }
// WithPostAuthHook registers a hook to be called after an Auth record is created
// but before it is persisted to storage.
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
if hook == nil {
return b
}
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
return b
}
// Build validates inputs, applies defaults, and returns a ready-to-run service. // Build validates inputs, applies defaults, and returns a ready-to-run service.
func (b *Builder) Build() (*Service, error) { func (b *Builder) Build() (*Service, error) {
if b.cfg == nil { if b.cfg == nil {