feat(server): add keep-alive endpoint with timeout handling

- Introduced a keep-alive endpoint to monitor service activity.
- Added timeout-specific shutdown functionality when the endpoint is idle.
- Implemented password-protected access for the keep-alive endpoint.
- Updated server startup to support configurable keep-alive options.
This commit is contained in:
Luis Pater
2025-09-26 01:45:30 +08:00
parent cf734f7e7b
commit 5a50856fc1
2 changed files with 135 additions and 7 deletions

View File

@@ -6,12 +6,14 @@ package api
import (
"context"
"crypto/subtle"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
@@ -34,6 +36,9 @@ type serverOptionConfig struct {
routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)
requestLoggerFactory func(*config.Config, string) logging.RequestLogger
localPassword string
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
}
// ServerOption customises HTTP server construction.
@@ -71,6 +76,18 @@ func WithLocalManagementPassword(password string) ServerOption {
}
}
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
return func(cfg *serverOptionConfig) {
if timeout <= 0 || onTimeout == nil {
return
}
cfg.keepAliveEnabled = true
cfg.keepAliveTimeout = timeout
cfg.keepAliveOnTimeout = onTimeout
}
}
// WithRequestLoggerFactory customises request logger creation.
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
return func(cfg *serverOptionConfig) {
@@ -105,6 +122,14 @@ type Server struct {
// management handler
mgmt *managementHandlers.Handler
localPassword string
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
keepAliveHeartbeat chan struct{}
keepAliveStop chan struct{}
}
// NewServer creates and initializes a new API server instance.
@@ -174,6 +199,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
if optionState.localPassword != "" {
s.mgmt.SetLocalPassword(optionState.localPassword)
}
s.localPassword = optionState.localPassword
// Setup routes
s.setupRoutes()
@@ -181,6 +207,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
optionState.routerConfigurator(engine, s.handlers, cfg)
}
if optionState.keepAliveEnabled {
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
}
// Create HTTP server
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),
@@ -348,6 +378,84 @@ func (s *Server) setupRoutes() {
}
}
func (s *Server) enableKeepAlive(timeout time.Duration, onTimeout func()) {
if timeout <= 0 || onTimeout == nil {
return
}
s.keepAliveEnabled = true
s.keepAliveTimeout = timeout
s.keepAliveOnTimeout = onTimeout
s.keepAliveHeartbeat = make(chan struct{}, 1)
s.keepAliveStop = make(chan struct{}, 1)
s.engine.GET("/keep-alive", s.handleKeepAlive)
go s.watchKeepAlive()
}
func (s *Server) handleKeepAlive(c *gin.Context) {
if s.localPassword != "" {
provided := strings.TrimSpace(c.GetHeader("Authorization"))
if provided != "" {
parts := strings.SplitN(provided, " ", 2)
if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") {
provided = parts[1]
}
}
if provided == "" {
provided = strings.TrimSpace(c.GetHeader("X-Local-Password"))
}
if subtle.ConstantTimeCompare([]byte(provided), []byte(s.localPassword)) != 1 {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid password"})
return
}
}
s.signalKeepAlive()
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func (s *Server) signalKeepAlive() {
if !s.keepAliveEnabled {
return
}
select {
case s.keepAliveHeartbeat <- struct{}{}:
default:
}
}
func (s *Server) watchKeepAlive() {
if !s.keepAliveEnabled {
return
}
timer := time.NewTimer(s.keepAliveTimeout)
defer timer.Stop()
for {
select {
case <-timer.C:
log.Warnf("keep-alive endpoint idle for %s, shutting down", s.keepAliveTimeout)
if s.keepAliveOnTimeout != nil {
s.keepAliveOnTimeout()
}
return
case <-s.keepAliveHeartbeat:
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(s.keepAliveTimeout)
case <-s.keepAliveStop:
return
}
}
}
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
// that routes to different handlers based on the User-Agent header.
// If User-Agent starts with "claude-cli", it routes to Claude handler,
@@ -394,6 +502,13 @@ func (s *Server) Start() error {
func (s *Server) Stop(ctx context.Context) error {
log.Debug("Stopping API server...")
if s.keepAliveEnabled {
select {
case s.keepAliveStop <- struct{}{}:
default:
}
}
// Shutdown the HTTP server.
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTP server: %v", err)

View File

@@ -8,7 +8,9 @@ import (
"errors"
"os/signal"
"syscall"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
log "github.com/sirupsen/logrus"
@@ -23,19 +25,30 @@ import (
// - configPath: The path to the configuration file
// - localPassword: Optional password accepted for local management requests
func StartService(cfg *config.Config, configPath string, localPassword string) {
service, err := cliproxy.NewBuilder().
builder := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath(configPath).
WithLocalManagementPassword(localPassword).
Build()
WithLocalManagementPassword(localPassword)
ctxSignal, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
runCtx := ctxSignal
if localPassword != "" {
var keepAliveCancel context.CancelFunc
runCtx, keepAliveCancel = context.WithCancel(ctxSignal)
builder = builder.WithServerOptions(api.WithKeepAliveEndpoint(10*time.Second, func() {
log.Warn("keep-alive endpoint idle for 10s, shutting down")
keepAliveCancel()
}))
}
service, err := builder.Build()
if err != nil {
log.Fatalf("failed to build proxy service: %v", err)
}
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
err = service.Run(ctx)
err = service.Run(runCtx)
if err != nil && !errors.Is(err, context.Canceled) {
log.Fatalf("proxy service exited with error: %v", err)
}