feat(amp): add hot-reload support for upstream URL and localhost restriction

This commit is contained in:
hkfires
2025-12-04 21:25:04 +08:00
parent c4c02f4ad0
commit 4a135f1986
4 changed files with 172 additions and 50 deletions

View File

@@ -27,12 +27,17 @@ type Option func(*AmpModule)
type AmpModule struct { type AmpModule struct {
secretSource SecretSource secretSource SecretSource
proxy *httputil.ReverseProxy proxy *httputil.ReverseProxy
proxyMu sync.RWMutex // protects proxy for hot-reload
accessManager *sdkaccess.Manager accessManager *sdkaccess.Manager
authMiddleware_ gin.HandlerFunc authMiddleware_ gin.HandlerFunc
modelMapper *DefaultModelMapper modelMapper *DefaultModelMapper
enabled bool enabled bool
registerOnce sync.Once registerOnce sync.Once
// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
restrictToLocalhost bool
restrictMu sync.RWMutex
// configMu protects lastConfig for partial reload comparison // configMu protects lastConfig for partial reload comparison
configMu sync.RWMutex configMu sync.RWMutex
lastConfig *config.AmpCode lastConfig *config.AmpCode
@@ -115,6 +120,9 @@ func (m *AmpModule) Register(ctx modules.Context) error {
settingsCopy := settings settingsCopy := settings
m.lastConfig = &settingsCopy m.lastConfig = &settingsCopy
// Initialize localhost restriction setting (hot-reloadable)
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
// Always register provider aliases - these work without an upstream // Always register provider aliases - these work without an upstream
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
@@ -139,13 +147,12 @@ func (m *AmpModule) Register(ctx modules.Context) error {
return return
} }
m.proxy = proxy m.setProxy(proxy)
m.enabled = true m.enabled = true
// Register management proxy routes (requires upstream) // Register management proxy routes (requires upstream)
// Restrict to localhost by default for security (prevents drive-by browser attacks) // Uses dynamic middleware that checks m.IsRestrictedToLocalhost() for hot-reload support
handler := proxyHandler(proxy) m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, settings.RestrictManagementToLocalhost)
log.Infof("amp upstream proxy enabled for: %s", upstreamURL) log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
log.Debug("amp provider alias routes registered") log.Debug("amp provider alias routes registered")
@@ -172,7 +179,7 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
// OnConfigUpdated handles configuration updates with partial reload support. // OnConfigUpdated handles configuration updates with partial reload support.
// Only updates components that have actually changed to avoid unnecessary work. // Only updates components that have actually changed to avoid unnecessary work.
// URL changes still require restart (logged as warning). // Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
newSettings := cfg.AmpCode newSettings := cfg.AmpCode
@@ -199,7 +206,7 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
} }
if m.enabled { if m.enabled {
// Check upstream URL change (requires restart) // Check upstream URL change - now supports hot-reload
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL) newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
oldUpstreamURL := "" oldUpstreamURL := ""
if oldSettings != nil { if oldSettings != nil {
@@ -207,10 +214,19 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
} }
if newUpstreamURL == "" && oldUpstreamURL != "" { if newUpstreamURL == "" && oldUpstreamURL != "" {
log.Warn("amp upstream URL removed from config, restart required to disable") log.Warn("amp upstream URL removed from config, proxy disabled until restart")
} else if newUpstreamURL != oldUpstreamURL { m.setProxy(nil)
changes = append(changes, "upstream-url(restart required)") changes = append(changes, "upstream-url(disabled)")
log.Warnf("amp config: upstream-url changed (%s -> %s), restart required", oldUpstreamURL, newUpstreamURL) } else if newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
// Recreate proxy with new URL
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
if err != nil {
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
} else {
m.setProxy(proxy)
changes = append(changes, "upstream-url")
log.Infof("amp config partial reload: upstream URL updated (%s -> %s)", oldUpstreamURL, newUpstreamURL)
}
} }
// Check API key change // Check API key change
@@ -226,11 +242,15 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
} }
} }
// Check restrict-management-to-localhost change (requires restart) // Check restrict-management-to-localhost change - now supports hot-reload
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost { if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
changes = append(changes, "restrict-management-to-localhost(restart required)") m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
log.Warnf("amp config: restrict-management-to-localhost changed (%t -> %t), restart required", changes = append(changes, "restrict-management-to-localhost")
oldSettings.RestrictManagementToLocalhost, newSettings.RestrictManagementToLocalhost) if newSettings.RestrictManagementToLocalhost {
log.Infof("amp config partial reload: management routes now restricted to localhost")
} else {
log.Warnf("amp config partial reload: management routes now accessible from any IP - this is insecure!")
}
} }
} }
@@ -291,3 +311,31 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
func (m *AmpModule) GetModelMapper() *DefaultModelMapper { func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper return m.modelMapper
} }
// getProxy returns the current proxy instance (thread-safe for hot-reload).
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
m.proxyMu.RLock()
defer m.proxyMu.RUnlock()
return m.proxy
}
// setProxy updates the proxy instance (thread-safe for hot-reload).
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
m.proxyMu.Lock()
defer m.proxyMu.Unlock()
m.proxy = proxy
}
// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
func (m *AmpModule) IsRestrictedToLocalhost() bool {
m.restrictMu.RLock()
defer m.restrictMu.RUnlock()
return m.restrictToLocalhost
}
// setRestrictToLocalhost updates the localhost restriction setting.
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
m.restrictMu.Lock()
defer m.restrictMu.Unlock()
m.restrictToLocalhost = restrict
}

