feat(ws): add WebSocket auth

This commit is contained in:
hkfires
2025-10-25 21:40:20 +08:00
parent ea6065f1b1
commit 359b8de44e
9 changed files with 119 additions and 14 deletions

View File

@@ -43,6 +43,9 @@ quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
# API keys for official Generative Language API # API keys for official Generative Language API
#generative-language-api-key: #generative-language-api-key:
# - "AIzaSy...01" # - "AIzaSy...01"

View File

@@ -57,10 +57,12 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key") authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
authHeaderAnthropic := r.Header.Get("X-Api-Key") authHeaderAnthropic := r.Header.Get("X-Api-Key")
queryKey := "" queryKey := ""
queryAuthToken := ""
if r.URL != nil { if r.URL != nil {
queryKey = r.URL.Query().Get("key") queryKey = r.URL.Query().Get("key")
queryAuthToken = r.URL.Query().Get("auth_token")
} }
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" { if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
return nil, sdkaccess.ErrNoCredentials return nil, sdkaccess.ErrNoCredentials
} }
@@ -74,6 +76,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
{authHeaderGoogle, "x-goog-api-key"}, {authHeaderGoogle, "x-goog-api-key"},
{authHeaderAnthropic, "x-api-key"}, {authHeaderAnthropic, "x-api-key"},
{queryKey, "query-key"}, {queryKey, "query-key"},
{queryAuthToken, "query-auth-token"},
} }
for _, candidate := range candidates { for _, candidate := range candidates {

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
) )
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses. // RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
@@ -63,13 +64,11 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
// It captures the URL, method, headers, and body. The request body is read and then // It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers. // restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) { func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
// Capture URL // Capture URL with sensitive query parameters masked
url := c.Request.URL.String() maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
if c.Request.URL.Path != "" { url := c.Request.URL.Path
url = c.Request.URL.Path if maskedQuery != "" {
if c.Request.URL.RawQuery != "" { url += "?" + maskedQuery
url += "?" + c.Request.URL.RawQuery
}
} }
// Capture method // Capture method

View File

@@ -140,8 +140,10 @@ type Server struct {
currentPath string currentPath string
// wsRoutes tracks registered websocket upgrade paths. // wsRoutes tracks registered websocket upgrade paths.
wsRouteMu sync.Mutex wsRouteMu sync.Mutex
wsRoutes map[string]struct{} wsRoutes map[string]struct{}
wsAuthChanged func(bool, bool)
wsAuthEnabled atomic.Bool
// management handler // management handler
mgmt *managementHandlers.Handler mgmt *managementHandlers.Handler
@@ -235,6 +237,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
envManagementSecret: envManagementSecret, envManagementSecret: envManagementSecret,
wsRoutes: make(map[string]struct{}), wsRoutes: make(map[string]struct{}),
} }
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
// Save initial YAML snapshot // Save initial YAML snapshot
s.oldConfigYaml, _ = yaml.Marshal(cfg) s.oldConfigYaml, _ = yaml.Marshal(cfg)
s.applyAccessConfig(nil, cfg) s.applyAccessConfig(nil, cfg)
@@ -398,10 +401,20 @@ func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
s.wsRoutes[trimmed] = struct{}{} s.wsRoutes[trimmed] = struct{}{}
s.wsRouteMu.Unlock() s.wsRouteMu.Unlock()
s.engine.GET(trimmed, func(c *gin.Context) { authMiddleware := AuthMiddleware(s.accessManager)
conditionalAuth := func(c *gin.Context) {
if !s.wsAuthEnabled.Load() {
c.Next()
return
}
authMiddleware(c)
}
finalHandler := func(c *gin.Context) {
handler.ServeHTTP(c.Writer, c.Request) handler.ServeHTTP(c.Writer, c.Request)
c.Abort() c.Abort()
}) }
s.engine.GET(trimmed, conditionalAuth, finalHandler)
} }
func (s *Server) registerManagementRoutes() { func (s *Server) registerManagementRoutes() {
@@ -803,6 +816,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
s.applyAccessConfig(oldCfg, cfg) s.applyAccessConfig(oldCfg, cfg)
s.cfg = cfg s.cfg = cfg
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
}
managementasset.SetCurrentConfig(cfg) managementasset.SetCurrentConfig(cfg)
// Save YAML snapshot for next comparison // Save YAML snapshot for next comparison
s.oldConfigYaml, _ = yaml.Marshal(cfg) s.oldConfigYaml, _ = yaml.Marshal(cfg)
@@ -843,6 +860,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
) )
} }
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
if s == nil {
return
}
s.wsAuthChanged = fn
}
// (management handlers moved to internal/api/handlers/management) // (management handlers moved to internal/api/handlers/management)
// AuthMiddleware returns a Gin middleware handler that authenticates requests // AuthMiddleware returns a Gin middleware handler that authenticates requests

