mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
feat(ws): add WebSocket auth
This commit is contained in:
@@ -43,6 +43,9 @@ quota-exceeded:
|
|||||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||||
|
|
||||||
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
|
ws-auth: false
|
||||||
|
|
||||||
# API keys for official Generative Language API
|
# API keys for official Generative Language API
|
||||||
#generative-language-api-key:
|
#generative-language-api-key:
|
||||||
# - "AIzaSy...01"
|
# - "AIzaSy...01"
|
||||||
|
|||||||
@@ -57,10 +57,12 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
|||||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||||
authHeaderAnthropic := r.Header.Get("X-Api-Key")
|
authHeaderAnthropic := r.Header.Get("X-Api-Key")
|
||||||
queryKey := ""
|
queryKey := ""
|
||||||
|
queryAuthToken := ""
|
||||||
if r.URL != nil {
|
if r.URL != nil {
|
||||||
queryKey = r.URL.Query().Get("key")
|
queryKey = r.URL.Query().Get("key")
|
||||||
|
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||||
}
|
}
|
||||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" {
|
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||||
return nil, sdkaccess.ErrNoCredentials
|
return nil, sdkaccess.ErrNoCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,6 +76,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
|||||||
{authHeaderGoogle, "x-goog-api-key"},
|
{authHeaderGoogle, "x-goog-api-key"},
|
||||||
{authHeaderAnthropic, "x-api-key"},
|
{authHeaderAnthropic, "x-api-key"},
|
||||||
{queryKey, "query-key"},
|
{queryKey, "query-key"},
|
||||||
|
{queryAuthToken, "query-auth-token"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||||
@@ -63,13 +64,11 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
// It captures the URL, method, headers, and body. The request body is read and then
|
// It captures the URL, method, headers, and body. The request body is read and then
|
||||||
// restored so that it can be processed by subsequent handlers.
|
// restored so that it can be processed by subsequent handlers.
|
||||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||||
// Capture URL
|
// Capture URL with sensitive query parameters masked
|
||||||
url := c.Request.URL.String()
|
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||||
if c.Request.URL.Path != "" {
|
url := c.Request.URL.Path
|
||||||
url = c.Request.URL.Path
|
if maskedQuery != "" {
|
||||||
if c.Request.URL.RawQuery != "" {
|
url += "?" + maskedQuery
|
||||||
url += "?" + c.Request.URL.RawQuery
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture method
|
// Capture method
|
||||||
|
|||||||
@@ -142,6 +142,8 @@ type Server struct {
|
|||||||
// wsRoutes tracks registered websocket upgrade paths.
|
// wsRoutes tracks registered websocket upgrade paths.
|
||||||
wsRouteMu sync.Mutex
|
wsRouteMu sync.Mutex
|
||||||
wsRoutes map[string]struct{}
|
wsRoutes map[string]struct{}
|
||||||
|
wsAuthChanged func(bool, bool)
|
||||||
|
wsAuthEnabled atomic.Bool
|
||||||
|
|
||||||
// management handler
|
// management handler
|
||||||
mgmt *managementHandlers.Handler
|
mgmt *managementHandlers.Handler
|
||||||
@@ -235,6 +237,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
envManagementSecret: envManagementSecret,
|
envManagementSecret: envManagementSecret,
|
||||||
wsRoutes: make(map[string]struct{}),
|
wsRoutes: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
||||||
// Save initial YAML snapshot
|
// Save initial YAML snapshot
|
||||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||||
s.applyAccessConfig(nil, cfg)
|
s.applyAccessConfig(nil, cfg)
|
||||||
@@ -398,10 +401,20 @@ func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
|
|||||||
s.wsRoutes[trimmed] = struct{}{}
|
s.wsRoutes[trimmed] = struct{}{}
|
||||||
s.wsRouteMu.Unlock()
|
s.wsRouteMu.Unlock()
|
||||||
|
|
||||||
s.engine.GET(trimmed, func(c *gin.Context) {
|
authMiddleware := AuthMiddleware(s.accessManager)
|
||||||
|
conditionalAuth := func(c *gin.Context) {
|
||||||
|
if !s.wsAuthEnabled.Load() {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authMiddleware(c)
|
||||||
|
}
|
||||||
|
finalHandler := func(c *gin.Context) {
|
||||||
handler.ServeHTTP(c.Writer, c.Request)
|
handler.ServeHTTP(c.Writer, c.Request)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
})
|
}
|
||||||
|
|
||||||
|
s.engine.GET(trimmed, conditionalAuth, finalHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) registerManagementRoutes() {
|
func (s *Server) registerManagementRoutes() {
|
||||||
@@ -803,6 +816,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
|
|
||||||
s.applyAccessConfig(oldCfg, cfg)
|
s.applyAccessConfig(oldCfg, cfg)
|
||||||
s.cfg = cfg
|
s.cfg = cfg
|
||||||
|
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
||||||
|
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
|
||||||
|
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
|
||||||
|
}
|
||||||
managementasset.SetCurrentConfig(cfg)
|
managementasset.SetCurrentConfig(cfg)
|
||||||
// Save YAML snapshot for next comparison
|
// Save YAML snapshot for next comparison
|
||||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||||
@@ -843,6 +860,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.wsAuthChanged = fn
|
||||||
|
}
|
||||||
|
|
||||||
// (management handlers moved to internal/api/handlers/management)
|
// (management handlers moved to internal/api/handlers/management)
|
||||||
|
|
||||||
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
||||||
|
|||||||
@@ -40,6 +40,9 @@ type Config struct {
|
|||||||
// QuotaExceeded defines the behavior when a quota is exceeded.
|
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||||
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
|
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
|
||||||
|
|
||||||
|
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
||||||
|
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
||||||
|
|
||||||
// GlAPIKey is the API key for the generative language API.
|
// GlAPIKey is the API key for the generative language API.
|
||||||
GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"`
|
GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"`
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ func GinLogrusLogger() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
raw := c.Request.URL.RawQuery
|
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -188,3 +189,56 @@ func MaskSensitiveHeaderValue(key, value string) string {
|
|||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string.
|
||||||
|
func MaskSensitiveQuery(raw string) string {
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parts := strings.Split(raw, "&")
|
||||||
|
changed := false
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keyPart := part
|
||||||
|
valuePart := ""
|
||||||
|
if idx := strings.Index(part, "="); idx >= 0 {
|
||||||
|
keyPart = part[:idx]
|
||||||
|
valuePart = part[idx+1:]
|
||||||
|
}
|
||||||
|
decodedKey, err := url.QueryUnescape(keyPart)
|
||||||
|
if err != nil {
|
||||||
|
decodedKey = keyPart
|
||||||
|
}
|
||||||
|
if !shouldMaskQueryParam(decodedKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
decodedValue, err := url.QueryUnescape(valuePart)
|
||||||
|
if err != nil {
|
||||||
|
decodedValue = valuePart
|
||||||
|
}
|
||||||
|
masked := HideAPIKey(strings.TrimSpace(decodedValue))
|
||||||
|
parts[i] = keyPart + "=" + url.QueryEscape(masked)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "&")
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldMaskQueryParam(key string) bool {
|
||||||
|
key = strings.ToLower(strings.TrimSpace(key))
|
||||||
|
if key == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key = strings.TrimSuffix(key, "[]")
|
||||||
|
if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.Contains(key, "token") || strings.Contains(key, "secret") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -1204,6 +1204,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
|
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
|
||||||
}
|
}
|
||||||
|
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
|
||||||
|
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
|
||||||
|
}
|
||||||
|
|
||||||
// Quota-exceeded behavior
|
// Quota-exceeded behavior
|
||||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||||
|
|||||||
@@ -421,6 +421,22 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
s.ensureWebsocketGateway()
|
s.ensureWebsocketGateway()
|
||||||
if s.server != nil && s.wsGateway != nil {
|
if s.server != nil && s.wsGateway != nil {
|
||||||
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
|
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
|
||||||
|
s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) {
|
||||||
|
if oldEnabled == newEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !oldEnabled && newEnabled {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if errStop := s.wsGateway.Stop(ctx); errStop != nil {
|
||||||
|
log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("ws-auth disabled; existing websocket sessions remain connected")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.hooks.OnBeforeStart != nil {
|
if s.hooks.OnBeforeStart != nil {
|
||||||
@@ -460,7 +476,6 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
s.cfg = newCfg
|
s.cfg = newCfg
|
||||||
s.cfgMu.Unlock()
|
s.cfgMu.Unlock()
|
||||||
s.rebindExecutors()
|
s.rebindExecutors()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)
|
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)
|
||||||
|
|||||||
Reference in New Issue
Block a user