mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
1063 lines
36 KiB
Go
1063 lines
36 KiB
Go
// Package api provides the HTTP API server implementation for the CLI Proxy API.
|
|
// It includes the main server struct, routing setup, middleware for CORS and authentication,
|
|
// and integration with various AI API handlers (OpenAI, Claude, Gemini).
|
|
// The server supports hot-reloading of clients and configuration.
|
|
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/subtle"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/access"
|
|
managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
|
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
log "github.com/sirupsen/logrus"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
const oauthCallbackSuccessHTML = `<html><head><meta charset="utf-8"><title>Authentication successful</title><script>setTimeout(function(){window.close();},5000);</script></head><body><h1>Authentication successful!</h1><p>You can close this window.</p><p>This window will close automatically in 5 seconds.</p></body></html>`
|
|
|
|
type serverOptionConfig struct {
|
|
extraMiddleware []gin.HandlerFunc
|
|
engineConfigurator func(*gin.Engine)
|
|
routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)
|
|
requestLoggerFactory func(*config.Config, string) logging.RequestLogger
|
|
localPassword string
|
|
keepAliveEnabled bool
|
|
keepAliveTimeout time.Duration
|
|
keepAliveOnTimeout func()
|
|
}
|
|
|
|
// ServerOption customises HTTP server construction.
|
|
type ServerOption func(*serverOptionConfig)
|
|
|
|
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
|
configDir := filepath.Dir(configPath)
|
|
if base := util.WritablePath(); base != "" {
|
|
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir)
|
|
}
|
|
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir)
|
|
}
|
|
|
|
// WithMiddleware appends additional Gin middleware during server construction.
|
|
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption {
|
|
return func(cfg *serverOptionConfig) {
|
|
cfg.extraMiddleware = append(cfg.extraMiddleware, mw...)
|
|
}
|
|
}
|
|
|
|
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
|
|
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
|
|
return func(cfg *serverOptionConfig) {
|
|
cfg.engineConfigurator = fn
|
|
}
|
|
}
|
|
|
|
// WithRouterConfigurator appends a callback after default routes are registered.
|
|
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
|
|
return func(cfg *serverOptionConfig) {
|
|
cfg.routerConfigurator = fn
|
|
}
|
|
}
|
|
|
|
// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
|
|
func WithLocalManagementPassword(password string) ServerOption {
|
|
return func(cfg *serverOptionConfig) {
|
|
cfg.localPassword = password
|
|
}
|
|
}
|
|
|
|
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
|
|
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
|
|
return func(cfg *serverOptionConfig) {
|
|
if timeout <= 0 || onTimeout == nil {
|
|
return
|
|
}
|
|
cfg.keepAliveEnabled = true
|
|
cfg.keepAliveTimeout = timeout
|
|
cfg.keepAliveOnTimeout = onTimeout
|
|
}
|
|
}
|
|
|
|
// WithRequestLoggerFactory customises request logger creation.
|
|
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
|
|
return func(cfg *serverOptionConfig) {
|
|
cfg.requestLoggerFactory = factory
|
|
}
|
|
}
|
|
|
|
// Server represents the main API server.
|
|
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
|
type Server struct {
|
|
// engine is the Gin web framework engine instance.
|
|
engine *gin.Engine
|
|
|
|
// server is the underlying HTTP server.
|
|
server *http.Server
|
|
|
|
// handlers contains the API handlers for processing requests.
|
|
handlers *handlers.BaseAPIHandler
|
|
|
|
// cfg holds the current server configuration.
|
|
cfg *config.Config
|
|
|
|
// oldConfigYaml stores a YAML snapshot of the previous configuration for change detection.
|
|
// This prevents issues when the config object is modified in place by Management API.
|
|
oldConfigYaml []byte
|
|
|
|
// accessManager handles request authentication providers.
|
|
accessManager *sdkaccess.Manager
|
|
|
|
// requestLogger is the request logger instance for dynamic configuration updates.
|
|
requestLogger logging.RequestLogger
|
|
loggerToggle func(bool)
|
|
|
|
// configFilePath is the absolute path to the YAML config file for persistence.
|
|
configFilePath string
|
|
|
|
// currentPath is the absolute path to the current working directory.
|
|
currentPath string
|
|
|
|
// wsRoutes tracks registered websocket upgrade paths.
|
|
wsRouteMu sync.Mutex
|
|
wsRoutes map[string]struct{}
|
|
wsAuthChanged func(bool, bool)
|
|
wsAuthEnabled atomic.Bool
|
|
|
|
// management handler
|
|
mgmt *managementHandlers.Handler
|
|
|
|
// ampModule is the Amp routing module for model mapping hot-reload
|
|
ampModule *ampmodule.AmpModule
|
|
|
|
// managementRoutesRegistered tracks whether the management routes have been attached to the engine.
|
|
managementRoutesRegistered atomic.Bool
|
|
// managementRoutesEnabled controls whether management endpoints serve real handlers.
|
|
managementRoutesEnabled atomic.Bool
|
|
|
|
// envManagementSecret indicates whether MANAGEMENT_PASSWORD is configured.
|
|
envManagementSecret bool
|
|
|
|
localPassword string
|
|
|
|
keepAliveEnabled bool
|
|
keepAliveTimeout time.Duration
|
|
keepAliveOnTimeout func()
|
|
keepAliveHeartbeat chan struct{}
|
|
keepAliveStop chan struct{}
|
|
}
|
|
|
|
// NewServer creates and initializes a new API server instance.
|
|
// It sets up the Gin engine, middleware, routes, and handlers.
|
|
//
|
|
// Parameters:
|
|
// - cfg: The server configuration
|
|
// - authManager: core runtime auth manager
|
|
// - accessManager: request authentication manager
|
|
//
|
|
// Returns:
|
|
// - *Server: A new server instance
|
|
func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdkaccess.Manager, configFilePath string, opts ...ServerOption) *Server {
|
|
optionState := &serverOptionConfig{
|
|
requestLoggerFactory: defaultRequestLoggerFactory,
|
|
}
|
|
for i := range opts {
|
|
opts[i](optionState)
|
|
}
|
|
// Set gin mode
|
|
if !cfg.Debug {
|
|
gin.SetMode(gin.ReleaseMode)
|
|
}
|
|
|
|
// Create gin engine
|
|
engine := gin.New()
|
|
if optionState.engineConfigurator != nil {
|
|
optionState.engineConfigurator(engine)
|
|
}
|
|
|
|
// Add middleware
|
|
engine.Use(logging.GinLogrusLogger())
|
|
engine.Use(logging.GinLogrusRecovery())
|
|
for _, mw := range optionState.extraMiddleware {
|
|
engine.Use(mw)
|
|
}
|
|
|
|
// Add request logging middleware (positioned after recovery, before auth)
|
|
// Resolve logs directory relative to the configuration file directory.
|
|
var requestLogger logging.RequestLogger
|
|
var toggle func(bool)
|
|
if !cfg.CommercialMode {
|
|
if optionState.requestLoggerFactory != nil {
|
|
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
|
}
|
|
if requestLogger != nil {
|
|
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
|
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
|
toggle = setter.SetEnabled
|
|
}
|
|
}
|
|
}
|
|
|
|
engine.Use(corsMiddleware())
|
|
wd, err := os.Getwd()
|
|
if err != nil {
|
|
wd = configFilePath
|
|
}
|
|
|
|
envAdminPassword, envAdminPasswordSet := os.LookupEnv("MANAGEMENT_PASSWORD")
|
|
envAdminPassword = strings.TrimSpace(envAdminPassword)
|
|
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
|
|
|
// Create server instance
|
|
s := &Server{
|
|
engine: engine,
|
|
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
|
cfg: cfg,
|
|
accessManager: accessManager,
|
|
requestLogger: requestLogger,
|
|
loggerToggle: toggle,
|
|
configFilePath: configFilePath,
|
|
currentPath: wd,
|
|
envManagementSecret: envManagementSecret,
|
|
wsRoutes: make(map[string]struct{}),
|
|
}
|
|
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
|
// Save initial YAML snapshot
|
|
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
|
s.applyAccessConfig(nil, cfg)
|
|
if authManager != nil {
|
|
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
|
}
|
|
managementasset.SetCurrentConfig(cfg)
|
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
|
// Initialize management handler
|
|
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
|
if optionState.localPassword != "" {
|
|
s.mgmt.SetLocalPassword(optionState.localPassword)
|
|
}
|
|
logDir := filepath.Join(s.currentPath, "logs")
|
|
if base := util.WritablePath(); base != "" {
|
|
logDir = filepath.Join(base, "logs")
|
|
}
|
|
s.mgmt.SetLogDirectory(logDir)
|
|
s.localPassword = optionState.localPassword
|
|
|
|
// Setup routes
|
|
s.setupRoutes()
|
|
|
|
// Register Amp module using V2 interface with Context
|
|
s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
|
|
ctx := modules.Context{
|
|
Engine: engine,
|
|
BaseHandler: s.handlers,
|
|
Config: cfg,
|
|
AuthMiddleware: AuthMiddleware(accessManager),
|
|
}
|
|
if err := modules.RegisterModule(ctx, s.ampModule); err != nil {
|
|
log.Errorf("Failed to register Amp module: %v", err)
|
|
}
|
|
|
|
// Apply additional router configurators from options
|
|
if optionState.routerConfigurator != nil {
|
|
optionState.routerConfigurator(engine, s.handlers, cfg)
|
|
}
|
|
|
|
// Register management routes when configuration or environment secrets are available.
|
|
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
|
|
s.managementRoutesEnabled.Store(hasManagementSecret)
|
|
if hasManagementSecret {
|
|
s.registerManagementRoutes()
|
|
}
|
|
|
|
if optionState.keepAliveEnabled {
|
|
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
|
|
}
|
|
|
|
// Create HTTP server
|
|
s.server = &http.Server{
|
|
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
|
Handler: engine,
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
// setupRoutes configures the API routes for the server.
|
|
// It defines the endpoints and associates them with their respective handlers.
|
|
func (s *Server) setupRoutes() {
|
|
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
|
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
|
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
|
geminiCLIHandlers := gemini.NewGeminiCLIAPIHandler(s.handlers)
|
|
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(s.handlers)
|
|
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(s.handlers)
|
|
|
|
// OpenAI compatible API routes
|
|
v1 := s.engine.Group("/v1")
|
|
v1.Use(AuthMiddleware(s.accessManager))
|
|
{
|
|
v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
|
|
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
|
v1.POST("/completions", openaiHandlers.Completions)
|
|
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
|
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
|
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
|
}
|
|
|
|
// Gemini compatible API routes
|
|
v1beta := s.engine.Group("/v1beta")
|
|
v1beta.Use(AuthMiddleware(s.accessManager))
|
|
{
|
|
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
|
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
|
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
|
}
|
|
|
|
// Root endpoint
|
|
s.engine.GET("/", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"message": "CLI Proxy API Server",
|
|
"endpoints": []string{
|
|
"POST /v1/chat/completions",
|
|
"POST /v1/completions",
|
|
"GET /v1/models",
|
|
},
|
|
})
|
|
})
|
|
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
|
|
|
// OAuth callback endpoints (reuse main server port)
|
|
// These endpoints receive provider redirects and persist
|
|
// the short-lived code/state for the waiting goroutine.
|
|
s.engine.GET("/anthropic/callback", func(c *gin.Context) {
|
|
code := c.Query("code")
|
|
state := c.Query("state")
|
|
errStr := c.Query("error")
|
|
if errStr == "" {
|
|
errStr = c.Query("error_description")
|
|
}
|
|
if state != "" {
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
|
|
}
|
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
|
})
|
|
|
|
s.engine.GET("/codex/callback", func(c *gin.Context) {
|
|
code := c.Query("code")
|
|
state := c.Query("state")
|
|
errStr := c.Query("error")
|
|
if errStr == "" {
|
|
errStr = c.Query("error_description")
|
|
}
|
|
if state != "" {
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
|
|
}
|
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
|
})
|
|
|
|
s.engine.GET("/google/callback", func(c *gin.Context) {
|
|
code := c.Query("code")
|
|
state := c.Query("state")
|
|
errStr := c.Query("error")
|
|
if errStr == "" {
|
|
errStr = c.Query("error_description")
|
|
}
|
|
if state != "" {
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
|
|
}
|
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
|
})
|
|
|
|
s.engine.GET("/iflow/callback", func(c *gin.Context) {
|
|
code := c.Query("code")
|
|
state := c.Query("state")
|
|
errStr := c.Query("error")
|
|
if errStr == "" {
|
|
errStr = c.Query("error_description")
|
|
}
|
|
if state != "" {
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
|
}
|
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
|
})
|
|
|
|
s.engine.GET("/antigravity/callback", func(c *gin.Context) {
|
|
code := c.Query("code")
|
|
state := c.Query("state")
|
|
errStr := c.Query("error")
|
|
if errStr == "" {
|
|
errStr = c.Query("error_description")
|
|
}
|
|
if state != "" {
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
|
|
}
|
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
|
})
|
|
|
|
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
|
}
|
|
|
|
// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine.
|
|
// The handler is served as-is without additional middleware beyond the standard stack already configured.
|
|
func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
|
|
if s == nil || s.engine == nil || handler == nil {
|
|
return
|
|
}
|
|
trimmed := strings.TrimSpace(path)
|
|
if trimmed == "" {
|
|
trimmed = "/v1/ws"
|
|
}
|
|
if !strings.HasPrefix(trimmed, "/") {
|
|
trimmed = "/" + trimmed
|
|
}
|
|
s.wsRouteMu.Lock()
|
|
if _, exists := s.wsRoutes[trimmed]; exists {
|
|
s.wsRouteMu.Unlock()
|
|
return
|
|
}
|
|
s.wsRoutes[trimmed] = struct{}{}
|
|
s.wsRouteMu.Unlock()
|
|
|
|
authMiddleware := AuthMiddleware(s.accessManager)
|
|
conditionalAuth := func(c *gin.Context) {
|
|
if !s.wsAuthEnabled.Load() {
|
|
c.Next()
|
|
return
|
|
}
|
|
authMiddleware(c)
|
|
}
|
|
finalHandler := func(c *gin.Context) {
|
|
handler.ServeHTTP(c.Writer, c.Request)
|
|
c.Abort()
|
|
}
|
|
|
|
s.engine.GET(trimmed, conditionalAuth, finalHandler)
|
|
}
|
|
|
|
func (s *Server) registerManagementRoutes() {
|
|
if s == nil || s.engine == nil || s.mgmt == nil {
|
|
return
|
|
}
|
|
if !s.managementRoutesRegistered.CompareAndSwap(false, true) {
|
|
return
|
|
}
|
|
|
|
log.Info("management routes registered after secret key configuration")
|
|
|
|
mgmt := s.engine.Group("/v0/management")
|
|
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
|
{
|
|
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
|
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
|
|
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
|
|
mgmt.GET("/config", s.mgmt.GetConfig)
|
|
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
|
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
|
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
|
|
|
|
mgmt.GET("/debug", s.mgmt.GetDebug)
|
|
mgmt.PUT("/debug", s.mgmt.PutDebug)
|
|
mgmt.PATCH("/debug", s.mgmt.PutDebug)
|
|
|
|
mgmt.GET("/logging-to-file", s.mgmt.GetLoggingToFile)
|
|
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
|
|
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
|
|
|
|
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
|
|
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
|
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
|
|
|
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
|
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
|
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
|
|
|
mgmt.GET("/proxy-url", s.mgmt.GetProxyURL)
|
|
mgmt.PUT("/proxy-url", s.mgmt.PutProxyURL)
|
|
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
|
|
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
|
|
|
|
mgmt.POST("/api-call", s.mgmt.APICall)
|
|
|
|
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
|
|
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
|
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
|
|
|
mgmt.GET("/quota-exceeded/switch-preview-model", s.mgmt.GetSwitchPreviewModel)
|
|
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
|
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
|
|
|
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
|
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
|
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
|
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
|
|
|
|
mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys)
|
|
mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys)
|
|
mgmt.PATCH("/gemini-api-key", s.mgmt.PatchGeminiKey)
|
|
mgmt.DELETE("/gemini-api-key", s.mgmt.DeleteGeminiKey)
|
|
|
|
mgmt.GET("/logs", s.mgmt.GetLogs)
|
|
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
|
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
|
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
|
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
|
|
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
|
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
|
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
|
mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth)
|
|
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
|
|
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
|
|
|
|
mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
|
|
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
|
|
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
|
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
|
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
|
|
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
|
|
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
|
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
|
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
|
|
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
|
|
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
|
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
|
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
|
|
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
|
|
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
|
|
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
|
|
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
|
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
|
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
|
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
|
|
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
|
|
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
|
|
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
|
|
|
|
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
|
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
|
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
|
|
mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval)
|
|
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
|
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
|
|
|
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
|
|
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
|
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
|
|
|
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
|
|
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
|
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
|
|
|
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
|
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
|
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
|
|
mgmt.DELETE("/claude-api-key", s.mgmt.DeleteClaudeKey)
|
|
|
|
mgmt.GET("/codex-api-key", s.mgmt.GetCodexKeys)
|
|
mgmt.PUT("/codex-api-key", s.mgmt.PutCodexKeys)
|
|
mgmt.PATCH("/codex-api-key", s.mgmt.PatchCodexKey)
|
|
mgmt.DELETE("/codex-api-key", s.mgmt.DeleteCodexKey)
|
|
|
|
mgmt.GET("/openai-compatibility", s.mgmt.GetOpenAICompat)
|
|
mgmt.PUT("/openai-compatibility", s.mgmt.PutOpenAICompat)
|
|
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
|
|
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
|
|
|
|
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
|
|
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
|
|
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
|
|
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
|
|
|
|
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
|
|
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
|
|
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
|
|
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
|
|
|
mgmt.GET("/oauth-model-alias", s.mgmt.GetOAuthModelAlias)
|
|
mgmt.PUT("/oauth-model-alias", s.mgmt.PutOAuthModelAlias)
|
|
mgmt.PATCH("/oauth-model-alias", s.mgmt.PatchOAuthModelAlias)
|
|
mgmt.DELETE("/oauth-model-alias", s.mgmt.DeleteOAuthModelAlias)
|
|
|
|
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
|
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
|
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
|
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
|
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
|
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
|
|
|
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
|
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
|
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
|
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
|
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
|
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
|
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
|
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
|
}
|
|
}
|
|
|
|
func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if !s.managementRoutesEnabled.Load() {
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
|
cfg := s.cfg
|
|
if cfg == nil || cfg.RemoteManagement.DisableControlPanel {
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
filePath := managementasset.FilePath(s.configFilePath)
|
|
if strings.TrimSpace(filePath) == "" {
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
if _, err := os.Stat(filePath); err != nil {
|
|
if os.IsNotExist(err) {
|
|
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
log.WithError(err).Error("failed to stat management control panel asset")
|
|
c.AbortWithStatus(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
c.File(filePath)
|
|
}
|
|
|
|
func (s *Server) enableKeepAlive(timeout time.Duration, onTimeout func()) {
|
|
if timeout <= 0 || onTimeout == nil {
|
|
return
|
|
}
|
|
|
|
s.keepAliveEnabled = true
|
|
s.keepAliveTimeout = timeout
|
|
s.keepAliveOnTimeout = onTimeout
|
|
s.keepAliveHeartbeat = make(chan struct{}, 1)
|
|
s.keepAliveStop = make(chan struct{}, 1)
|
|
|
|
s.engine.GET("/keep-alive", s.handleKeepAlive)
|
|
|
|
go s.watchKeepAlive()
|
|
}
|
|
|
|
func (s *Server) handleKeepAlive(c *gin.Context) {
|
|
if s.localPassword != "" {
|
|
provided := strings.TrimSpace(c.GetHeader("Authorization"))
|
|
if provided != "" {
|
|
parts := strings.SplitN(provided, " ", 2)
|
|
if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") {
|
|
provided = parts[1]
|
|
}
|
|
}
|
|
if provided == "" {
|
|
provided = strings.TrimSpace(c.GetHeader("X-Local-Password"))
|
|
}
|
|
if subtle.ConstantTimeCompare([]byte(provided), []byte(s.localPassword)) != 1 {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid password"})
|
|
return
|
|
}
|
|
}
|
|
|
|
s.signalKeepAlive()
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
}
|
|
|
|
func (s *Server) signalKeepAlive() {
|
|
if !s.keepAliveEnabled {
|
|
return
|
|
}
|
|
select {
|
|
case s.keepAliveHeartbeat <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (s *Server) watchKeepAlive() {
|
|
if !s.keepAliveEnabled {
|
|
return
|
|
}
|
|
|
|
timer := time.NewTimer(s.keepAliveTimeout)
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-timer.C:
|
|
log.Warnf("keep-alive endpoint idle for %s, shutting down", s.keepAliveTimeout)
|
|
if s.keepAliveOnTimeout != nil {
|
|
s.keepAliveOnTimeout()
|
|
}
|
|
return
|
|
case <-s.keepAliveHeartbeat:
|
|
if !timer.Stop() {
|
|
select {
|
|
case <-timer.C:
|
|
default:
|
|
}
|
|
}
|
|
timer.Reset(s.keepAliveTimeout)
|
|
case <-s.keepAliveStop:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
|
|
// that routes to different handlers based on the User-Agent header.
|
|
// If User-Agent starts with "claude-cli", it routes to Claude handler,
|
|
// otherwise it routes to OpenAI handler.
|
|
func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
userAgent := c.GetHeader("User-Agent")
|
|
|
|
// Route to Claude handler if User-Agent starts with "claude-cli"
|
|
if strings.HasPrefix(userAgent, "claude-cli") {
|
|
// log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent)
|
|
claudeHandler.ClaudeModels(c)
|
|
} else {
|
|
// log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent)
|
|
openaiHandler.OpenAIModels(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Start begins listening for and serving HTTP or HTTPS requests.
|
|
// It's a blocking call and will only return on an unrecoverable error.
|
|
//
|
|
// Returns:
|
|
// - error: An error if the server fails to start
|
|
func (s *Server) Start() error {
|
|
if s == nil || s.server == nil {
|
|
return fmt.Errorf("failed to start HTTP server: server not initialized")
|
|
}
|
|
|
|
useTLS := s.cfg != nil && s.cfg.TLS.Enable
|
|
if useTLS {
|
|
cert := strings.TrimSpace(s.cfg.TLS.Cert)
|
|
key := strings.TrimSpace(s.cfg.TLS.Key)
|
|
if cert == "" || key == "" {
|
|
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
|
|
}
|
|
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
|
|
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
|
|
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
log.Debugf("Starting API server on %s", s.server.Addr)
|
|
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
|
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop gracefully shuts down the API server without interrupting any
|
|
// active connections.
|
|
//
|
|
// Parameters:
|
|
// - ctx: The context for graceful shutdown
|
|
//
|
|
// Returns:
|
|
// - error: An error if the server fails to stop
|
|
func (s *Server) Stop(ctx context.Context) error {
|
|
log.Debug("Stopping API server...")
|
|
|
|
if s.keepAliveEnabled {
|
|
select {
|
|
case s.keepAliveStop <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// Shutdown the HTTP server.
|
|
if err := s.server.Shutdown(ctx); err != nil {
|
|
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
|
}
|
|
|
|
log.Debug("API server stopped")
|
|
return nil
|
|
}
|
|
|
|
// corsMiddleware returns a Gin middleware handler that adds CORS headers
|
|
// to every response, allowing cross-origin requests.
|
|
//
|
|
// Returns:
|
|
// - gin.HandlerFunc: The CORS middleware handler
|
|
func corsMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
|
c.Header("Access-Control-Allow-Headers", "*")
|
|
|
|
if c.Request.Method == "OPTIONS" {
|
|
c.AbortWithStatus(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) {
|
|
if s == nil || s.accessManager == nil || newCfg == nil {
|
|
return
|
|
}
|
|
if _, err := access.ApplyAccessProviders(s.accessManager, oldCfg, newCfg); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
// UpdateClients updates the server's client list and configuration.
|
|
// This method is called when the configuration or authentication tokens change.
|
|
//
|
|
// Parameters:
|
|
// - clients: The new slice of AI service clients
|
|
// - cfg: The new application configuration
|
|
func (s *Server) UpdateClients(cfg *config.Config) {
|
|
// Reconstruct old config from YAML snapshot to avoid reference sharing issues
|
|
var oldCfg *config.Config
|
|
if len(s.oldConfigYaml) > 0 {
|
|
_ = yaml.Unmarshal(s.oldConfigYaml, &oldCfg)
|
|
}
|
|
|
|
// Update request logger enabled state if it has changed
|
|
previousRequestLog := false
|
|
if oldCfg != nil {
|
|
previousRequestLog = oldCfg.RequestLog
|
|
}
|
|
if s.requestLogger != nil && (oldCfg == nil || previousRequestLog != cfg.RequestLog) {
|
|
if s.loggerToggle != nil {
|
|
s.loggerToggle(cfg.RequestLog)
|
|
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
|
|
toggler.SetEnabled(cfg.RequestLog)
|
|
}
|
|
if oldCfg != nil {
|
|
log.Debugf("request logging updated from %t to %t", previousRequestLog, cfg.RequestLog)
|
|
} else {
|
|
log.Debugf("request logging toggled to %t", cfg.RequestLog)
|
|
}
|
|
}
|
|
|
|
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
|
if err := logging.ConfigureLogOutput(cfg); err != nil {
|
|
log.Errorf("failed to reconfigure log output: %v", err)
|
|
} else {
|
|
if oldCfg == nil {
|
|
log.Debug("log output configuration refreshed")
|
|
} else {
|
|
if oldCfg.LoggingToFile != cfg.LoggingToFile {
|
|
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
|
}
|
|
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
|
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
|
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
|
if oldCfg != nil {
|
|
log.Debugf("usage_statistics_enabled updated from %t to %t", oldCfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
|
|
} else {
|
|
log.Debugf("usage_statistics_enabled toggled to %t", cfg.UsageStatisticsEnabled)
|
|
}
|
|
}
|
|
|
|
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
|
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
|
if oldCfg != nil {
|
|
log.Debugf("disable_cooling updated from %t to %t", oldCfg.DisableCooling, cfg.DisableCooling)
|
|
} else {
|
|
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
|
|
}
|
|
}
|
|
if s.handlers != nil && s.handlers.AuthManager != nil {
|
|
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
|
}
|
|
|
|
// Update log level dynamically when debug flag changes
|
|
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
|
|
util.SetLogLevel(cfg)
|
|
if oldCfg != nil {
|
|
log.Debugf("debug mode updated from %t to %t", oldCfg.Debug, cfg.Debug)
|
|
} else {
|
|
log.Debugf("debug mode toggled to %t", cfg.Debug)
|
|
}
|
|
}
|
|
|
|
prevSecretEmpty := true
|
|
if oldCfg != nil {
|
|
prevSecretEmpty = oldCfg.RemoteManagement.SecretKey == ""
|
|
}
|
|
newSecretEmpty := cfg.RemoteManagement.SecretKey == ""
|
|
if s.envManagementSecret {
|
|
s.registerManagementRoutes()
|
|
if s.managementRoutesEnabled.CompareAndSwap(false, true) {
|
|
log.Info("management routes enabled via MANAGEMENT_PASSWORD")
|
|
} else {
|
|
s.managementRoutesEnabled.Store(true)
|
|
}
|
|
} else {
|
|
switch {
|
|
case prevSecretEmpty && !newSecretEmpty:
|
|
s.registerManagementRoutes()
|
|
if s.managementRoutesEnabled.CompareAndSwap(false, true) {
|
|
log.Info("management routes enabled after secret key update")
|
|
} else {
|
|
s.managementRoutesEnabled.Store(true)
|
|
}
|
|
case !prevSecretEmpty && newSecretEmpty:
|
|
if s.managementRoutesEnabled.CompareAndSwap(true, false) {
|
|
log.Info("management routes disabled after secret key removal")
|
|
} else {
|
|
s.managementRoutesEnabled.Store(false)
|
|
}
|
|
default:
|
|
s.managementRoutesEnabled.Store(!newSecretEmpty)
|
|
}
|
|
}
|
|
|
|
s.applyAccessConfig(oldCfg, cfg)
|
|
s.cfg = cfg
|
|
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
|
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
|
|
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
|
|
}
|
|
managementasset.SetCurrentConfig(cfg)
|
|
// Save YAML snapshot for next comparison
|
|
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
|
|
|
s.handlers.UpdateClients(&cfg.SDKConfig)
|
|
|
|
if !cfg.RemoteManagement.DisableControlPanel {
|
|
staticDir := managementasset.StaticDir(s.configFilePath)
|
|
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
|
}
|
|
if s.mgmt != nil {
|
|
s.mgmt.SetConfig(cfg)
|
|
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
|
}
|
|
|
|
// Notify Amp module of config changes (for model mapping hot-reload)
|
|
if s.ampModule != nil {
|
|
log.Debugf("triggering amp module config update")
|
|
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
|
log.Errorf("failed to update Amp module config: %v", err)
|
|
}
|
|
} else {
|
|
log.Warnf("amp module is nil, skipping config update")
|
|
}
|
|
|
|
// Count client sources from configuration and auth store.
|
|
tokenStore := sdkAuth.GetTokenStore()
|
|
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
|
|
dirSetter.SetBaseDir(cfg.AuthDir)
|
|
}
|
|
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
|
|
geminiAPIKeyCount := len(cfg.GeminiKey)
|
|
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
|
codexAPIKeyCount := len(cfg.CodexKey)
|
|
vertexAICompatCount := len(cfg.VertexCompatAPIKey)
|
|
openAICompatCount := 0
|
|
for i := range cfg.OpenAICompatibility {
|
|
entry := cfg.OpenAICompatibility[i]
|
|
openAICompatCount += len(entry.APIKeyEntries)
|
|
}
|
|
|
|
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
|
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
|
total,
|
|
authEntries,
|
|
geminiAPIKeyCount,
|
|
claudeAPIKeyCount,
|
|
codexAPIKeyCount,
|
|
vertexAICompatCount,
|
|
openAICompatCount,
|
|
)
|
|
}
|
|
|
|
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.wsAuthChanged = fn
|
|
}
|
|
|
|
// (management handlers moved to internal/api/handlers/management)
|
|
|
|
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
|
// using the configured authentication providers. When no providers are available,
|
|
// it allows all requests (legacy behaviour).
|
|
func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if manager == nil {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
result, err := manager.Authenticate(c.Request.Context(), c.Request)
|
|
if err == nil {
|
|
if result != nil {
|
|
c.Set("apiKey", result.Principal)
|
|
c.Set("accessProvider", result.Provider)
|
|
if len(result.Metadata) > 0 {
|
|
c.Set("accessMetadata", result.Metadata)
|
|
}
|
|
}
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
switch {
|
|
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
|
|
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
|
default:
|
|
log.Errorf("authentication middleware error: %v", err)
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
|
|
}
|
|
}
|
|
}
|