mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
AMP CLI requests /threads.rss at the root level, but the AMP module only registered routes under /api/*. This caused a 404 error during AMP CLI startup. Add the missing root-level route with the same security middleware (noCORS, optional localhost restriction) as other management routes.
303 lines
8.6 KiB
Go
303 lines
8.6 KiB
Go
package amp
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
|
)
|
|
|
|
func TestRegisterManagementRoutes(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
|
|
// Spy to track if proxy handler was called
|
|
proxyCalled := false
|
|
proxyHandler := func(c *gin.Context) {
|
|
proxyCalled = true
|
|
c.String(200, "proxied")
|
|
}
|
|
|
|
m := &AmpModule{}
|
|
base := &handlers.BaseAPIHandler{}
|
|
m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests
|
|
|
|
managementPaths := []struct {
|
|
path string
|
|
method string
|
|
}{
|
|
{"/api/internal", http.MethodGet},
|
|
{"/api/internal/some/path", http.MethodGet},
|
|
{"/api/user", http.MethodGet},
|
|
{"/api/user/profile", http.MethodGet},
|
|
{"/api/auth", http.MethodGet},
|
|
{"/api/auth/login", http.MethodGet},
|
|
{"/api/meta", http.MethodGet},
|
|
{"/api/telemetry", http.MethodGet},
|
|
{"/api/threads", http.MethodGet},
|
|
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
|
|
{"/api/otel", http.MethodGet},
|
|
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
|
|
{"/api/provider/google/v1beta1/models", http.MethodGet},
|
|
{"/api/provider/google/v1beta1/models", http.MethodPost},
|
|
}
|
|
|
|
for _, path := range managementPaths {
|
|
t.Run(path.path, func(t *testing.T) {
|
|
proxyCalled = false
|
|
req := httptest.NewRequest(path.method, path.path, nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code == http.StatusNotFound {
|
|
t.Fatalf("route %s not registered", path.path)
|
|
}
|
|
if !proxyCalled {
|
|
t.Fatalf("proxy handler not called for %s", path.path)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
|
|
// Minimal base handler setup (no need to initialize, just check routing)
|
|
base := &handlers.BaseAPIHandler{}
|
|
|
|
// Track if auth middleware was called
|
|
authCalled := false
|
|
authMiddleware := func(c *gin.Context) {
|
|
authCalled = true
|
|
c.Header("X-Auth", "ok")
|
|
// Abort with success to avoid calling the actual handler (which needs full setup)
|
|
c.AbortWithStatus(http.StatusOK)
|
|
}
|
|
|
|
m := &AmpModule{authMiddleware_: authMiddleware}
|
|
m.registerProviderAliases(r, base, authMiddleware)
|
|
|
|
paths := []struct {
|
|
path string
|
|
method string
|
|
}{
|
|
{"/api/provider/openai/models", http.MethodGet},
|
|
{"/api/provider/anthropic/models", http.MethodGet},
|
|
{"/api/provider/google/models", http.MethodGet},
|
|
{"/api/provider/groq/models", http.MethodGet},
|
|
{"/api/provider/openai/chat/completions", http.MethodPost},
|
|
{"/api/provider/anthropic/v1/messages", http.MethodPost},
|
|
{"/api/provider/google/v1beta/models", http.MethodGet},
|
|
}
|
|
|
|
for _, tc := range paths {
|
|
t.Run(tc.path, func(t *testing.T) {
|
|
authCalled = false
|
|
req := httptest.NewRequest(tc.method, tc.path, nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code == http.StatusNotFound {
|
|
t.Fatalf("route %s %s not registered", tc.method, tc.path)
|
|
}
|
|
if !authCalled {
|
|
t.Fatalf("auth middleware not executed for %s", tc.path)
|
|
}
|
|
if w.Header().Get("X-Auth") != "ok" {
|
|
t.Fatalf("auth middleware header not set for %s", tc.path)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
|
|
base := &handlers.BaseAPIHandler{}
|
|
|
|
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
|
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
|
|
|
providers := []string{"openai", "anthropic", "google", "groq", "cerebras"}
|
|
|
|
for _, provider := range providers {
|
|
t.Run(provider, func(t *testing.T) {
|
|
path := "/api/provider/" + provider + "/models"
|
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
// Should not 404
|
|
if w.Code == http.StatusNotFound {
|
|
t.Fatalf("models route not found for provider: %s", provider)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegisterProviderAliases_V1Routes(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
|
|
base := &handlers.BaseAPIHandler{}
|
|
|
|
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
|
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
|
|
|
v1Paths := []struct {
|
|
path string
|
|
method string
|
|
}{
|
|
{"/api/provider/openai/v1/models", http.MethodGet},
|
|
{"/api/provider/openai/v1/chat/completions", http.MethodPost},
|
|
{"/api/provider/openai/v1/completions", http.MethodPost},
|
|
{"/api/provider/anthropic/v1/messages", http.MethodPost},
|
|
{"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost},
|
|
}
|
|
|
|
for _, tc := range v1Paths {
|
|
t.Run(tc.path, func(t *testing.T) {
|
|
req := httptest.NewRequest(tc.method, tc.path, nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code == http.StatusNotFound {
|
|
t.Fatalf("v1 route %s %s not registered", tc.method, tc.path)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
|
|
base := &handlers.BaseAPIHandler{}
|
|
|
|
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
|
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
|
|
|
v1betaPaths := []struct {
|
|
path string
|
|
method string
|
|
}{
|
|
{"/api/provider/google/v1beta/models", http.MethodGet},
|
|
{"/api/provider/google/v1beta/models/generateContent", http.MethodPost},
|
|
}
|
|
|
|
for _, tc := range v1betaPaths {
|
|
t.Run(tc.path, func(t *testing.T) {
|
|
req := httptest.NewRequest(tc.method, tc.path, nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code == http.StatusNotFound {
|
|
t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) {
|
|
// Test that routes still register even if auth middleware is nil (fallback behavior)
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
|
|
base := &handlers.BaseAPIHandler{}
|
|
|
|
m := &AmpModule{authMiddleware_: nil} // No auth middleware
|
|
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil)
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
// Should still work (with fallback no-op auth)
|
|
if w.Code == http.StatusNotFound {
|
|
t.Fatal("routes should register even without auth middleware")
|
|
}
|
|
}
|
|
|
|
func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
r := gin.New()
|
|
|
|
// Apply localhost-only middleware
|
|
r.Use(localhostOnlyMiddleware())
|
|
r.GET("/test", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
remoteAddr string
|
|
forwardedFor string
|
|
expectedStatus int
|
|
description string
|
|
}{
|
|
{
|
|
name: "spoofed_header_remote_connection",
|
|
remoteAddr: "192.168.1.100:12345",
|
|
forwardedFor: "127.0.0.1",
|
|
expectedStatus: http.StatusForbidden,
|
|
description: "Spoofed X-Forwarded-For header should be ignored",
|
|
},
|
|
{
|
|
name: "real_localhost_ipv4",
|
|
remoteAddr: "127.0.0.1:54321",
|
|
forwardedFor: "",
|
|
expectedStatus: http.StatusOK,
|
|
description: "Real localhost IPv4 connection should work",
|
|
},
|
|
{
|
|
name: "real_localhost_ipv6",
|
|
remoteAddr: "[::1]:54321",
|
|
forwardedFor: "",
|
|
expectedStatus: http.StatusOK,
|
|
description: "Real localhost IPv6 connection should work",
|
|
},
|
|
{
|
|
name: "remote_ipv4",
|
|
remoteAddr: "203.0.113.42:8080",
|
|
forwardedFor: "",
|
|
expectedStatus: http.StatusForbidden,
|
|
description: "Remote IPv4 connection should be blocked",
|
|
},
|
|
{
|
|
name: "remote_ipv6",
|
|
remoteAddr: "[2001:db8::1]:9090",
|
|
forwardedFor: "",
|
|
expectedStatus: http.StatusForbidden,
|
|
description: "Remote IPv6 connection should be blocked",
|
|
},
|
|
{
|
|
name: "spoofed_localhost_ipv6",
|
|
remoteAddr: "203.0.113.42:8080",
|
|
forwardedFor: "::1",
|
|
expectedStatus: http.StatusForbidden,
|
|
description: "Spoofed X-Forwarded-For with IPv6 localhost should be ignored",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = tt.remoteAddr
|
|
if tt.forwardedFor != "" {
|
|
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
r.ServeHTTP(w, req)
|
|
|
|
if w.Code != tt.expectedStatus {
|
|
t.Errorf("%s: expected status %d, got %d", tt.description, tt.expectedStatus, w.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|