mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(amp): add hot-reload support for upstream URL and localhost restriction
This commit is contained in:
@@ -27,12 +27,17 @@ type Option func(*AmpModule)
|
|||||||
type AmpModule struct {
|
type AmpModule struct {
|
||||||
secretSource SecretSource
|
secretSource SecretSource
|
||||||
proxy *httputil.ReverseProxy
|
proxy *httputil.ReverseProxy
|
||||||
|
proxyMu sync.RWMutex // protects proxy for hot-reload
|
||||||
accessManager *sdkaccess.Manager
|
accessManager *sdkaccess.Manager
|
||||||
authMiddleware_ gin.HandlerFunc
|
authMiddleware_ gin.HandlerFunc
|
||||||
modelMapper *DefaultModelMapper
|
modelMapper *DefaultModelMapper
|
||||||
enabled bool
|
enabled bool
|
||||||
registerOnce sync.Once
|
registerOnce sync.Once
|
||||||
|
|
||||||
|
// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
|
||||||
|
restrictToLocalhost bool
|
||||||
|
restrictMu sync.RWMutex
|
||||||
|
|
||||||
// configMu protects lastConfig for partial reload comparison
|
// configMu protects lastConfig for partial reload comparison
|
||||||
configMu sync.RWMutex
|
configMu sync.RWMutex
|
||||||
lastConfig *config.AmpCode
|
lastConfig *config.AmpCode
|
||||||
@@ -115,6 +120,9 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
|||||||
settingsCopy := settings
|
settingsCopy := settings
|
||||||
m.lastConfig = &settingsCopy
|
m.lastConfig = &settingsCopy
|
||||||
|
|
||||||
|
// Initialize localhost restriction setting (hot-reloadable)
|
||||||
|
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
||||||
|
|
||||||
// Always register provider aliases - these work without an upstream
|
// Always register provider aliases - these work without an upstream
|
||||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||||
|
|
||||||
@@ -139,13 +147,12 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.proxy = proxy
|
m.setProxy(proxy)
|
||||||
m.enabled = true
|
m.enabled = true
|
||||||
|
|
||||||
// Register management proxy routes (requires upstream)
|
// Register management proxy routes (requires upstream)
|
||||||
// Restrict to localhost by default for security (prevents drive-by browser attacks)
|
// Uses dynamic middleware that checks m.IsRestrictedToLocalhost() for hot-reload support
|
||||||
handler := proxyHandler(proxy)
|
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
||||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, settings.RestrictManagementToLocalhost)
|
|
||||||
|
|
||||||
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
||||||
log.Debug("amp provider alias routes registered")
|
log.Debug("amp provider alias routes registered")
|
||||||
@@ -172,7 +179,7 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
|
|||||||
|
|
||||||
// OnConfigUpdated handles configuration updates with partial reload support.
|
// OnConfigUpdated handles configuration updates with partial reload support.
|
||||||
// Only updates components that have actually changed to avoid unnecessary work.
|
// Only updates components that have actually changed to avoid unnecessary work.
|
||||||
// URL changes still require restart (logged as warning).
|
// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
|
||||||
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||||
newSettings := cfg.AmpCode
|
newSettings := cfg.AmpCode
|
||||||
|
|
||||||
@@ -199,7 +206,7 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if m.enabled {
|
if m.enabled {
|
||||||
// Check upstream URL change (requires restart)
|
// Check upstream URL change - now supports hot-reload
|
||||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||||
oldUpstreamURL := ""
|
oldUpstreamURL := ""
|
||||||
if oldSettings != nil {
|
if oldSettings != nil {
|
||||||
@@ -207,10 +214,19 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
||||||
log.Warn("amp upstream URL removed from config, restart required to disable")
|
log.Warn("amp upstream URL removed from config, proxy disabled until restart")
|
||||||
} else if newUpstreamURL != oldUpstreamURL {
|
m.setProxy(nil)
|
||||||
changes = append(changes, "upstream-url(restart required)")
|
changes = append(changes, "upstream-url(disabled)")
|
||||||
log.Warnf("amp config: upstream-url changed (%s -> %s), restart required", oldUpstreamURL, newUpstreamURL)
|
} else if newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
|
||||||
|
// Recreate proxy with new URL
|
||||||
|
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
|
||||||
|
} else {
|
||||||
|
m.setProxy(proxy)
|
||||||
|
changes = append(changes, "upstream-url")
|
||||||
|
log.Infof("amp config partial reload: upstream URL updated (%s -> %s)", oldUpstreamURL, newUpstreamURL)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check API key change
|
// Check API key change
|
||||||
@@ -226,11 +242,15 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check restrict-management-to-localhost change (requires restart)
|
// Check restrict-management-to-localhost change - now supports hot-reload
|
||||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||||
changes = append(changes, "restrict-management-to-localhost(restart required)")
|
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||||
log.Warnf("amp config: restrict-management-to-localhost changed (%t -> %t), restart required",
|
changes = append(changes, "restrict-management-to-localhost")
|
||||||
oldSettings.RestrictManagementToLocalhost, newSettings.RestrictManagementToLocalhost)
|
if newSettings.RestrictManagementToLocalhost {
|
||||||
|
log.Infof("amp config partial reload: management routes now restricted to localhost")
|
||||||
|
} else {
|
||||||
|
log.Warnf("amp config partial reload: management routes now accessible from any IP - this is insecure!")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,3 +311,31 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
|
|||||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||||
return m.modelMapper
|
return m.modelMapper
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getProxy returns the current proxy instance (thread-safe for hot-reload).
|
||||||
|
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
|
||||||
|
m.proxyMu.RLock()
|
||||||
|
defer m.proxyMu.RUnlock()
|
||||||
|
return m.proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
// setProxy updates the proxy instance (thread-safe for hot-reload).
|
||||||
|
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
|
||||||
|
m.proxyMu.Lock()
|
||||||
|
defer m.proxyMu.Unlock()
|
||||||
|
m.proxy = proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
|
||||||
|
func (m *AmpModule) IsRestrictedToLocalhost() bool {
|
||||||
|
m.restrictMu.RLock()
|
||||||
|
defer m.restrictMu.RUnlock()
|
||||||
|
return m.restrictToLocalhost
|
||||||
|
}
|
||||||
|
|
||||||
|
// setRestrictToLocalhost updates the localhost restriction setting.
|
||||||
|
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
|
||||||
|
m.restrictMu.Lock()
|
||||||
|
defer m.restrictMu.Unlock()
|
||||||
|
m.restrictToLocalhost = restrict
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,15 +14,16 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only.
|
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
||||||
// Returns 403 Forbidden for non-localhost clients.
|
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
||||||
//
|
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
||||||
// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent
|
|
||||||
// header spoofing attacks via X-Forwarded-For or similar headers. This means the
|
|
||||||
// middleware will not work correctly behind reverse proxies - users deploying behind
|
|
||||||
// nginx/Cloudflare should disable this feature and use firewall rules instead.
|
|
||||||
func localhostOnlyMiddleware() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
// Check current setting (hot-reloadable)
|
||||||
|
if !m.IsRestrictedToLocalhost() {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing
|
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing
|
||||||
// This cannot be forged by X-Forwarded-For or other client-controlled headers
|
// This cannot be forged by X-Forwarded-For or other client-controlled headers
|
||||||
remoteAddr := c.Request.RemoteAddr
|
remoteAddr := c.Request.RemoteAddr
|
||||||
@@ -79,21 +80,32 @@ func noCORSMiddleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
// registerManagementRoutes registers Amp management proxy routes
|
// registerManagementRoutes registers Amp management proxy routes
|
||||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||||
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1.
|
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) {
|
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
|
||||||
ampAPI := engine.Group("/api")
|
ampAPI := engine.Group("/api")
|
||||||
|
|
||||||
// Always disable CORS for management routes to prevent browser-based attacks
|
// Always disable CORS for management routes to prevent browser-based attacks
|
||||||
ampAPI.Use(noCORSMiddleware())
|
ampAPI.Use(noCORSMiddleware())
|
||||||
|
|
||||||
// Apply localhost-only restriction if configured
|
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||||
if restrictToLocalhost {
|
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||||
ampAPI.Use(localhostOnlyMiddleware())
|
|
||||||
|
if m.IsRestrictedToLocalhost() {
|
||||||
log.Info("amp management routes restricted to localhost only (CORS disabled)")
|
log.Info("amp management routes restricted to localhost only (CORS disabled)")
|
||||||
} else {
|
} else {
|
||||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||||
|
proxyHandler := func(c *gin.Context) {
|
||||||
|
proxy := m.getProxy()
|
||||||
|
if proxy == nil {
|
||||||
|
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
proxy.ServeHTTP(c.Writer, c.Request)
|
||||||
|
}
|
||||||
|
|
||||||
// Management routes - these are proxied directly to Amp upstream
|
// Management routes - these are proxied directly to Amp upstream
|
||||||
ampAPI.Any("/internal", proxyHandler)
|
ampAPI.Any("/internal", proxyHandler)
|
||||||
ampAPI.Any("/internal/*path", proxyHandler)
|
ampAPI.Any("/internal/*path", proxyHandler)
|
||||||
@@ -114,11 +126,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
ampAPI.Any("/tab/*path", proxyHandler)
|
ampAPI.Any("/tab/*path", proxyHandler)
|
||||||
|
|
||||||
// Root-level routes that AMP CLI expects without /api prefix
|
// Root-level routes that AMP CLI expects without /api prefix
|
||||||
// These need the same security middleware as the /api/* routes
|
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||||
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()}
|
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||||
if restrictToLocalhost {
|
|
||||||
rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware())
|
|
||||||
}
|
|
||||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||||
|
|
||||||
// Root-level auth routes for CLI login flow
|
// Root-level auth routes for CLI login flow
|
||||||
@@ -134,7 +143,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
|
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
|
||||||
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
|
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
return m.proxy
|
return m.getProxy()
|
||||||
})
|
})
|
||||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||||
|
|
||||||
@@ -177,10 +186,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||||
|
|
||||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
||||||
// Uses lazy evaluation to access proxy (which is created after routes are registered)
|
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
|
||||||
// Also includes model mapping support for routing unavailable models to alternatives
|
// Also includes model mapping support for routing unavailable models to alternatives
|
||||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||||
return m.proxy
|
return m.getProxy()
|
||||||
}, m.modelMapper)
|
}, m.modelMapper)
|
||||||
|
|
||||||
// Provider-specific routes under /api/provider/:provider
|
// Provider-specific routes under /api/provider/:provider
|
||||||
|
|||||||
@@ -13,16 +13,26 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
|||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
|
|
||||||
// Spy to track if proxy handler was called
|
// Create module with proxy for testing
|
||||||
proxyCalled := false
|
m := &AmpModule{
|
||||||
proxyHandler := func(c *gin.Context) {
|
restrictToLocalhost: false, // disable localhost restriction for tests
|
||||||
proxyCalled = true
|
|
||||||
c.String(200, "proxied")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m := &AmpModule{}
|
// Create a mock proxy that tracks calls
|
||||||
|
proxyCalled := false
|
||||||
|
mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
proxyCalled = true
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte("proxied"))
|
||||||
|
}))
|
||||||
|
defer mockProxy.Close()
|
||||||
|
|
||||||
|
// Create real proxy to mock server
|
||||||
|
proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource(""))
|
||||||
|
m.setProxy(proxy)
|
||||||
|
|
||||||
base := &handlers.BaseAPIHandler{}
|
base := &handlers.BaseAPIHandler{}
|
||||||
m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests
|
m.registerManagementRoutes(r, base)
|
||||||
|
|
||||||
managementPaths := []struct {
|
managementPaths := []struct {
|
||||||
path string
|
path string
|
||||||
@@ -231,8 +241,13 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
|||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
|
|
||||||
// Apply localhost-only middleware
|
// Create module with localhost restriction enabled
|
||||||
r.Use(localhostOnlyMiddleware())
|
m := &AmpModule{
|
||||||
|
restrictToLocalhost: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply dynamic localhost-only middleware
|
||||||
|
r.Use(m.localhostOnlyMiddleware())
|
||||||
r.GET("/test", func(c *gin.Context) {
|
r.GET("/test", func(c *gin.Context) {
|
||||||
c.String(http.StatusOK, "ok")
|
c.String(http.StatusOK, "ok")
|
||||||
})
|
})
|
||||||
@@ -305,3 +320,53 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
|
||||||
|
// Create module with localhost restriction initially enabled
|
||||||
|
m := &AmpModule{
|
||||||
|
restrictToLocalhost: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply dynamic localhost-only middleware
|
||||||
|
r.Use(m.localhostOnlyMiddleware())
|
||||||
|
r.GET("/test", func(c *gin.Context) {
|
||||||
|
c.String(http.StatusOK, "ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 1: Remote IP should be blocked when restriction is enabled
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Expected 403 when restriction enabled, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Hot-reload - disable restriction
|
||||||
|
m.setRestrictToLocalhost(false)
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200 after disabling restriction, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Hot-reload - re-enable restriction
|
||||||
|
m.setRestrictToLocalhost(true)
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user