View File

@@ -40,6 +40,9 @@ type Config struct {
// QuotaExceeded defines the behavior when a quota is exceeded. // QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"` QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
// WebsocketAuth enables or disables authentication for the WebSocket API.
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
// GlAPIKey is the API key for the generative language API. // GlAPIKey is the API key for the generative language API.
GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"` GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"`

View File

@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -23,7 +24,7 @@ func GinLogrusLogger() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
start := time.Now() start := time.Now()
path := c.Request.URL.Path path := c.Request.URL.Path
raw := c.Request.URL.RawQuery raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
c.Next() c.Next()

View File

@@ -4,6 +4,7 @@
package util package util
import ( import (
"net/url"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -188,3 +189,56 @@ func MaskSensitiveHeaderValue(key, value string) string {
return value return value
} }
} }
// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string.
func MaskSensitiveQuery(raw string) string {
if raw == "" {
return ""
}
parts := strings.Split(raw, "&")
changed := false
for i, part := range parts {
if part == "" {
continue
}
keyPart := part
valuePart := ""
if idx := strings.Index(part, "="); idx >= 0 {
keyPart = part[:idx]
valuePart = part[idx+1:]
}
decodedKey, err := url.QueryUnescape(keyPart)
if err != nil {
decodedKey = keyPart
}
if !shouldMaskQueryParam(decodedKey) {
continue
}
decodedValue, err := url.QueryUnescape(valuePart)
if err != nil {
decodedValue = valuePart
}
masked := HideAPIKey(strings.TrimSpace(decodedValue))
parts[i] = keyPart + "=" + url.QueryEscape(masked)
changed = true
}
if !changed {
return raw
}
return strings.Join(parts, "&")
}
func shouldMaskQueryParam(key string) bool {
key = strings.ToLower(strings.TrimSpace(key))
if key == "" {
return false
}
key = strings.TrimSuffix(key, "[]")
if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") {
return true
}
if strings.Contains(key, "token") || strings.Contains(key, "secret") {
return true
}
return false
}

View File

@@ -1204,6 +1204,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.ProxyURL != newCfg.ProxyURL { if oldCfg.ProxyURL != newCfg.ProxyURL {
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL)) changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
} }
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
}
// Quota-exceeded behavior // Quota-exceeded behavior
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject { if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {

View File

@@ -421,6 +421,22 @@ func (s *Service) Run(ctx context.Context) error {
s.ensureWebsocketGateway() s.ensureWebsocketGateway()
if s.server != nil && s.wsGateway != nil { if s.server != nil && s.wsGateway != nil {
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler()) s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) {
if oldEnabled == newEnabled {
return
}
if !oldEnabled && newEnabled {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if errStop := s.wsGateway.Stop(ctx); errStop != nil {
log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop)
return
}
log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication")
return
}
log.Debugf("ws-auth disabled; existing websocket sessions remain connected")
})
} }
if s.hooks.OnBeforeStart != nil { if s.hooks.OnBeforeStart != nil {
@@ -460,7 +476,6 @@ func (s *Service) Run(ctx context.Context) error {
s.cfg = newCfg s.cfg = newCfg
s.cfgMu.Unlock() s.cfgMu.Unlock()
s.rebindExecutors() s.rebindExecutors()
} }
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback) watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)