mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat: Add Amp CLI integration with comprehensive documentation
Add full Amp CLI support to enable routing AI model requests through the proxy
while maintaining Amp-specific features like thread management, user info, and
telemetry. Includes complete documentation and pull bot configuration.
Features:
- Modular architecture with RouteModule interface for clean integration
- Reverse proxy for Amp management routes (thread/user/meta/ads/telemetry)
- Provider-specific route aliases (/api/provider/{provider}/*)
- Secret management with precedence: config > env > file
- 5-minute secret caching to reduce file I/O
- Automatic gzip decompression for responses
- Proper connection cleanup to prevent leaks
- Localhost-only restriction for management routes (configurable)
- CORS protection for management endpoints
Documentation:
- Complete setup guide (USING_WITH_FACTORY_AND_AMP.md)
- OAuth setup for OpenAI (ChatGPT Plus/Pro) and Anthropic (Claude Pro/Max)
- Factory CLI config examples with all model variants
- Amp CLI/IDE configuration examples
- tmux setup for remote server deployment
- Screenshots and diagrams
Configuration:
- Pull bot disabled for this repo (manual rebase workflow)
- Config fields: AmpUpstreamURL, AmpUpstreamAPIKey, AmpRestrictManagementToLocalhost
- Compatible with upstream DisableCooling and other features
Technical details:
- internal/api/modules/amp/: Complete Amp routing module
- sdk/api/httpx/: HTTP utilities for gzip/transport
- 94.6% test coverage with 34 comprehensive test cases
- Clean integration minimizes merge conflict risk
Security:
- Management routes restricted to localhost by default
- Configurable via amp-restrict-management-to-localhost
- Prevents drive-by browser attacks on user data
This provides a production-ready foundation for Amp CLI integration while
maintaining clean separation from upstream code for easy rebasing.
Amp-Thread-ID: https://ampcode.com/threads/T-9e2befc5-f969-41c6-890c-5b779d58cf18
This commit is contained in:
185
internal/api/modules/amp/amp.go
Normal file
185
internal/api/modules/amp/amp.go
Normal file
@@ -0,0 +1,185 @@
|
||||
// Package amp implements the Amp CLI routing module, providing OAuth-based
|
||||
// integration with Amp CLI for ChatGPT and Anthropic subscriptions.
|
||||
package amp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Option configures the AmpModule.
|
||||
type Option func(*AmpModule)
|
||||
|
||||
// AmpModule implements the RouteModuleV2 interface for Amp CLI integration.
|
||||
// It provides:
|
||||
// - Reverse proxy to Amp control plane for OAuth/management
|
||||
// - Provider-specific route aliases (/api/provider/{provider}/...)
|
||||
// - Automatic gzip decompression for misconfigured upstreams
|
||||
type AmpModule struct {
|
||||
secretSource SecretSource
|
||||
proxy *httputil.ReverseProxy
|
||||
accessManager *sdkaccess.Manager
|
||||
authMiddleware_ gin.HandlerFunc
|
||||
enabled bool
|
||||
registerOnce sync.Once
|
||||
}
|
||||
|
||||
// New creates a new Amp routing module with the given options.
|
||||
// This is the preferred constructor using the Option pattern.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ampModule := amp.New(
|
||||
// amp.WithAccessManager(accessManager),
|
||||
// amp.WithAuthMiddleware(authMiddleware),
|
||||
// amp.WithSecretSource(customSecret),
|
||||
// )
|
||||
func New(opts ...Option) *AmpModule {
|
||||
m := &AmpModule{
|
||||
secretSource: nil, // Will be created on demand if not provided
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// NewLegacy creates a new Amp routing module using the legacy constructor signature.
|
||||
// This is provided for backwards compatibility.
|
||||
//
|
||||
// DEPRECATED: Use New with options instead.
|
||||
func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule {
|
||||
return New(
|
||||
WithAccessManager(accessManager),
|
||||
WithAuthMiddleware(authMiddleware),
|
||||
)
|
||||
}
|
||||
|
||||
// WithSecretSource sets a custom secret source for the module.
|
||||
func WithSecretSource(source SecretSource) Option {
|
||||
return func(m *AmpModule) {
|
||||
m.secretSource = source
|
||||
}
|
||||
}
|
||||
|
||||
// WithAccessManager sets the access manager for the module.
|
||||
func WithAccessManager(am *sdkaccess.Manager) Option {
|
||||
return func(m *AmpModule) {
|
||||
m.accessManager = am
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthMiddleware sets the authentication middleware for provider routes.
|
||||
func WithAuthMiddleware(middleware gin.HandlerFunc) Option {
|
||||
return func(m *AmpModule) {
|
||||
m.authMiddleware_ = middleware
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the module identifier
|
||||
func (m *AmpModule) Name() string {
|
||||
return "amp-routing"
|
||||
}
|
||||
|
||||
// Register sets up Amp routes if configured.
|
||||
// This implements the RouteModuleV2 interface with Context.
|
||||
// Routes are registered only once via sync.Once for idempotent behavior.
|
||||
func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
upstreamURL := strings.TrimSpace(ctx.Config.AmpUpstreamURL)
|
||||
|
||||
// Determine auth middleware (from module or context)
|
||||
auth := m.getAuthMiddleware(ctx)
|
||||
|
||||
// Use registerOnce to ensure routes are only registered once
|
||||
var regErr error
|
||||
m.registerOnce.Do(func() {
|
||||
// Always register provider aliases - these work without an upstream
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
log.Debug("Amp upstream proxy disabled (no upstream URL configured)")
|
||||
log.Debug("Amp provider alias routes registered")
|
||||
m.enabled = false
|
||||
return
|
||||
}
|
||||
|
||||
// Create secret source with precedence: config > env > file
|
||||
// Cache secrets for 5 minutes to reduce file I/O
|
||||
if m.secretSource == nil {
|
||||
m.secretSource = NewMultiSourceSecret(ctx.Config.AmpUpstreamAPIKey, 0 /* default 5min */)
|
||||
}
|
||||
|
||||
// Create reverse proxy with gzip handling via ModifyResponse
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
if err != nil {
|
||||
regErr = fmt.Errorf("failed to create amp proxy: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.proxy = proxy
|
||||
m.enabled = true
|
||||
|
||||
// Register management proxy routes (requires upstream)
|
||||
// Restrict to localhost by default for security (prevents drive-by browser attacks)
|
||||
handler := proxyHandler(proxy)
|
||||
m.registerManagementRoutes(ctx.Engine, handler, ctx.Config.AmpRestrictManagementToLocalhost)
|
||||
|
||||
log.Infof("Amp upstream proxy enabled for: %s", upstreamURL)
|
||||
log.Debug("Amp provider alias routes registered")
|
||||
})
|
||||
|
||||
return regErr
|
||||
}
|
||||
|
||||
// getAuthMiddleware returns the authentication middleware, preferring the
|
||||
// module's configured middleware, then the context middleware, then a fallback.
|
||||
func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
|
||||
if m.authMiddleware_ != nil {
|
||||
return m.authMiddleware_
|
||||
}
|
||||
if ctx.AuthMiddleware != nil {
|
||||
return ctx.AuthMiddleware
|
||||
}
|
||||
// Fallback: no authentication (should not happen in production)
|
||||
log.Warn("Amp module: no auth middleware provided, allowing all requests")
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OnConfigUpdated handles configuration updates.
|
||||
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
|
||||
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
if !m.enabled {
|
||||
log.Debug("Amp routing not enabled, skipping config update")
|
||||
return nil
|
||||
}
|
||||
|
||||
upstreamURL := strings.TrimSpace(cfg.AmpUpstreamURL)
|
||||
if upstreamURL == "" {
|
||||
log.Warn("Amp upstream URL removed from config, restart required to disable")
|
||||
return nil
|
||||
}
|
||||
|
||||
// If API key changed, invalidate the cache
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.InvalidateCache()
|
||||
log.Debug("Amp secret cache invalidated due to config update")
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Amp config updated (restart required for URL changes)")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
303
internal/api/modules/amp/amp_test.go
Normal file
303
internal/api/modules/amp/amp_test.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
)
|
||||
|
||||
func TestAmpModule_Name(t *testing.T) {
|
||||
m := New()
|
||||
if m.Name() != "amp-routing" {
|
||||
t.Fatalf("want amp-routing, got %s", m.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_New(t *testing.T) {
|
||||
accessManager := sdkaccess.NewManager()
|
||||
authMiddleware := func(c *gin.Context) { c.Next() }
|
||||
|
||||
m := NewLegacy(accessManager, authMiddleware)
|
||||
|
||||
if m.accessManager != accessManager {
|
||||
t.Fatal("accessManager not set")
|
||||
}
|
||||
if m.authMiddleware_ == nil {
|
||||
t.Fatal("authMiddleware not set")
|
||||
}
|
||||
if m.enabled {
|
||||
t.Fatal("enabled should be false initially")
|
||||
}
|
||||
if m.proxy != nil {
|
||||
t.Fatal("proxy should be nil initially")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_Register_WithUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Fake upstream to ensure URL is valid
|
||||
upstream := httptest.NewServer(nil)
|
||||
defer upstream.Close()
|
||||
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpUpstreamURL: upstream.URL,
|
||||
AmpUpstreamAPIKey: "test-key",
|
||||
}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err != nil {
|
||||
t.Fatalf("register error: %v", err)
|
||||
}
|
||||
|
||||
if !m.enabled {
|
||||
t.Fatal("module should be enabled with upstream URL")
|
||||
}
|
||||
if m.proxy == nil {
|
||||
t.Fatal("proxy should be initialized")
|
||||
}
|
||||
if m.secretSource == nil {
|
||||
t.Fatal("secretSource should be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_Register_WithoutUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpUpstreamURL: "", // No upstream
|
||||
}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err != nil {
|
||||
t.Fatalf("register should not error without upstream: %v", err)
|
||||
}
|
||||
|
||||
if m.enabled {
|
||||
t.Fatal("module should be disabled without upstream URL")
|
||||
}
|
||||
if m.proxy != nil {
|
||||
t.Fatal("proxy should not be initialized without upstream")
|
||||
}
|
||||
|
||||
// But provider aliases should still be registered
|
||||
req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 404 {
|
||||
t.Fatal("provider aliases should be registered even without upstream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_Register_InvalidUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpUpstreamURL: "://invalid-url",
|
||||
}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err == nil {
|
||||
t.Fatal("expected error for invalid upstream URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if ms.cache == nil {
|
||||
t.Fatal("expected cache to be set")
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpUpstreamURL: "http://x"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if ms.cache != nil {
|
||||
t.Fatal("expected cache to be invalidated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_OnConfigUpdated_NotEnabled(t *testing.T) {
|
||||
m := &AmpModule{enabled: false}
|
||||
|
||||
// Should not error or panic when disabled
|
||||
if err := m.OnConfigUpdated(&config.Config{}); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_OnConfigUpdated_URLRemoved(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecret("", 0)
|
||||
m.secretSource = ms
|
||||
|
||||
// Config update with empty URL - should log warning but not error
|
||||
cfg := &config.Config{AmpUpstreamURL: ""}
|
||||
|
||||
if err := m.OnConfigUpdated(cfg); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_OnConfigUpdated_NonMultiSourceSecret(t *testing.T) {
|
||||
// Test that OnConfigUpdated doesn't panic with StaticSecretSource
|
||||
m := &AmpModule{enabled: true}
|
||||
m.secretSource = NewStaticSecretSource("static-key")
|
||||
|
||||
cfg := &config.Config{AmpUpstreamURL: "http://example.com"}
|
||||
|
||||
// Should not error or panic
|
||||
if err := m.OnConfigUpdated(cfg); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_AuthMiddleware_Fallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Create module with no auth middleware
|
||||
m := &AmpModule{authMiddleware_: nil}
|
||||
|
||||
// Get the fallback middleware via getAuthMiddleware
|
||||
ctx := modules.Context{Engine: r, AuthMiddleware: nil}
|
||||
middleware := m.getAuthMiddleware(ctx)
|
||||
|
||||
if middleware == nil {
|
||||
t.Fatal("getAuthMiddleware should return a fallback, not nil")
|
||||
}
|
||||
|
||||
// Test that it works
|
||||
called := false
|
||||
r.GET("/test", middleware, func(c *gin.Context) {
|
||||
called = true
|
||||
c.String(200, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if !called {
|
||||
t.Fatal("fallback middleware should allow requests through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_SecretSource_FromConfig(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
upstream := httptest.NewServer(nil)
|
||||
defer upstream.Close()
|
||||
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
// Config with explicit API key
|
||||
cfg := &config.Config{
|
||||
AmpUpstreamURL: upstream.URL,
|
||||
AmpUpstreamAPIKey: "config-key",
|
||||
}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err != nil {
|
||||
t.Fatalf("register error: %v", err)
|
||||
}
|
||||
|
||||
// Secret source should be MultiSourceSecret with config key
|
||||
if m.secretSource == nil {
|
||||
t.Fatal("secretSource should be set")
|
||||
}
|
||||
|
||||
// Verify it returns the config key
|
||||
key, err := m.secretSource.Get(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Get error: %v", err)
|
||||
}
|
||||
if key != "config-key" {
|
||||
t.Fatalf("want config-key, got %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
configURL string
|
||||
}{
|
||||
{"with_upstream", "http://example.com"},
|
||||
{"without_upstream", ""},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
t.Run(scenario.name, func(t *testing.T) {
|
||||
r := gin.New()
|
||||
accessManager := sdkaccess.NewManager()
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
||||
|
||||
cfg := &config.Config{AmpUpstreamURL: scenario.configURL}
|
||||
|
||||
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
||||
if err := m.Register(ctx); err != nil && scenario.configURL != "" {
|
||||
t.Fatalf("register error: %v", err)
|
||||
}
|
||||
|
||||
// Provider aliases should always be available
|
||||
req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 404 {
|
||||
t.Fatal("provider aliases should be registered")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
176
internal/api/modules/amp/proxy.go
Normal file
176
internal/api/modules/amp/proxy.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// readCloser wraps a reader and forwards Close to a separate closer.
|
||||
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
||||
type readCloser struct {
|
||||
r io.Reader
|
||||
c io.Closer
|
||||
}
|
||||
|
||||
func (rc *readCloser) Read(p []byte) (int, error) { return rc.r.Read(p) }
|
||||
func (rc *readCloser) Close() error { return rc.c.Close() }
|
||||
|
||||
// createReverseProxy creates a reverse proxy handler for Amp upstream
|
||||
// with automatic gzip decompression via ModifyResponse
|
||||
func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputil.ReverseProxy, error) {
|
||||
parsed, err := url.Parse(upstreamURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid amp upstream url: %w", err)
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(parsed)
|
||||
originalDirector := proxy.Director
|
||||
|
||||
// Modify outgoing requests to inject API key and fix routing
|
||||
proxy.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
req.Host = parsed.Host
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
// Could generate one here if needed
|
||||
}
|
||||
|
||||
// Inject API key from secret source (precedence: config > env > file)
|
||||
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
||||
req.Header.Set("X-Api-Key", key)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
} else if err != nil {
|
||||
log.Warnf("amp secret source error (continuing without auth): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Modify incoming responses to handle gzip without Content-Encoding
|
||||
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Only process successful responses
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip if already marked as gzip (Content-Encoding set)
|
||||
if resp.Header.Get("Content-Encoding") != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip streaming responses (SSE, chunked)
|
||||
if isStreamingResponse(resp) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save reference to original upstream body for proper cleanup
|
||||
originalBody := resp.Body
|
||||
|
||||
// Peek at first 2 bytes to detect gzip magic bytes
|
||||
header := make([]byte, 2)
|
||||
n, _ := io.ReadFull(originalBody, header)
|
||||
|
||||
// Check for gzip magic bytes (0x1f 0x8b)
|
||||
// If n < 2, we didn't get enough bytes, so it's not gzip
|
||||
if n >= 2 && header[0] == 0x1f && header[1] == 0x8b {
|
||||
// It's gzip - read the rest of the body
|
||||
rest, err := io.ReadAll(originalBody)
|
||||
if err != nil {
|
||||
// Restore what we read and return original body (preserve Close behavior)
|
||||
resp.Body = &readCloser{
|
||||
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
|
||||
c: originalBody,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reconstruct complete gzipped data
|
||||
gzippedData := append(header[:n], rest...)
|
||||
|
||||
// Decompress
|
||||
gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedData))
|
||||
if err != nil {
|
||||
log.Warnf("amp proxy: gzip header detected but decompress failed: %v", err)
|
||||
// Close original body and return in-memory copy
|
||||
_ = originalBody.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
|
||||
return nil
|
||||
}
|
||||
|
||||
decompressed, err := io.ReadAll(gzipReader)
|
||||
_ = gzipReader.Close()
|
||||
if err != nil {
|
||||
log.Warnf("amp proxy: gzip decompress error: %v", err)
|
||||
// Close original body and return in-memory copy
|
||||
_ = originalBody.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close original body since we're replacing with in-memory decompressed content
|
||||
_ = originalBody.Close()
|
||||
|
||||
// Replace body with decompressed content
|
||||
resp.Body = io.NopCloser(bytes.NewReader(decompressed))
|
||||
resp.ContentLength = int64(len(decompressed))
|
||||
|
||||
// Update headers to reflect decompressed state
|
||||
resp.Header.Del("Content-Encoding") // No longer compressed
|
||||
resp.Header.Del("Content-Length") // Remove stale compressed length
|
||||
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length
|
||||
|
||||
log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed))
|
||||
} else {
|
||||
// Not gzip - restore peeked bytes while preserving Close behavior
|
||||
// Handle edge cases: n might be 0, 1, or 2 depending on EOF
|
||||
resp.Body = &readCloser{
|
||||
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
|
||||
c: originalBody,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error handler for proxy failures
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// isStreamingResponse detects if the response is streaming (SSE only)
|
||||
// Note: We only treat text/event-stream as streaming. Chunked transfer encoding
|
||||
// is a transport-level detail and doesn't mean we can't decompress the full response.
|
||||
// Many JSON APIs use chunked encoding for normal responses.
|
||||
func isStreamingResponse(resp *http.Response) bool {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
|
||||
// Only Server-Sent Events are true streaming responses
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// proxyHandler converts httputil.ReverseProxy to gin.HandlerFunc
|
||||
func proxyHandler(proxy *httputil.ReverseProxy) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
439
internal/api/modules/amp/proxy_test.go
Normal file
439
internal/api/modules/amp/proxy_test.go
Normal file
@@ -0,0 +1,439 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Helper: compress data with gzip
|
||||
func gzipBytes(b []byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
zw := gzip.NewWriter(&buf)
|
||||
zw.Write(b)
|
||||
zw.Close()
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// Helper: create a mock http.Response
|
||||
func mkResp(status int, hdr http.Header, body []byte) *http.Response {
|
||||
if hdr == nil {
|
||||
hdr = http.Header{}
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: hdr,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
ContentLength: int64(len(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateReverseProxy_ValidURL(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key"))
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if proxy == nil {
|
||||
t.Fatal("expected proxy to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateReverseProxy_InvalidURL(t *testing.T) {
|
||||
_, err := createReverseProxy("://invalid", NewStaticSecretSource("key"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifyResponse_GzipScenarios(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goodJSON := []byte(`{"ok":true}`)
|
||||
good := gzipBytes(goodJSON)
|
||||
truncated := good[:10]
|
||||
corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
body []byte
|
||||
status int
|
||||
wantBody []byte
|
||||
wantCE string
|
||||
}{
|
||||
{
|
||||
name: "decompresses_valid_gzip_no_header",
|
||||
header: http.Header{},
|
||||
body: good,
|
||||
status: 200,
|
||||
wantBody: goodJSON,
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "skips_when_ce_present",
|
||||
header: http.Header{"Content-Encoding": []string{"gzip"}},
|
||||
body: good,
|
||||
status: 200,
|
||||
wantBody: good,
|
||||
wantCE: "gzip",
|
||||
},
|
||||
{
|
||||
name: "passes_truncated_unchanged",
|
||||
header: http.Header{},
|
||||
body: truncated,
|
||||
status: 200,
|
||||
wantBody: truncated,
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "passes_corrupted_unchanged",
|
||||
header: http.Header{},
|
||||
body: corrupted,
|
||||
status: 200,
|
||||
wantBody: corrupted,
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "non_gzip_unchanged",
|
||||
header: http.Header{},
|
||||
body: []byte("plain"),
|
||||
status: 200,
|
||||
wantBody: []byte("plain"),
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "empty_body",
|
||||
header: http.Header{},
|
||||
body: []byte{},
|
||||
status: 200,
|
||||
wantBody: []byte{},
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "single_byte_body",
|
||||
header: http.Header{},
|
||||
body: []byte{0x1f},
|
||||
status: 200,
|
||||
wantBody: []byte{0x1f},
|
||||
wantCE: "",
|
||||
},
|
||||
{
|
||||
name: "skips_non_2xx_status",
|
||||
header: http.Header{},
|
||||
body: good,
|
||||
status: 404,
|
||||
wantBody: good,
|
||||
wantCE: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := mkResp(tc.status, tc.header, tc.body)
|
||||
if err := proxy.ModifyResponse(resp); err != nil {
|
||||
t.Fatalf("ModifyResponse error: %v", err)
|
||||
}
|
||||
got, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, tc.wantBody) {
|
||||
t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got)
|
||||
}
|
||||
if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE {
|
||||
t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goodJSON := []byte(`{"message":"test response"}`)
|
||||
gzipped := gzipBytes(goodJSON)
|
||||
|
||||
// Simulate upstream response with gzip body AND Content-Length header
|
||||
// (this is the scenario the bot flagged - stale Content-Length after decompression)
|
||||
resp := mkResp(200, http.Header{
|
||||
"Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size
|
||||
}, gzipped)
|
||||
|
||||
if err := proxy.ModifyResponse(resp); err != nil {
|
||||
t.Fatalf("ModifyResponse error: %v", err)
|
||||
}
|
||||
|
||||
// Verify body is decompressed
|
||||
got, _ := io.ReadAll(resp.Body)
|
||||
if !bytes.Equal(got, goodJSON) {
|
||||
t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON)
|
||||
}
|
||||
|
||||
// Verify Content-Length header is updated to decompressed size
|
||||
wantCL := fmt.Sprintf("%d", len(goodJSON))
|
||||
gotCL := resp.Header.Get("Content-Length")
|
||||
if gotCL != wantCL {
|
||||
t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL)
|
||||
}
|
||||
|
||||
// Verify struct field also matches
|
||||
if resp.ContentLength != int64(len(goodJSON)) {
|
||||
t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifyResponse_SkipsStreamingResponses(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goodJSON := []byte(`{"ok":true}`)
|
||||
gzipped := gzipBytes(goodJSON)
|
||||
|
||||
t.Run("sse_skips_decompression", func(t *testing.T) {
|
||||
resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped)
|
||||
if err := proxy.ModifyResponse(resp); err != nil {
|
||||
t.Fatalf("ModifyResponse error: %v", err)
|
||||
}
|
||||
// SSE should NOT be decompressed
|
||||
got, _ := io.ReadAll(resp.Body)
|
||||
if !bytes.Equal(got, gzipped) {
|
||||
t.Fatal("SSE response should not be decompressed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) {
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goodJSON := []byte(`{"ok":true}`)
|
||||
gzipped := gzipBytes(goodJSON)
|
||||
|
||||
t.Run("chunked_json_decompresses", func(t *testing.T) {
|
||||
// Chunked JSON responses (like thread APIs) should be decompressed
|
||||
resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped)
|
||||
if err := proxy.ModifyResponse(resp); err != nil {
|
||||
t.Fatalf("ModifyResponse error: %v", err)
|
||||
}
|
||||
// Should decompress because it's not SSE
|
||||
got, _ := io.ReadAll(resp.Body)
|
||||
if !bytes.Equal(got, goodJSON) {
|
||||
t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReverseProxy_InjectsHeaders(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "secret" {
|
||||
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer secret" {
|
||||
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_EmptySecret(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource(""))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
// Should NOT inject headers when secret is empty
|
||||
if hdr.Get("X-Api-Key") != "" {
|
||||
t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " {
|
||||
t.Fatalf("Authorization should not be set, got: %q", authVal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
// Point proxy to a non-routable address to trigger error
|
||||
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/any")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("want 502, got %d", res.StatusCode)
|
||||
}
|
||||
if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) {
|
||||
t.Fatalf("unexpected body: %s", body)
|
||||
}
|
||||
if ct := res.Header.Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("content-type: want application/json, got %s", ct)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
||||
// Upstream returns gzipped JSON without Content-Encoding header
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write(gzipBytes([]byte(`{"upstream":"ok"}`)))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
expected := []byte(`{"upstream":"ok"}`)
|
||||
if !bytes.Equal(body, expected) {
|
||||
t.Fatalf("want decompressed JSON, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) {
|
||||
// Upstream returns plain JSON
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`{"plain":"json"}`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
expected := []byte(`{"plain":"json"}`)
|
||||
if !bytes.Equal(body, expected) {
|
||||
t.Fatalf("want plain JSON unchanged, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsStreamingResponse(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "sse",
|
||||
header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "chunked_not_streaming",
|
||||
header: http.Header{"Transfer-Encoding": []string{"chunked"}},
|
||||
want: false, // Chunked is transport-level, not streaming
|
||||
},
|
||||
{
|
||||
name: "normal_json",
|
||||
header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
header: http.Header{},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := &http.Response{Header: tc.header}
|
||||
got := isStreamingResponse(resp)
|
||||
if got != tc.want {
|
||||
t.Fatalf("want %v, got %v", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
166
internal/api/modules/amp/routes.go
Normal file
166
internal/api/modules/amp/routes.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only.
|
||||
// Returns 403 Forbidden for non-localhost clients.
|
||||
func localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// Parse the IP to handle both IPv4 and IPv6
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip == nil {
|
||||
log.Warnf("Amp management: invalid client IP %s, denying access", clientIP)
|
||||
c.AbortWithStatusJSON(403, gin.H{
|
||||
"error": "Access denied: management routes restricted to localhost",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if IP is loopback (127.0.0.1 or ::1)
|
||||
if !ip.IsLoopback() {
|
||||
log.Warnf("Amp management: non-localhost IP %s attempted access, denying", clientIP)
|
||||
c.AbortWithStatusJSON(403, gin.H{
|
||||
"error": "Access denied: management routes restricted to localhost",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// noCORSMiddleware disables CORS for management routes to prevent browser-based attacks.
|
||||
// This overwrites any global CORS headers set by the server.
|
||||
func noCORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Remove CORS headers to prevent cross-origin access from browsers
|
||||
c.Header("Access-Control-Allow-Origin", "")
|
||||
c.Header("Access-Control-Allow-Methods", "")
|
||||
c.Header("Access-Control-Allow-Headers", "")
|
||||
c.Header("Access-Control-Allow-Credentials", "")
|
||||
|
||||
// For OPTIONS preflight, deny with 403
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(403)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
ampAPI.Use(noCORSMiddleware())
|
||||
|
||||
// Apply localhost-only restriction if configured
|
||||
if restrictToLocalhost {
|
||||
ampAPI.Use(localhostOnlyMiddleware())
|
||||
log.Info("Amp management routes restricted to localhost only (CORS disabled)")
|
||||
} else {
|
||||
log.Warn("⚠️ Amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
}
|
||||
|
||||
// Management routes - these are proxied directly to Amp upstream
|
||||
ampAPI.Any("/internal", proxyHandler)
|
||||
ampAPI.Any("/internal/*path", proxyHandler)
|
||||
ampAPI.Any("/user", proxyHandler)
|
||||
ampAPI.Any("/user/*path", proxyHandler)
|
||||
ampAPI.Any("/auth", proxyHandler)
|
||||
ampAPI.Any("/auth/*path", proxyHandler)
|
||||
ampAPI.Any("/meta", proxyHandler)
|
||||
ampAPI.Any("/meta/*path", proxyHandler)
|
||||
ampAPI.Any("/ads", proxyHandler)
|
||||
ampAPI.Any("/telemetry", proxyHandler)
|
||||
ampAPI.Any("/telemetry/*path", proxyHandler)
|
||||
ampAPI.Any("/threads", proxyHandler)
|
||||
ampAPI.Any("/threads/*path", proxyHandler)
|
||||
ampAPI.Any("/otel", proxyHandler)
|
||||
ampAPI.Any("/otel/*path", proxyHandler)
|
||||
|
||||
// Google v1beta1 passthrough (Gemini native API)
|
||||
ampAPI.Any("/provider/google/v1beta1/*path", proxyHandler)
|
||||
}
|
||||
|
||||
// registerProviderAliases registers /api/provider/{provider}/... routes
|
||||
// These allow Amp CLI to route requests like:
|
||||
//
|
||||
// /api/provider/openai/v1/chat/completions
|
||||
// /api/provider/anthropic/v1/messages
|
||||
// /api/provider/google/v1beta/models
|
||||
func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
||||
// Create handler instances for different providers
|
||||
openaiHandlers := openai.NewOpenAIAPIHandler(baseHandler)
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||
|
||||
// Provider-specific routes under /api/provider/:provider
|
||||
ampProviders := engine.Group("/api/provider")
|
||||
if auth != nil {
|
||||
ampProviders.Use(auth)
|
||||
}
|
||||
|
||||
provider := ampProviders.Group("/:provider")
|
||||
|
||||
// Dynamic models handler - routes to appropriate provider based on path parameter
|
||||
ampModelsHandler := func(c *gin.Context) {
|
||||
providerName := strings.ToLower(c.Param("provider"))
|
||||
|
||||
switch providerName {
|
||||
case "anthropic":
|
||||
claudeCodeHandlers.ClaudeModels(c)
|
||||
case "google":
|
||||
geminiHandlers.GeminiModels(c)
|
||||
default:
|
||||
// Default to OpenAI-compatible (works for openai, groq, cerebras, etc.)
|
||||
openaiHandlers.OpenAIModels(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
||||
provider.GET("/models", ampModelsHandler)
|
||||
provider.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
||||
provider.POST("/completions", openaiHandlers.Completions)
|
||||
provider.POST("/responses", openaiResponsesHandlers.Responses)
|
||||
|
||||
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
||||
v1Amp := provider.Group("/v1")
|
||||
{
|
||||
v1Amp.GET("/models", ampModelsHandler)
|
||||
|
||||
// OpenAI-compatible endpoints
|
||||
v1Amp.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
||||
v1Amp.POST("/completions", openaiHandlers.Completions)
|
||||
v1Amp.POST("/responses", openaiResponsesHandlers.Responses)
|
||||
|
||||
// Claude/Anthropic-compatible endpoints
|
||||
v1Amp.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
v1Amp.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||
}
|
||||
|
||||
// /v1beta routes (Gemini native API)
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
216
internal/api/modules/amp/routes_test.go
Normal file
216
internal/api/modules/amp/routes_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
)
|
||||
|
||||
func TestRegisterManagementRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Spy to track if proxy handler was called
|
||||
proxyCalled := false
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
proxyCalled = true
|
||||
c.String(200, "proxied")
|
||||
}
|
||||
|
||||
m := &AmpModule{}
|
||||
m.registerManagementRoutes(r, proxyHandler, false) // false = don't restrict to localhost in tests
|
||||
|
||||
managementPaths := []string{
|
||||
"/api/internal",
|
||||
"/api/internal/some/path",
|
||||
"/api/user",
|
||||
"/api/user/profile",
|
||||
"/api/auth",
|
||||
"/api/auth/login",
|
||||
"/api/meta",
|
||||
"/api/telemetry",
|
||||
"/api/threads",
|
||||
"/api/otel",
|
||||
"/api/provider/google/v1beta1/models",
|
||||
}
|
||||
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
t.Fatalf("proxy handler not called for %s", path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Minimal base handler setup (no need to initialize, just check routing)
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
// Track if auth middleware was called
|
||||
authCalled := false
|
||||
authMiddleware := func(c *gin.Context) {
|
||||
authCalled = true
|
||||
c.Header("X-Auth", "ok")
|
||||
// Abort with success to avoid calling the actual handler (which needs full setup)
|
||||
c.AbortWithStatus(http.StatusOK)
|
||||
}
|
||||
|
||||
m := &AmpModule{authMiddleware_: authMiddleware}
|
||||
m.registerProviderAliases(r, base, authMiddleware)
|
||||
|
||||
paths := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/api/provider/openai/models", http.MethodGet},
|
||||
{"/api/provider/anthropic/models", http.MethodGet},
|
||||
{"/api/provider/google/models", http.MethodGet},
|
||||
{"/api/provider/groq/models", http.MethodGet},
|
||||
{"/api/provider/openai/chat/completions", http.MethodPost},
|
||||
{"/api/provider/anthropic/v1/messages", http.MethodPost},
|
||||
{"/api/provider/google/v1beta/models", http.MethodGet},
|
||||
}
|
||||
|
||||
for _, tc := range paths {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
authCalled = false
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("route %s %s not registered", tc.method, tc.path)
|
||||
}
|
||||
if !authCalled {
|
||||
t.Fatalf("auth middleware not executed for %s", tc.path)
|
||||
}
|
||||
if w.Header().Get("X-Auth") != "ok" {
|
||||
t.Fatalf("auth middleware header not set for %s", tc.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
||||
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
||||
|
||||
providers := []string{"openai", "anthropic", "google", "groq", "cerebras"}
|
||||
|
||||
for _, provider := range providers {
|
||||
t.Run(provider, func(t *testing.T) {
|
||||
path := "/api/provider/" + provider + "/models"
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Should not 404
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("models route not found for provider: %s", provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_V1Routes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
||||
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
||||
|
||||
v1Paths := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/api/provider/openai/v1/models", http.MethodGet},
|
||||
{"/api/provider/openai/v1/chat/completions", http.MethodPost},
|
||||
{"/api/provider/openai/v1/completions", http.MethodPost},
|
||||
{"/api/provider/anthropic/v1/messages", http.MethodPost},
|
||||
{"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost},
|
||||
}
|
||||
|
||||
for _, tc := range v1Paths {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("v1 route %s %s not registered", tc.method, tc.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
||||
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
||||
|
||||
v1betaPaths := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/api/provider/google/v1beta/models", http.MethodGet},
|
||||
{"/api/provider/google/v1beta/models/generateContent", http.MethodPost},
|
||||
}
|
||||
|
||||
for _, tc := range v1betaPaths {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) {
|
||||
// Test that routes still register even if auth middleware is nil (fallback behavior)
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
|
||||
m := &AmpModule{authMiddleware_: nil} // No auth middleware
|
||||
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Should still work (with fallback no-op auth)
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Fatal("routes should register even without auth middleware")
|
||||
}
|
||||
}
|
||||
155
internal/api/modules/amp/secret.go
Normal file
155
internal/api/modules/amp/secret.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecretSource provides Amp API keys with configurable precedence and caching
|
||||
type SecretSource interface {
|
||||
Get(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// cachedSecret holds a secret value with expiration
|
||||
type cachedSecret struct {
|
||||
value string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// MultiSourceSecret implements precedence-based secret lookup:
|
||||
// 1. Explicit config value (highest priority)
|
||||
// 2. Environment variable AMP_API_KEY
|
||||
// 3. File-based secret (lowest priority)
|
||||
type MultiSourceSecret struct {
|
||||
explicitKey string
|
||||
envKey string
|
||||
filePath string
|
||||
cacheTTL time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
cache *cachedSecret
|
||||
}
|
||||
|
||||
// NewMultiSourceSecret creates a secret source with precedence and caching
|
||||
func NewMultiSourceSecret(explicitKey string, cacheTTL time.Duration) *MultiSourceSecret {
|
||||
if cacheTTL == 0 {
|
||||
cacheTTL = 5 * time.Minute // Default 5 minute cache
|
||||
}
|
||||
|
||||
home, _ := os.UserHomeDir()
|
||||
filePath := filepath.Join(home, ".local", "share", "amp", "secrets.json")
|
||||
|
||||
return &MultiSourceSecret{
|
||||
explicitKey: strings.TrimSpace(explicitKey),
|
||||
envKey: "AMP_API_KEY",
|
||||
filePath: filePath,
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMultiSourceSecretWithPath creates a secret source with a custom file path (for testing)
|
||||
func NewMultiSourceSecretWithPath(explicitKey string, filePath string, cacheTTL time.Duration) *MultiSourceSecret {
|
||||
if cacheTTL == 0 {
|
||||
cacheTTL = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &MultiSourceSecret{
|
||||
explicitKey: strings.TrimSpace(explicitKey),
|
||||
envKey: "AMP_API_KEY",
|
||||
filePath: filePath,
|
||||
cacheTTL: cacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves the Amp API key using precedence: config > env > file
|
||||
// Results are cached for cacheTTL duration to avoid excessive file reads
|
||||
func (s *MultiSourceSecret) Get(ctx context.Context) (string, error) {
|
||||
// Precedence 1: Explicit config key (highest priority, no caching needed)
|
||||
if s.explicitKey != "" {
|
||||
return s.explicitKey, nil
|
||||
}
|
||||
|
||||
// Precedence 2: Environment variable
|
||||
if envValue := strings.TrimSpace(os.Getenv(s.envKey)); envValue != "" {
|
||||
return envValue, nil
|
||||
}
|
||||
|
||||
// Precedence 3: File-based secret (lowest priority, cached)
|
||||
// Check cache first
|
||||
s.mu.RLock()
|
||||
if s.cache != nil && time.Now().Before(s.cache.expiresAt) {
|
||||
value := s.cache.value
|
||||
s.mu.RUnlock()
|
||||
return value, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// Cache miss or expired - read from file
|
||||
key, err := s.readFromFile()
|
||||
if err != nil {
|
||||
// Cache empty result to avoid repeated file reads on missing files
|
||||
s.updateCache("")
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
s.updateCache(key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// readFromFile reads the Amp API key from the secrets file
|
||||
func (s *MultiSourceSecret) readFromFile() (string, error) {
|
||||
content, err := os.ReadFile(s.filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil // Missing file is not an error, just no key available
|
||||
}
|
||||
return "", fmt.Errorf("failed to read amp secrets from %s: %w", s.filePath, err)
|
||||
}
|
||||
|
||||
var secrets map[string]string
|
||||
if err := json.Unmarshal(content, &secrets); err != nil {
|
||||
return "", fmt.Errorf("failed to parse amp secrets from %s: %w", s.filePath, err)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(secrets["apiKey@https://ampcode.com/"])
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// updateCache updates the cached secret value
|
||||
func (s *MultiSourceSecret) updateCache(value string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cache = &cachedSecret{
|
||||
value: value,
|
||||
expiresAt: time.Now().Add(s.cacheTTL),
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache clears the cached secret, forcing a fresh read on next Get
|
||||
func (s *MultiSourceSecret) InvalidateCache() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cache = nil
|
||||
}
|
||||
|
||||
// StaticSecretSource returns a fixed API key (for testing)
|
||||
type StaticSecretSource struct {
|
||||
key string
|
||||
}
|
||||
|
||||
// NewStaticSecretSource creates a secret source with a fixed key
|
||||
func NewStaticSecretSource(key string) *StaticSecretSource {
|
||||
return &StaticSecretSource{key: strings.TrimSpace(key)}
|
||||
}
|
||||
|
||||
// Get returns the static API key
|
||||
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
280
internal/api/modules/amp/secret_test.go
Normal file
280
internal/api/modules/amp/secret_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
configKey string
|
||||
envKey string
|
||||
fileJSON string
|
||||
want string
|
||||
}{
|
||||
{"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"},
|
||||
{"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
|
||||
{"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
|
||||
{"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
|
||||
{"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
|
||||
{"missing_file_returns_empty", "", "", "", ""},
|
||||
{"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretsPath := filepath.Join(tmpDir, "secrets.json")
|
||||
|
||||
if tc.fileJSON != "" {
|
||||
if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Setenv("AMP_API_KEY", tc.envKey)
|
||||
|
||||
s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond)
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Fatalf("want %q, got %q", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSourceSecret_CacheBehavior(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
|
||||
// Initial value
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond)
|
||||
|
||||
// First read - should return v1
|
||||
got1, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Get failed: %v", err)
|
||||
}
|
||||
if got1 != "v1" {
|
||||
t.Fatalf("expected v1, got %s", got1)
|
||||
}
|
||||
|
||||
// Change file; within TTL we should still see v1 (cached)
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got2, _ := s.Get(ctx)
|
||||
if got2 != "v1" {
|
||||
t.Fatalf("cache hit expected v1, got %s", got2)
|
||||
}
|
||||
|
||||
// After TTL expires, should see v2
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
got3, _ := s.Get(ctx)
|
||||
if got3 != "v2" {
|
||||
t.Fatalf("cache miss expected v2, got %s", got3)
|
||||
}
|
||||
|
||||
// Invalidate forces re-read immediately
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s.InvalidateCache()
|
||||
got4, _ := s.Get(ctx)
|
||||
if got4 != "v3" {
|
||||
t.Fatalf("invalidate expected v3, got %s", got4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiSourceSecret_FileHandling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("missing_file_no_error", func(t *testing.T) {
|
||||
s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond)
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for missing file, got: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
||||
_, err := s.Get(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing_key_in_json", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty string for missing key, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_key_value", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
||||
got, _ := s.Get(ctx)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty after trim, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSourceSecret_Concurrency(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "secrets.json")
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 5*time.Second)
|
||||
ctx := context.Background()
|
||||
|
||||
// Spawn many goroutines calling Get concurrently
|
||||
const goroutines = 50
|
||||
const iterations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
val, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
if val != "concurrent" {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
for err := range errors {
|
||||
t.Errorf("concurrency error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticSecretSource(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("returns_provided_key", func(t *testing.T) {
|
||||
s := NewStaticSecretSource("test-key-123")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "test-key-123" {
|
||||
t.Fatalf("want test-key-123, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("trims_whitespace", func(t *testing.T) {
|
||||
s := NewStaticSecretSource(" test-key ")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "test-key" {
|
||||
t.Fatalf("want test-key, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_string", func(t *testing.T) {
|
||||
s := NewStaticSecretSource("")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Fatalf("want empty string, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
||||
// Test that missing file results are cached to avoid repeated file reads
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "nonexistent.json")
|
||||
|
||||
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
||||
ctx := context.Background()
|
||||
|
||||
// First call - file doesn't exist, should cache empty result
|
||||
got1, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for missing file, got: %v", err)
|
||||
}
|
||||
if got1 != "" {
|
||||
t.Fatalf("expected empty string, got %q", got1)
|
||||
}
|
||||
|
||||
// Create the file now
|
||||
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Second call - should still return empty (cached), not read the new file
|
||||
got2, _ := s.Get(ctx)
|
||||
if got2 != "" {
|
||||
t.Fatalf("cache should return empty, got %q", got2)
|
||||
}
|
||||
|
||||
// After TTL expires, should see the new value
|
||||
time.Sleep(110 * time.Millisecond)
|
||||
got3, _ := s.Get(ctx)
|
||||
if got3 != "new-value" {
|
||||
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
||||
}
|
||||
}
|
||||
92
internal/api/modules/modules.go
Normal file
92
internal/api/modules/modules.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Package modules provides a pluggable routing module system for extending
|
||||
// the API server with optional features without modifying core routing logic.
|
||||
package modules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
)
|
||||
|
||||
// Context encapsulates the dependencies exposed to routing modules during
|
||||
// registration. Modules can use the Gin engine to attach routes, the shared
|
||||
// BaseAPIHandler for constructing SDK-specific handlers, and the resolved
|
||||
// authentication middleware for protecting routes that require API keys.
|
||||
type Context struct {
|
||||
Engine *gin.Engine
|
||||
BaseHandler *handlers.BaseAPIHandler
|
||||
Config *config.Config
|
||||
AuthMiddleware gin.HandlerFunc
|
||||
}
|
||||
|
||||
// RouteModule represents a pluggable routing module that can register routes
|
||||
// and handle configuration updates independently of the core server.
|
||||
//
|
||||
// DEPRECATED: Use RouteModuleV2 for new modules. This interface is kept for
|
||||
// backwards compatibility and will be removed in a future version.
|
||||
type RouteModule interface {
|
||||
// Name returns a human-readable identifier for the module
|
||||
Name() string
|
||||
|
||||
// Register sets up routes and handlers for this module.
|
||||
// It receives the Gin engine, base handlers, and current configuration.
|
||||
// Returns an error if registration fails (errors are logged but don't stop the server).
|
||||
Register(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, cfg *config.Config) error
|
||||
|
||||
// OnConfigUpdated is called when the configuration is reloaded.
|
||||
// Modules can respond to configuration changes here.
|
||||
// Returns an error if the update cannot be applied.
|
||||
OnConfigUpdated(cfg *config.Config) error
|
||||
}
|
||||
|
||||
// RouteModuleV2 represents a pluggable bundle of routes that can integrate with
|
||||
// the API server without modifying its core routing logic. Implementations can
|
||||
// attach routes during Register and react to configuration updates via
|
||||
// OnConfigUpdated.
|
||||
//
|
||||
// This is the preferred interface for new modules. It uses Context for cleaner
|
||||
// dependency injection and supports idempotent registration.
|
||||
type RouteModuleV2 interface {
|
||||
// Name returns a unique identifier for logging and diagnostics.
|
||||
Name() string
|
||||
|
||||
// Register wires the module's routes into the provided Gin engine. Modules
|
||||
// should treat multiple calls as idempotent and avoid duplicate route
|
||||
// registration when invoked more than once.
|
||||
Register(ctx Context) error
|
||||
|
||||
// OnConfigUpdated notifies the module when the server configuration changes
|
||||
// via hot reload. Implementations can refresh cached state or emit warnings.
|
||||
OnConfigUpdated(cfg *config.Config) error
|
||||
}
|
||||
|
||||
// RegisterModule is a helper that registers a module using either the V1 or V2
|
||||
// interface. This allows gradual migration from V1 to V2 without breaking
|
||||
// existing modules.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// ctx := modules.Context{
|
||||
// Engine: engine,
|
||||
// BaseHandler: baseHandler,
|
||||
// Config: cfg,
|
||||
// AuthMiddleware: authMiddleware,
|
||||
// }
|
||||
// if err := modules.RegisterModule(ctx, ampModule); err != nil {
|
||||
// log.Errorf("Failed to register module: %v", err)
|
||||
// }
|
||||
func RegisterModule(ctx Context, mod interface{}) error {
|
||||
// Try V2 interface first (preferred)
|
||||
if v2, ok := mod.(RouteModuleV2); ok {
|
||||
return v2.Register(ctx)
|
||||
}
|
||||
|
||||
// Fall back to V1 interface for backwards compatibility
|
||||
if v1, ok := mod.(RouteModule); ok {
|
||||
return v1.Register(ctx.Engine, ctx.BaseHandler, ctx.Config)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unsupported module type %T (must implement RouteModule or RouteModuleV2)", mod)
|
||||
}
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/access"
|
||||
managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
@@ -261,6 +263,20 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
|
||||
// Setup routes
|
||||
s.setupRoutes()
|
||||
|
||||
// Register Amp module using V2 interface with Context
|
||||
ampModule := ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
|
||||
ctx := modules.Context{
|
||||
Engine: engine,
|
||||
BaseHandler: s.handlers,
|
||||
Config: cfg,
|
||||
AuthMiddleware: AuthMiddleware(accessManager),
|
||||
}
|
||||
if err := modules.RegisterModule(ctx, ampModule); err != nil {
|
||||
log.Errorf("Failed to register Amp module: %v", err)
|
||||
}
|
||||
|
||||
// Apply additional router configurators from options
|
||||
if optionState.routerConfigurator != nil {
|
||||
optionState.routerConfigurator(engine, s.handlers, cfg)
|
||||
}
|
||||
|
||||
111
internal/api/server_test.go
Normal file
111
internal/api/server_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
gin "github.com/gin-gonic/gin"
|
||||
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T) *Server {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
|
||||
cfg := &proxyconfig.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"test-key"},
|
||||
},
|
||||
Port: 0,
|
||||
AuthDir: authDir,
|
||||
Debug: true,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
}
|
||||
|
||||
authManager := auth.NewManager(nil, nil, nil)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
return NewServer(cfg, authManager, accessManager, configPath)
|
||||
}
|
||||
|
||||
func TestAmpProviderModelRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
wantContains string
|
||||
}{
|
||||
{
|
||||
name: "openai root models",
|
||||
path: "/api/provider/openai/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"object":"list"`,
|
||||
},
|
||||
{
|
||||
name: "groq root models",
|
||||
path: "/api/provider/groq/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"object":"list"`,
|
||||
},
|
||||
{
|
||||
name: "openai models",
|
||||
path: "/api/provider/openai/v1/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"object":"list"`,
|
||||
},
|
||||
{
|
||||
name: "anthropic models",
|
||||
path: "/api/provider/anthropic/v1/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"data"`,
|
||||
},
|
||||
{
|
||||
name: "google models v1",
|
||||
path: "/api/provider/google/v1/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"models"`,
|
||||
},
|
||||
{
|
||||
name: "google models v1beta",
|
||||
path: "/api/provider/google/v1beta/models",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContains: `"models"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
req.Header.Set("Authorization", "Bearer test-key")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tc.wantStatus {
|
||||
t.Fatalf("unexpected status code for %s: got %d want %d; body=%s", tc.path, rr.Code, tc.wantStatus, rr.Body.String())
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, tc.wantContains) {
|
||||
t.Fatalf("response body for %s missing %q: %s", tc.path, tc.wantContains, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user