diff --git a/config.example.yaml b/config.example.yaml index 428df70b..d5795719 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -43,6 +43,9 @@ quota-exceeded: switch-project: true # Whether to automatically switch to another project when a quota is exceeded switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded +# When true, enable authentication for the WebSocket API (/v1/ws). +ws-auth: false + # API keys for official Generative Language API #generative-language-api-key: # - "AIzaSy...01" diff --git a/internal/access/config_access/provider.go b/internal/access/config_access/provider.go index 97a64fe2..70824524 100644 --- a/internal/access/config_access/provider.go +++ b/internal/access/config_access/provider.go @@ -57,10 +57,12 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess. authHeaderGoogle := r.Header.Get("X-Goog-Api-Key") authHeaderAnthropic := r.Header.Get("X-Api-Key") queryKey := "" + queryAuthToken := "" if r.URL != nil { queryKey = r.URL.Query().Get("key") + queryAuthToken = r.URL.Query().Get("auth_token") } - if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" { + if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" { return nil, sdkaccess.ErrNoCredentials } @@ -74,6 +76,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess. {authHeaderGoogle, "x-goog-api-key"}, {authHeaderAnthropic, "x-api-key"}, {queryKey, "query-key"}, + {queryAuthToken, "query-auth-token"}, } for _, candidate := range candidates { diff --git a/internal/api/middleware/request_logging.go b/internal/api/middleware/request_logging.go index b866e00c..d4ea6510 100644 --- a/internal/api/middleware/request_logging.go +++ b/internal/api/middleware/request_logging.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" ) // RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses. @@ -63,13 +64,11 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { // It captures the URL, method, headers, and body. The request body is read and then // restored so that it can be processed by subsequent handlers. func captureRequestInfo(c *gin.Context) (*RequestInfo, error) { - // Capture URL - url := c.Request.URL.String() - if c.Request.URL.Path != "" { - url = c.Request.URL.Path - if c.Request.URL.RawQuery != "" { - url += "?" + c.Request.URL.RawQuery - } + // Capture URL with sensitive query parameters masked + maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery) + url := c.Request.URL.Path + if maskedQuery != "" { + url += "?" + maskedQuery } // Capture method diff --git a/internal/api/server.go b/internal/api/server.go index a41861c2..f4eb81e2 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -140,8 +140,10 @@ type Server struct { currentPath string // wsRoutes tracks registered websocket upgrade paths. - wsRouteMu sync.Mutex - wsRoutes map[string]struct{} + wsRouteMu sync.Mutex + wsRoutes map[string]struct{} + wsAuthChanged func(bool, bool) + wsAuthEnabled atomic.Bool // management handler mgmt *managementHandlers.Handler @@ -235,6 +237,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk envManagementSecret: envManagementSecret, wsRoutes: make(map[string]struct{}), } + s.wsAuthEnabled.Store(cfg.WebsocketAuth) // Save initial YAML snapshot s.oldConfigYaml, _ = yaml.Marshal(cfg) s.applyAccessConfig(nil, cfg) @@ -398,10 +401,20 @@ func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) { s.wsRoutes[trimmed] = struct{}{} s.wsRouteMu.Unlock() - s.engine.GET(trimmed, func(c *gin.Context) { + authMiddleware := AuthMiddleware(s.accessManager) + conditionalAuth := func(c *gin.Context) { + if !s.wsAuthEnabled.Load() { + c.Next() + return + } + authMiddleware(c) + } + finalHandler := func(c *gin.Context) { handler.ServeHTTP(c.Writer, c.Request) c.Abort() - }) + } + + s.engine.GET(trimmed, conditionalAuth, finalHandler) } func (s *Server) registerManagementRoutes() { @@ -803,6 +816,10 @@ func (s *Server) UpdateClients(cfg *config.Config) { s.applyAccessConfig(oldCfg, cfg) s.cfg = cfg + s.wsAuthEnabled.Store(cfg.WebsocketAuth) + if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth { + s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth) + } managementasset.SetCurrentConfig(cfg) // Save YAML snapshot for next comparison s.oldConfigYaml, _ = yaml.Marshal(cfg) @@ -843,6 +860,13 @@ func (s *Server) UpdateClients(cfg *config.Config) { ) } +func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) { + if s == nil { + return + } + s.wsAuthChanged = fn +} + // (management handlers moved to internal/api/handlers/management) // AuthMiddleware returns a Gin middleware handler that authenticates requests diff --git a/internal/config/config.go b/internal/config/config.go index 169eecc2..bc4d217a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -40,6 +40,9 @@ type Config struct { // QuotaExceeded defines the behavior when a quota is exceeded. QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"` + // WebsocketAuth enables or disables authentication for the WebSocket API. + WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"` + // GlAPIKey is the API key for the generative language API. GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"` diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index 904fa797..2933a0bb 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) @@ -23,7 +24,7 @@ func GinLogrusLogger() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path - raw := c.Request.URL.RawQuery + raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery) c.Next() diff --git a/internal/util/provider.go b/internal/util/provider.go index 5f4dcd19..8c6cefdb 100644 --- a/internal/util/provider.go +++ b/internal/util/provider.go @@ -4,6 +4,7 @@ package util import ( + "net/url" "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -188,3 +189,56 @@ func MaskSensitiveHeaderValue(key, value string) string { return value } } + +// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string. +func MaskSensitiveQuery(raw string) string { + if raw == "" { + return "" + } + parts := strings.Split(raw, "&") + changed := false + for i, part := range parts { + if part == "" { + continue + } + keyPart := part + valuePart := "" + if idx := strings.Index(part, "="); idx >= 0 { + keyPart = part[:idx] + valuePart = part[idx+1:] + } + decodedKey, err := url.QueryUnescape(keyPart) + if err != nil { + decodedKey = keyPart + } + if !shouldMaskQueryParam(decodedKey) { + continue + } + decodedValue, err := url.QueryUnescape(valuePart) + if err != nil { + decodedValue = valuePart + } + masked := HideAPIKey(strings.TrimSpace(decodedValue)) + parts[i] = keyPart + "=" + url.QueryEscape(masked) + changed = true + } + if !changed { + return raw + } + return strings.Join(parts, "&") +} + +func shouldMaskQueryParam(key string) bool { + key = strings.ToLower(strings.TrimSpace(key)) + if key == "" { + return false + } + key = strings.TrimSuffix(key, "[]") + if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") { + return true + } + if strings.Contains(key, "token") || strings.Contains(key, "secret") { + return true + } + return false +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 85b48aae..93694710 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1204,6 +1204,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldCfg.ProxyURL != newCfg.ProxyURL { changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL)) } + if oldCfg.WebsocketAuth != newCfg.WebsocketAuth { + changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth)) + } // Quota-exceeded behavior if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index b0f4605b..ada70eb5 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -421,6 +421,22 @@ func (s *Service) Run(ctx context.Context) error { s.ensureWebsocketGateway() if s.server != nil && s.wsGateway != nil { s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler()) + s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) { + if oldEnabled == newEnabled { + return + } + if !oldEnabled && newEnabled { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if errStop := s.wsGateway.Stop(ctx); errStop != nil { + log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop) + return + } + log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication") + return + } + log.Debugf("ws-auth disabled; existing websocket sessions remain connected") + }) } if s.hooks.OnBeforeStart != nil { @@ -460,7 +476,6 @@ func (s *Service) Run(ctx context.Context) error { s.cfg = newCfg s.cfgMu.Unlock() s.rebindExecutors() - } watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)