From 4a135f1986f9e5d7843d652dc83a16263fd3e414 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Thu, 4 Dec 2025 21:25:04 +0800 Subject: [PATCH] feat(amp): add hot-reload support for upstream URL and localhost restriction --- internal/api/modules/amp/amp.go | 76 +++++++++++++--- .../api/modules/amp/model_mapping_test.go | 6 +- internal/api/modules/amp/routes.go | 51 ++++++----- internal/api/modules/amp/routes_test.go | 89 ++++++++++++++++--- 4 files changed, 172 insertions(+), 50 deletions(-) diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index e3067a87..ee6c6b87 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -27,12 +27,17 @@ type Option func(*AmpModule) type AmpModule struct { secretSource SecretSource proxy *httputil.ReverseProxy + proxyMu sync.RWMutex // protects proxy for hot-reload accessManager *sdkaccess.Manager authMiddleware_ gin.HandlerFunc modelMapper *DefaultModelMapper enabled bool registerOnce sync.Once + // restrictToLocalhost controls localhost-only access for management routes (hot-reloadable) + restrictToLocalhost bool + restrictMu sync.RWMutex + // configMu protects lastConfig for partial reload comparison configMu sync.RWMutex lastConfig *config.AmpCode @@ -115,6 +120,9 @@ func (m *AmpModule) Register(ctx modules.Context) error { settingsCopy := settings m.lastConfig = &settingsCopy + // Initialize localhost restriction setting (hot-reloadable) + m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost) + // Always register provider aliases - these work without an upstream m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) @@ -139,13 +147,12 @@ func (m *AmpModule) Register(ctx modules.Context) error { return } - m.proxy = proxy + m.setProxy(proxy) m.enabled = true // Register management proxy routes (requires upstream) - // Restrict to localhost by default for security (prevents drive-by browser attacks) - handler := proxyHandler(proxy) - m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, settings.RestrictManagementToLocalhost) + // Uses dynamic middleware that checks m.IsRestrictedToLocalhost() for hot-reload support + m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler) log.Infof("amp upstream proxy enabled for: %s", upstreamURL) log.Debug("amp provider alias routes registered") @@ -172,7 +179,7 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { // OnConfigUpdated handles configuration updates with partial reload support. // Only updates components that have actually changed to avoid unnecessary work. -// URL changes still require restart (logged as warning). +// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost. func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { newSettings := cfg.AmpCode @@ -199,7 +206,7 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { } if m.enabled { - // Check upstream URL change (requires restart) + // Check upstream URL change - now supports hot-reload newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL) oldUpstreamURL := "" if oldSettings != nil { @@ -207,10 +214,19 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { } if newUpstreamURL == "" && oldUpstreamURL != "" { - log.Warn("amp upstream URL removed from config, restart required to disable") - } else if newUpstreamURL != oldUpstreamURL { - changes = append(changes, "upstream-url(restart required)") - log.Warnf("amp config: upstream-url changed (%s -> %s), restart required", oldUpstreamURL, newUpstreamURL) + log.Warn("amp upstream URL removed from config, proxy disabled until restart") + m.setProxy(nil) + changes = append(changes, "upstream-url(disabled)") + } else if newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" { + // Recreate proxy with new URL + proxy, err := createReverseProxy(newUpstreamURL, m.secretSource) + if err != nil { + log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err) + } else { + m.setProxy(proxy) + changes = append(changes, "upstream-url") + log.Infof("amp config partial reload: upstream URL updated (%s -> %s)", oldUpstreamURL, newUpstreamURL) + } } // Check API key change @@ -226,11 +242,15 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { } } - // Check restrict-management-to-localhost change (requires restart) + // Check restrict-management-to-localhost change - now supports hot-reload if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost { - changes = append(changes, "restrict-management-to-localhost(restart required)") - log.Warnf("amp config: restrict-management-to-localhost changed (%t -> %t), restart required", - oldSettings.RestrictManagementToLocalhost, newSettings.RestrictManagementToLocalhost) + m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost) + changes = append(changes, "restrict-management-to-localhost") + if newSettings.RestrictManagementToLocalhost { + log.Infof("amp config partial reload: management routes now restricted to localhost") + } else { + log.Warnf("amp config partial reload: management routes now accessible from any IP - this is insecure!") + } } } @@ -291,3 +311,31 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b func (m *AmpModule) GetModelMapper() *DefaultModelMapper { return m.modelMapper } + +// getProxy returns the current proxy instance (thread-safe for hot-reload). +func (m *AmpModule) getProxy() *httputil.ReverseProxy { + m.proxyMu.RLock() + defer m.proxyMu.RUnlock() + return m.proxy +} + +// setProxy updates the proxy instance (thread-safe for hot-reload). +func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) { + m.proxyMu.Lock() + defer m.proxyMu.Unlock() + m.proxy = proxy +} + +// IsRestrictedToLocalhost returns whether management routes are restricted to localhost. +func (m *AmpModule) IsRestrictedToLocalhost() bool { + m.restrictMu.RLock() + defer m.restrictMu.RUnlock() + return m.restrictToLocalhost +} + +// setRestrictToLocalhost updates the localhost restriction setting. +func (m *AmpModule) setRestrictToLocalhost(restrict bool) { + m.restrictMu.Lock() + defer m.restrictMu.Unlock() + m.restrictToLocalhost = restrict +} diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go index c11d61bd..4f4e5a8e 100644 --- a/internal/api/modules/amp/model_mapping_test.go +++ b/internal/api/modules/amp/model_mapping_test.go @@ -152,9 +152,9 @@ func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) { mapper := NewModelMapper(nil) mapper.UpdateMappings([]config.AmpModelMapping{ - {From: "", To: "model-b"}, // Invalid: empty from - {From: "model-a", To: ""}, // Invalid: empty to - {From: " ", To: "model-b"}, // Invalid: whitespace from + {From: "", To: "model-b"}, // Invalid: empty from + {From: "model-a", To: ""}, // Invalid: empty to + {From: " ", To: "model-b"}, // Invalid: whitespace from {From: "model-c", To: "model-d"}, // Valid }) diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index b7105a14..b986a53a 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -14,15 +14,16 @@ import ( log "github.com/sirupsen/logrus" ) -// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only. -// Returns 403 Forbidden for non-localhost clients. -// -// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent -// header spoofing attacks via X-Forwarded-For or similar headers. This means the -// middleware will not work correctly behind reverse proxies - users deploying behind -// nginx/Cloudflare should disable this feature and use firewall rules instead. -func localhostOnlyMiddleware() gin.HandlerFunc { +// localhostOnlyMiddleware returns a middleware that dynamically checks the module's +// localhost restriction setting. This allows hot-reload of the restriction without restarting. +func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc { return func(c *gin.Context) { + // Check current setting (hot-reloadable) + if !m.IsRestrictedToLocalhost() { + c.Next() + return + } + // Use actual TCP connection address (RemoteAddr) to prevent header spoofing // This cannot be forged by X-Forwarded-For or other client-controlled headers remoteAddr := c.Request.RemoteAddr @@ -79,21 +80,32 @@ func noCORSMiddleware() gin.HandlerFunc { // registerManagementRoutes registers Amp management proxy routes // These routes proxy through to the Amp control plane for OAuth, user management, etc. -// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1. -func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) { +// Uses dynamic middleware and proxy getter for hot-reload support. +func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) { ampAPI := engine.Group("/api") // Always disable CORS for management routes to prevent browser-based attacks ampAPI.Use(noCORSMiddleware()) - // Apply localhost-only restriction if configured - if restrictToLocalhost { - ampAPI.Use(localhostOnlyMiddleware()) + // Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost()) + ampAPI.Use(m.localhostOnlyMiddleware()) + + if m.IsRestrictedToLocalhost() { log.Info("amp management routes restricted to localhost only (CORS disabled)") } else { log.Warn("amp management routes are NOT restricted to localhost - this is insecure!") } + // Dynamic proxy handler that uses m.getProxy() for hot-reload support + proxyHandler := func(c *gin.Context) { + proxy := m.getProxy() + if proxy == nil { + c.JSON(503, gin.H{"error": "amp upstream proxy not available"}) + return + } + proxy.ServeHTTP(c.Writer, c.Request) + } + // Management routes - these are proxied directly to Amp upstream ampAPI.Any("/internal", proxyHandler) ampAPI.Any("/internal/*path", proxyHandler) @@ -114,11 +126,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha ampAPI.Any("/tab/*path", proxyHandler) // Root-level routes that AMP CLI expects without /api prefix - // These need the same security middleware as the /api/* routes - rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()} - if restrictToLocalhost { - rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware()) - } + // These need the same security middleware as the /api/* routes (dynamic for hot-reload) + rootMiddleware := []gin.HandlerFunc{noCORSMiddleware(), m.localhostOnlyMiddleware()} engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) // Root-level auth routes for CLI login flow @@ -134,7 +143,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) geminiBridge := createGeminiBridgeHandler(geminiHandlers) geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy { - return m.proxy + return m.getProxy() }) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) @@ -177,10 +186,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) // Create fallback handler wrapper that forwards to ampcode.com when provider not found - // Uses lazy evaluation to access proxy (which is created after routes are registered) + // Uses m.getProxy() for hot-reload support (proxy can be updated at runtime) // Also includes model mapping support for routing unavailable models to alternatives fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { - return m.proxy + return m.getProxy() }, m.modelMapper) // Provider-specific routes under /api/provider/:provider diff --git a/internal/api/modules/amp/routes_test.go b/internal/api/modules/amp/routes_test.go index 89e43506..a40852c0 100644 --- a/internal/api/modules/amp/routes_test.go +++ b/internal/api/modules/amp/routes_test.go @@ -13,16 +13,26 @@ func TestRegisterManagementRoutes(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() - // Spy to track if proxy handler was called - proxyCalled := false - proxyHandler := func(c *gin.Context) { - proxyCalled = true - c.String(200, "proxied") + // Create module with proxy for testing + m := &AmpModule{ + restrictToLocalhost: false, // disable localhost restriction for tests } - m := &AmpModule{} + // Create a mock proxy that tracks calls + proxyCalled := false + mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyCalled = true + w.WriteHeader(200) + w.Write([]byte("proxied")) + })) + defer mockProxy.Close() + + // Create real proxy to mock server + proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource("")) + m.setProxy(proxy) + base := &handlers.BaseAPIHandler{} - m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests + m.registerManagementRoutes(r, base) managementPaths := []struct { path string @@ -41,9 +51,9 @@ func TestRegisterManagementRoutes(t *testing.T) { {"/api/otel", http.MethodGet}, {"/api/tab", http.MethodGet}, {"/api/tab/some/path", http.MethodGet}, - {"/auth", http.MethodGet}, // Root-level auth route - {"/auth/cli-login", http.MethodGet}, // CLI login flow - {"/auth/callback", http.MethodGet}, // OAuth callback + {"/auth", http.MethodGet}, // Root-level auth route + {"/auth/cli-login", http.MethodGet}, // CLI login flow + {"/auth/callback", http.MethodGet}, // OAuth callback // Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST {"/api/provider/google/v1beta1/models", http.MethodGet}, {"/api/provider/google/v1beta1/models", http.MethodPost}, @@ -231,8 +241,13 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() - // Apply localhost-only middleware - r.Use(localhostOnlyMiddleware()) + // Create module with localhost restriction enabled + m := &AmpModule{ + restrictToLocalhost: true, + } + + // Apply dynamic localhost-only middleware + r.Use(m.localhostOnlyMiddleware()) r.GET("/test", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) @@ -305,3 +320,53 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) { }) } } + +func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + // Create module with localhost restriction initially enabled + m := &AmpModule{ + restrictToLocalhost: true, + } + + // Apply dynamic localhost-only middleware + r.Use(m.localhostOnlyMiddleware()) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + // Test 1: Remote IP should be blocked when restriction is enabled + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 when restriction enabled, got %d", w.Code) + } + + // Test 2: Hot-reload - disable restriction + m.setRestrictToLocalhost(false) + + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 after disabling restriction, got %d", w.Code) + } + + // Test 3: Hot-reload - re-enable restriction + m.setRestrictToLocalhost(true) + + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code) + } +}