mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(server): add keep-alive endpoint with timeout handling
- Introduced a keep-alive endpoint to monitor service activity. - Added timeout-specific shutdown functionality when the endpoint is idle. - Implemented password-protected access for the keep-alive endpoint. - Updated server startup to support configurable keep-alive options.
This commit is contained in:
@@ -6,12 +6,14 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
|
||||||
@@ -34,6 +36,9 @@ type serverOptionConfig struct {
|
|||||||
routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)
|
routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)
|
||||||
requestLoggerFactory func(*config.Config, string) logging.RequestLogger
|
requestLoggerFactory func(*config.Config, string) logging.RequestLogger
|
||||||
localPassword string
|
localPassword string
|
||||||
|
keepAliveEnabled bool
|
||||||
|
keepAliveTimeout time.Duration
|
||||||
|
keepAliveOnTimeout func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerOption customises HTTP server construction.
|
// ServerOption customises HTTP server construction.
|
||||||
@@ -71,6 +76,18 @@ func WithLocalManagementPassword(password string) ServerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
|
||||||
|
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
|
||||||
|
return func(cfg *serverOptionConfig) {
|
||||||
|
if timeout <= 0 || onTimeout == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.keepAliveEnabled = true
|
||||||
|
cfg.keepAliveTimeout = timeout
|
||||||
|
cfg.keepAliveOnTimeout = onTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithRequestLoggerFactory customises request logger creation.
|
// WithRequestLoggerFactory customises request logger creation.
|
||||||
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
|
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
|
||||||
return func(cfg *serverOptionConfig) {
|
return func(cfg *serverOptionConfig) {
|
||||||
@@ -105,6 +122,14 @@ type Server struct {
|
|||||||
|
|
||||||
// management handler
|
// management handler
|
||||||
mgmt *managementHandlers.Handler
|
mgmt *managementHandlers.Handler
|
||||||
|
|
||||||
|
localPassword string
|
||||||
|
|
||||||
|
keepAliveEnabled bool
|
||||||
|
keepAliveTimeout time.Duration
|
||||||
|
keepAliveOnTimeout func()
|
||||||
|
keepAliveHeartbeat chan struct{}
|
||||||
|
keepAliveStop chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates and initializes a new API server instance.
|
// NewServer creates and initializes a new API server instance.
|
||||||
@@ -174,6 +199,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
if optionState.localPassword != "" {
|
if optionState.localPassword != "" {
|
||||||
s.mgmt.SetLocalPassword(optionState.localPassword)
|
s.mgmt.SetLocalPassword(optionState.localPassword)
|
||||||
}
|
}
|
||||||
|
s.localPassword = optionState.localPassword
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
s.setupRoutes()
|
s.setupRoutes()
|
||||||
@@ -181,6 +207,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
optionState.routerConfigurator(engine, s.handlers, cfg)
|
optionState.routerConfigurator(engine, s.handlers, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if optionState.keepAliveEnabled {
|
||||||
|
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
// Create HTTP server
|
// Create HTTP server
|
||||||
s.server = &http.Server{
|
s.server = &http.Server{
|
||||||
Addr: fmt.Sprintf(":%d", cfg.Port),
|
Addr: fmt.Sprintf(":%d", cfg.Port),
|
||||||
@@ -348,6 +378,84 @@ func (s *Server) setupRoutes() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) enableKeepAlive(timeout time.Duration, onTimeout func()) {
|
||||||
|
if timeout <= 0 || onTimeout == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.keepAliveEnabled = true
|
||||||
|
s.keepAliveTimeout = timeout
|
||||||
|
s.keepAliveOnTimeout = onTimeout
|
||||||
|
s.keepAliveHeartbeat = make(chan struct{}, 1)
|
||||||
|
s.keepAliveStop = make(chan struct{}, 1)
|
||||||
|
|
||||||
|
s.engine.GET("/keep-alive", s.handleKeepAlive)
|
||||||
|
|
||||||
|
go s.watchKeepAlive()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleKeepAlive(c *gin.Context) {
|
||||||
|
if s.localPassword != "" {
|
||||||
|
provided := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||||
|
if provided != "" {
|
||||||
|
parts := strings.SplitN(provided, " ", 2)
|
||||||
|
if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") {
|
||||||
|
provided = parts[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if provided == "" {
|
||||||
|
provided = strings.TrimSpace(c.GetHeader("X-Local-Password"))
|
||||||
|
}
|
||||||
|
if subtle.ConstantTimeCompare([]byte(provided), []byte(s.localPassword)) != 1 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid password"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.signalKeepAlive()
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) signalKeepAlive() {
|
||||||
|
if !s.keepAliveEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case s.keepAliveHeartbeat <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) watchKeepAlive() {
|
||||||
|
if !s.keepAliveEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(s.keepAliveTimeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
log.Warnf("keep-alive endpoint idle for %s, shutting down", s.keepAliveTimeout)
|
||||||
|
if s.keepAliveOnTimeout != nil {
|
||||||
|
s.keepAliveOnTimeout()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-s.keepAliveHeartbeat:
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timer.Reset(s.keepAliveTimeout)
|
||||||
|
case <-s.keepAliveStop:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
|
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
|
||||||
// that routes to different handlers based on the User-Agent header.
|
// that routes to different handlers based on the User-Agent header.
|
||||||
// If User-Agent starts with "claude-cli", it routes to Claude handler,
|
// If User-Agent starts with "claude-cli", it routes to Claude handler,
|
||||||
@@ -394,6 +502,13 @@ func (s *Server) Start() error {
|
|||||||
func (s *Server) Stop(ctx context.Context) error {
|
func (s *Server) Stop(ctx context.Context) error {
|
||||||
log.Debug("Stopping API server...")
|
log.Debug("Stopping API server...")
|
||||||
|
|
||||||
|
if s.keepAliveEnabled {
|
||||||
|
select {
|
||||||
|
case s.keepAliveStop <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Shutdown the HTTP server.
|
// Shutdown the HTTP server.
|
||||||
if err := s.server.Shutdown(ctx); err != nil {
|
if err := s.server.Shutdown(ctx); err != nil {
|
||||||
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -23,19 +25,30 @@ import (
|
|||||||
// - configPath: The path to the configuration file
|
// - configPath: The path to the configuration file
|
||||||
// - localPassword: Optional password accepted for local management requests
|
// - localPassword: Optional password accepted for local management requests
|
||||||
func StartService(cfg *config.Config, configPath string, localPassword string) {
|
func StartService(cfg *config.Config, configPath string, localPassword string) {
|
||||||
service, err := cliproxy.NewBuilder().
|
builder := cliproxy.NewBuilder().
|
||||||
WithConfig(cfg).
|
WithConfig(cfg).
|
||||||
WithConfigPath(configPath).
|
WithConfigPath(configPath).
|
||||||
WithLocalManagementPassword(localPassword).
|
WithLocalManagementPassword(localPassword)
|
||||||
Build()
|
|
||||||
|
ctxSignal, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
runCtx := ctxSignal
|
||||||
|
if localPassword != "" {
|
||||||
|
var keepAliveCancel context.CancelFunc
|
||||||
|
runCtx, keepAliveCancel = context.WithCancel(ctxSignal)
|
||||||
|
builder = builder.WithServerOptions(api.WithKeepAliveEndpoint(10*time.Second, func() {
|
||||||
|
log.Warn("keep-alive endpoint idle for 10s, shutting down")
|
||||||
|
keepAliveCancel()
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := builder.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to build proxy service: %v", err)
|
log.Fatalf("failed to build proxy service: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
err = service.Run(runCtx)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err = service.Run(ctx)
|
|
||||||
if err != nil && !errors.Is(err, context.Canceled) {
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
log.Fatalf("proxy service exited with error: %v", err)
|
log.Fatalf("proxy service exited with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user