View File

@@ -152,9 +152,9 @@ func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
mapper := NewModelMapper(nil) mapper := NewModelMapper(nil)
mapper.UpdateMappings([]config.AmpModelMapping{ mapper.UpdateMappings([]config.AmpModelMapping{
{From: "", To: "model-b"}, // Invalid: empty from {From: "", To: "model-b"}, // Invalid: empty from
{From: "model-a", To: ""}, // Invalid: empty to {From: "model-a", To: ""}, // Invalid: empty to
{From: " ", To: "model-b"}, // Invalid: whitespace from {From: " ", To: "model-b"}, // Invalid: whitespace from
{From: "model-c", To: "model-d"}, // Valid {From: "model-c", To: "model-d"}, // Valid
}) })

View File

@@ -14,15 +14,16 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only. // localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// Returns 403 Forbidden for non-localhost clients. // localhost restriction setting. This allows hot-reload of the restriction without restarting.
// func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent
// header spoofing attacks via X-Forwarded-For or similar headers. This means the
// middleware will not work correctly behind reverse proxies - users deploying behind
// nginx/Cloudflare should disable this feature and use firewall rules instead.
func localhostOnlyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Check current setting (hot-reloadable)
if !m.IsRestrictedToLocalhost() {
c.Next()
return
}
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing // Use actual TCP connection address (RemoteAddr) to prevent header spoofing
// This cannot be forged by X-Forwarded-For or other client-controlled headers // This cannot be forged by X-Forwarded-For or other client-controlled headers
remoteAddr := c.Request.RemoteAddr remoteAddr := c.Request.RemoteAddr
@@ -79,21 +80,32 @@ func noCORSMiddleware() gin.HandlerFunc {
// registerManagementRoutes registers Amp management proxy routes // registerManagementRoutes registers Amp management proxy routes
// These routes proxy through to the Amp control plane for OAuth, user management, etc. // These routes proxy through to the Amp control plane for OAuth, user management, etc.
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1. // Uses dynamic middleware and proxy getter for hot-reload support.
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) { func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
ampAPI := engine.Group("/api") ampAPI := engine.Group("/api")
// Always disable CORS for management routes to prevent browser-based attacks // Always disable CORS for management routes to prevent browser-based attacks
ampAPI.Use(noCORSMiddleware()) ampAPI.Use(noCORSMiddleware())
// Apply localhost-only restriction if configured // Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
if restrictToLocalhost { ampAPI.Use(m.localhostOnlyMiddleware())
ampAPI.Use(localhostOnlyMiddleware())
if m.IsRestrictedToLocalhost() {
log.Info("amp management routes restricted to localhost only (CORS disabled)") log.Info("amp management routes restricted to localhost only (CORS disabled)")
} else { } else {
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!") log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
} }
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) {
proxy := m.getProxy()
if proxy == nil {
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
return
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// Management routes - these are proxied directly to Amp upstream // Management routes - these are proxied directly to Amp upstream
ampAPI.Any("/internal", proxyHandler) ampAPI.Any("/internal", proxyHandler)
ampAPI.Any("/internal/*path", proxyHandler) ampAPI.Any("/internal/*path", proxyHandler)
@@ -114,11 +126,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
ampAPI.Any("/tab/*path", proxyHandler) ampAPI.Any("/tab/*path", proxyHandler)
// Root-level routes that AMP CLI expects without /api prefix // Root-level routes that AMP CLI expects without /api prefix
// These need the same security middleware as the /api/* routes // These need the same security middleware as the /api/* routes (dynamic for hot-reload)
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()} rootMiddleware := []gin.HandlerFunc{noCORSMiddleware(), m.localhostOnlyMiddleware()}
if restrictToLocalhost {
rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware())
}
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
// Root-level auth routes for CLI login flow // Root-level auth routes for CLI login flow
@@ -134,7 +143,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers) geminiBridge := createGeminiBridgeHandler(geminiHandlers)
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy { geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
return m.proxy return m.getProxy()
}) })
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
@@ -177,10 +186,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
// Create fallback handler wrapper that forwards to ampcode.com when provider not found // Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses lazy evaluation to access proxy (which is created after routes are registered) // Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
// Also includes model mapping support for routing unavailable models to alternatives // Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.proxy return m.getProxy()
}, m.modelMapper) }, m.modelMapper)
// Provider-specific routes under /api/provider/:provider // Provider-specific routes under /api/provider/:provider

