diff --git a/internal/api/server.go b/internal/api/server.go index 30711882..5efc9175 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -6,12 +6,14 @@ package api import ( "context" + "crypto/subtle" "errors" "fmt" "net/http" "os" "path/filepath" "strings" + "time" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers" @@ -34,6 +36,9 @@ type serverOptionConfig struct { routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config) requestLoggerFactory func(*config.Config, string) logging.RequestLogger localPassword string + keepAliveEnabled bool + keepAliveTimeout time.Duration + keepAliveOnTimeout func() } // ServerOption customises HTTP server construction. @@ -71,6 +76,18 @@ func WithLocalManagementPassword(password string) ServerOption { } } +// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback. +func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption { + return func(cfg *serverOptionConfig) { + if timeout <= 0 || onTimeout == nil { + return + } + cfg.keepAliveEnabled = true + cfg.keepAliveTimeout = timeout + cfg.keepAliveOnTimeout = onTimeout + } +} + // WithRequestLoggerFactory customises request logger creation. func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption { return func(cfg *serverOptionConfig) { @@ -105,6 +122,14 @@ type Server struct { // management handler mgmt *managementHandlers.Handler + + localPassword string + + keepAliveEnabled bool + keepAliveTimeout time.Duration + keepAliveOnTimeout func() + keepAliveHeartbeat chan struct{} + keepAliveStop chan struct{} } // NewServer creates and initializes a new API server instance. @@ -174,6 +199,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk if optionState.localPassword != "" { s.mgmt.SetLocalPassword(optionState.localPassword) } + s.localPassword = optionState.localPassword // Setup routes s.setupRoutes() @@ -181,6 +207,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk optionState.routerConfigurator(engine, s.handlers, cfg) } + if optionState.keepAliveEnabled { + s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout) + } + // Create HTTP server s.server = &http.Server{ Addr: fmt.Sprintf(":%d", cfg.Port), @@ -348,6 +378,84 @@ func (s *Server) setupRoutes() { } } +func (s *Server) enableKeepAlive(timeout time.Duration, onTimeout func()) { + if timeout <= 0 || onTimeout == nil { + return + } + + s.keepAliveEnabled = true + s.keepAliveTimeout = timeout + s.keepAliveOnTimeout = onTimeout + s.keepAliveHeartbeat = make(chan struct{}, 1) + s.keepAliveStop = make(chan struct{}, 1) + + s.engine.GET("/keep-alive", s.handleKeepAlive) + + go s.watchKeepAlive() +} + +func (s *Server) handleKeepAlive(c *gin.Context) { + if s.localPassword != "" { + provided := strings.TrimSpace(c.GetHeader("Authorization")) + if provided != "" { + parts := strings.SplitN(provided, " ", 2) + if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") { + provided = parts[1] + } + } + if provided == "" { + provided = strings.TrimSpace(c.GetHeader("X-Local-Password")) + } + if subtle.ConstantTimeCompare([]byte(provided), []byte(s.localPassword)) != 1 { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid password"}) + return + } + } + + s.signalKeepAlive() + c.JSON(http.StatusOK, gin.H{"status": "ok"}) +} + +func (s *Server) signalKeepAlive() { + if !s.keepAliveEnabled { + return + } + select { + case s.keepAliveHeartbeat <- struct{}{}: + default: + } +} + +func (s *Server) watchKeepAlive() { + if !s.keepAliveEnabled { + return + } + + timer := time.NewTimer(s.keepAliveTimeout) + defer timer.Stop() + + for { + select { + case <-timer.C: + log.Warnf("keep-alive endpoint idle for %s, shutting down", s.keepAliveTimeout) + if s.keepAliveOnTimeout != nil { + s.keepAliveOnTimeout() + } + return + case <-s.keepAliveHeartbeat: + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(s.keepAliveTimeout) + case <-s.keepAliveStop: + return + } + } +} + // unifiedModelsHandler creates a unified handler for the /v1/models endpoint // that routes to different handlers based on the User-Agent header. // If User-Agent starts with "claude-cli", it routes to Claude handler, @@ -394,6 +502,13 @@ func (s *Server) Start() error { func (s *Server) Stop(ctx context.Context) error { log.Debug("Stopping API server...") + if s.keepAliveEnabled { + select { + case s.keepAliveStop <- struct{}{}: + default: + } + } + // Shutdown the HTTP server. if err := s.server.Shutdown(ctx); err != nil { return fmt.Errorf("failed to shutdown HTTP server: %v", err) diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 94b01592..cd4aaea7 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -8,7 +8,9 @@ import ( "errors" "os/signal" "syscall" + "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/api" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy" log "github.com/sirupsen/logrus" @@ -23,19 +25,30 @@ import ( // - configPath: The path to the configuration file // - localPassword: Optional password accepted for local management requests func StartService(cfg *config.Config, configPath string, localPassword string) { - service, err := cliproxy.NewBuilder(). + builder := cliproxy.NewBuilder(). WithConfig(cfg). WithConfigPath(configPath). - WithLocalManagementPassword(localPassword). - Build() + WithLocalManagementPassword(localPassword) + + ctxSignal, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + runCtx := ctxSignal + if localPassword != "" { + var keepAliveCancel context.CancelFunc + runCtx, keepAliveCancel = context.WithCancel(ctxSignal) + builder = builder.WithServerOptions(api.WithKeepAliveEndpoint(10*time.Second, func() { + log.Warn("keep-alive endpoint idle for 10s, shutting down") + keepAliveCancel() + })) + } + + service, err := builder.Build() if err != nil { log.Fatalf("failed to build proxy service: %v", err) } - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer cancel() - - err = service.Run(ctx) + err = service.Run(runCtx) if err != nil && !errors.Is(err, context.Canceled) { log.Fatalf("proxy service exited with error: %v", err) }