Files
CLIProxyAPI/internal/api/modules/amp/routes_test.go
Ben Vargas db1119dd78 fix(amp): add /threads.rss root-level route for AMP CLI
AMP CLI requests /threads.rss at the root level, but the AMP module
only registered routes under /api/*. This caused a 404 error during
AMP CLI startup.

Add the missing root-level route with the same security middleware
(noCORS, optional localhost restriction) as other management routes.
2025-11-29 05:01:19 -07:00

303 lines
8.6 KiB
Go

package amp
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
)
func TestRegisterManagementRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Spy to track if proxy handler was called
proxyCalled := false
proxyHandler := func(c *gin.Context) {
proxyCalled = true
c.String(200, "proxied")
}
m := &AmpModule{}
base := &handlers.BaseAPIHandler{}
m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests
managementPaths := []struct {
path string
method string
}{
{"/api/internal", http.MethodGet},
{"/api/internal/some/path", http.MethodGet},
{"/api/user", http.MethodGet},
{"/api/user/profile", http.MethodGet},
{"/api/auth", http.MethodGet},
{"/api/auth/login", http.MethodGet},
{"/api/meta", http.MethodGet},
{"/api/telemetry", http.MethodGet},
{"/api/threads", http.MethodGet},
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
{"/api/otel", http.MethodGet},
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
{"/api/provider/google/v1beta1/models", http.MethodGet},
{"/api/provider/google/v1beta1/models", http.MethodPost},
}
for _, path := range managementPaths {
t.Run(path.path, func(t *testing.T) {
proxyCalled = false
req := httptest.NewRequest(path.method, path.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == http.StatusNotFound {
t.Fatalf("route %s not registered", path.path)
}
if !proxyCalled {
t.Fatalf("proxy handler not called for %s", path.path)
}
})
}
}
func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Minimal base handler setup (no need to initialize, just check routing)
base := &handlers.BaseAPIHandler{}
// Track if auth middleware was called
authCalled := false
authMiddleware := func(c *gin.Context) {
authCalled = true
c.Header("X-Auth", "ok")
// Abort with success to avoid calling the actual handler (which needs full setup)
c.AbortWithStatus(http.StatusOK)
}
m := &AmpModule{authMiddleware_: authMiddleware}
m.registerProviderAliases(r, base, authMiddleware)
paths := []struct {
path string
method string
}{
{"/api/provider/openai/models", http.MethodGet},
{"/api/provider/anthropic/models", http.MethodGet},
{"/api/provider/google/models", http.MethodGet},
{"/api/provider/groq/models", http.MethodGet},
{"/api/provider/openai/chat/completions", http.MethodPost},
{"/api/provider/anthropic/v1/messages", http.MethodPost},
{"/api/provider/google/v1beta/models", http.MethodGet},
}
for _, tc := range paths {
t.Run(tc.path, func(t *testing.T) {
authCalled = false
req := httptest.NewRequest(tc.method, tc.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == http.StatusNotFound {
t.Fatalf("route %s %s not registered", tc.method, tc.path)
}
if !authCalled {
t.Fatalf("auth middleware not executed for %s", tc.path)
}
if w.Header().Get("X-Auth") != "ok" {
t.Fatalf("auth middleware header not set for %s", tc.path)
}
})
}
}
func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
base := &handlers.BaseAPIHandler{}
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
providers := []string{"openai", "anthropic", "google", "groq", "cerebras"}
for _, provider := range providers {
t.Run(provider, func(t *testing.T) {
path := "/api/provider/" + provider + "/models"
req := httptest.NewRequest(http.MethodGet, path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Should not 404
if w.Code == http.StatusNotFound {
t.Fatalf("models route not found for provider: %s", provider)
}
})
}
}
func TestRegisterProviderAliases_V1Routes(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
base := &handlers.BaseAPIHandler{}
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
v1Paths := []struct {
path string
method string
}{
{"/api/provider/openai/v1/models", http.MethodGet},
{"/api/provider/openai/v1/chat/completions", http.MethodPost},
{"/api/provider/openai/v1/completions", http.MethodPost},
{"/api/provider/anthropic/v1/messages", http.MethodPost},
{"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost},
}
for _, tc := range v1Paths {
t.Run(tc.path, func(t *testing.T) {
req := httptest.NewRequest(tc.method, tc.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == http.StatusNotFound {
t.Fatalf("v1 route %s %s not registered", tc.method, tc.path)
}
})
}
}
func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
base := &handlers.BaseAPIHandler{}
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
v1betaPaths := []struct {
path string
method string
}{
{"/api/provider/google/v1beta/models", http.MethodGet},
{"/api/provider/google/v1beta/models/generateContent", http.MethodPost},
}
for _, tc := range v1betaPaths {
t.Run(tc.path, func(t *testing.T) {
req := httptest.NewRequest(tc.method, tc.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == http.StatusNotFound {
t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path)
}
})
}
}
func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) {
// Test that routes still register even if auth middleware is nil (fallback behavior)
gin.SetMode(gin.TestMode)
r := gin.New()
base := &handlers.BaseAPIHandler{}
m := &AmpModule{authMiddleware_: nil} // No auth middleware
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Should still work (with fallback no-op auth)
if w.Code == http.StatusNotFound {
t.Fatal("routes should register even without auth middleware")
}
}
func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Apply localhost-only middleware
r.Use(localhostOnlyMiddleware())
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
tests := []struct {
name string
remoteAddr string
forwardedFor string
expectedStatus int
description string
}{
{
name: "spoofed_header_remote_connection",
remoteAddr: "192.168.1.100:12345",
forwardedFor: "127.0.0.1",
expectedStatus: http.StatusForbidden,
description: "Spoofed X-Forwarded-For header should be ignored",
},
{
name: "real_localhost_ipv4",
remoteAddr: "127.0.0.1:54321",
forwardedFor: "",
expectedStatus: http.StatusOK,
description: "Real localhost IPv4 connection should work",
},
{
name: "real_localhost_ipv6",
remoteAddr: "[::1]:54321",
forwardedFor: "",
expectedStatus: http.StatusOK,
description: "Real localhost IPv6 connection should work",
},
{
name: "remote_ipv4",
remoteAddr: "203.0.113.42:8080",
forwardedFor: "",
expectedStatus: http.StatusForbidden,
description: "Remote IPv4 connection should be blocked",
},
{
name: "remote_ipv6",
remoteAddr: "[2001:db8::1]:9090",
forwardedFor: "",
expectedStatus: http.StatusForbidden,
description: "Remote IPv6 connection should be blocked",
},
{
name: "spoofed_localhost_ipv6",
remoteAddr: "203.0.113.42:8080",
forwardedFor: "::1",
expectedStatus: http.StatusForbidden,
description: "Spoofed X-Forwarded-For with IPv6 localhost should be ignored",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.forwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
}
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("%s: expected status %d, got %d", tt.description, tt.expectedStatus, w.Code)
}
})
}
}