View File

@@ -13,16 +13,26 @@ func TestRegisterManagementRoutes(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
// Spy to track if proxy handler was called // Create module with proxy for testing
proxyCalled := false m := &AmpModule{
proxyHandler := func(c *gin.Context) { restrictToLocalhost: false, // disable localhost restriction for tests
proxyCalled = true
c.String(200, "proxied")
} }
m := &AmpModule{} // Create a mock proxy that tracks calls
proxyCalled := false
mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyCalled = true
w.WriteHeader(200)
w.Write([]byte("proxied"))
}))
defer mockProxy.Close()
// Create real proxy to mock server
proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource(""))
m.setProxy(proxy)
base := &handlers.BaseAPIHandler{} base := &handlers.BaseAPIHandler{}
m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests m.registerManagementRoutes(r, base)
managementPaths := []struct { managementPaths := []struct {
path string path string
@@ -41,9 +51,9 @@ func TestRegisterManagementRoutes(t *testing.T) {
{"/api/otel", http.MethodGet}, {"/api/otel", http.MethodGet},
{"/api/tab", http.MethodGet}, {"/api/tab", http.MethodGet},
{"/api/tab/some/path", http.MethodGet}, {"/api/tab/some/path", http.MethodGet},
{"/auth", http.MethodGet}, // Root-level auth route {"/auth", http.MethodGet}, // Root-level auth route
{"/auth/cli-login", http.MethodGet}, // CLI login flow {"/auth/cli-login", http.MethodGet}, // CLI login flow
{"/auth/callback", http.MethodGet}, // OAuth callback {"/auth/callback", http.MethodGet}, // OAuth callback
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST // Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
{"/api/provider/google/v1beta1/models", http.MethodGet}, {"/api/provider/google/v1beta1/models", http.MethodGet},
{"/api/provider/google/v1beta1/models", http.MethodPost}, {"/api/provider/google/v1beta1/models", http.MethodPost},
@@ -231,8 +241,13 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
// Apply localhost-only middleware // Create module with localhost restriction enabled
r.Use(localhostOnlyMiddleware()) m := &AmpModule{
restrictToLocalhost: true,
}
// Apply dynamic localhost-only middleware
r.Use(m.localhostOnlyMiddleware())
r.GET("/test", func(c *gin.Context) { r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok") c.String(http.StatusOK, "ok")
}) })
@@ -305,3 +320,53 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
}) })
} }
} }
func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Create module with localhost restriction initially enabled
m := &AmpModule{
restrictToLocalhost: true,
}
// Apply dynamic localhost-only middleware
r.Use(m.localhostOnlyMiddleware())
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// Test 1: Remote IP should be blocked when restriction is enabled
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Expected 403 when restriction enabled, got %d", w.Code)
}
// Test 2: Hot-reload - disable restriction
m.setRestrictToLocalhost(false)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200 after disabling restriction, got %d", w.Code)
}
// Test 3: Hot-reload - re-enable restriction
m.setRestrictToLocalhost(true)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code)
}
}