Implements unified model routing

Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach.

This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings.
It provides a more flexible and maintainable routing mechanism by centralizing routing logic.

The changes include:
- Introducing new `routing` package with core routing logic.
- Creating characterization tests to capture existing behavior.
- Implementing model extraction and rewriting.
- Updating AMP module routes to utilize the new routing wrapper.
- Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
This commit is contained in:
이대희
2026-02-01 16:58:06 +09:00
parent 527a269799
commit 9299897e04
14 changed files with 2105 additions and 31 deletions

View File

@@ -86,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
//
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
// This type is kept for backward compatibility and test purposes.
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
@@ -94,6 +98,8 @@ type FallbackHandler struct {
// NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
//
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
@@ -102,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
}
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
//
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
if forceModelMappings == nil {
forceModelMappings = func() bool { return false }

View File

@@ -0,0 +1,326 @@
package amp
import (
"bytes"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil"
"github.com/stretchr/testify/assert"
)
// Characterization tests for fallback_handlers.go using testutil recorders
// These tests capture existing behavior before refactoring to routing layer
func TestCharacterization_LocalProvider(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{
{ID: "test-model-local"},
})
defer reg.UnregisterClient("char-test-local")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with proxy recorder
// Create a test server to act as the proxy target
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
// Create a reverse proxy that forwards to our test server
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
// Assert: local handler called once
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
// Assert: request body model unchanged
assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged")
}
func TestCharacterization_ModelMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the TARGET model (the mapped-to model)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{
{ID: "gpt-4-local"},
})
defer reg.UnregisterClient("char-test-mapped")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create model mapper with a mapping
mapper := NewModelMapper([]config.AmpModelMapping{
{From: "gpt-4-turbo", To: "gpt-4-local"},
})
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Request with original model that gets mapped
body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with mapper
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
}, mapper, func() bool { return false })
// Execute - use handler that returns model in response for rewriter to work
wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel())
wrapped(c)
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping")
// Assert: local handler called once
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
// Assert: request body model was rewritten to mapped model
assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model")
assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model")
// Assert: context has mapped_model key set
mappedModel, exists := handlerRecorder.GetContextKey("mapped_model")
assert.True(t, exists, "context should have mapped_model key")
assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model")
// Assert: response body model rewritten back to original
// The response writer should rewrite model names in the response
responseBody := w.Body.String()
assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name")
}
func TestCharacterization_AmpCreditsProxy(t *testing.T) {
gin.SetMode(gin.TestMode)
// Setup recorders - NO local provider registered, NO mapping configured
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context with CloseNotifier support (required for ReverseProxy)
w := testutil.NewCloseNotifierRecorder()
c, _ := gin.CreateTestContext(w)
// Request with a model that has no local provider and no mapping
body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: proxy called once
assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping")
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
// Assert: local handler NOT called
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy")
// Assert: body forwarded to proxy is original (no rewrite)
assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying")
}
func TestCharacterization_BodyRestore(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{
{ID: "test-model-body"},
})
defer reg.UnregisterClient("char-test-body")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create a complex request body that will be read by the wrapper for model extraction
originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with proxy recorder
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: local handler called (not proxy, since we have a local provider)
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
// Assert: handler receives complete original body
// This verifies that the body was properly restored after the wrapper read it for model extraction
assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction")
}
// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler
// This is a characterization test for the route gating logic in routes.go
func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model (Gemini format uses path-based model extraction)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{
{ID: "gemini-pro"},
})
defer reg.UnregisterClient("char-test-gemini")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create a test server for the proxy
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
// Create fallback handler
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Create the Gemini bridge handler (simulating what routes.go does)
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
// Create router with the same gating logic as routes.go
r := gin.New()
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
// POST with /models/ path -> use Gemini bridge with fallback handler
geminiV1Beta1Handler(c)
return
}
}
// Non-POST or no /models/ in path -> proxy upstream
proxyRecorder.ServeHTTP(c.Writer, c.Request)
})
// Execute: POST request with /models/ in path
body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Assert: local Gemini handler called
assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/")
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path")
}
// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy
// This is a characterization test for the route gating logic in routes.go
func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) {
gin.SetMode(gin.TestMode)
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create a test server for the proxy
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
// Create fallback handler
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Create the Gemini bridge handler
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
// Create router with the same gating logic as routes.go
r := gin.New()
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
geminiV1Beta1Handler(c)
return
}
}
proxyRecorder.ServeHTTP(c.Writer, c.Request)
})
// Execute: GET request (even with /models/ in path)
req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Assert: proxy called
assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests")
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
// Assert: local handler NOT called
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests")
}

View File

@@ -276,6 +276,22 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
return result
}
// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice.
// Safe for concurrent use.
func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]config.AmpModelMapping, 0, len(m.mappings))
for from, to := range m.mappings {
result = append(result, config.AmpModelMapping{
From: from,
To: to,
})
}
return result
}
type regexMapping struct {
re *regexp.Regexp
to string

View File

@@ -5,11 +5,12 @@ import (
"errors"
"net"
"net/http"
"net/http/httputil"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
@@ -234,19 +235,20 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
// If no local OAuth is available, falls back to ampcode.com proxy.
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper, m.forceModelMappings)
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
// Route POST model calls through Gemini bridge with FallbackHandler.
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
// T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper
// Create a dedicated routing wrapper for the Gemini bridge
geminiBridgeWrapper := m.createModelRoutingWrapper()
geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge)
// Route POST model calls through Gemini bridge with ModelRoutingWrapper.
// ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically.
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
// POST with /models/ path -> use Gemini bridge with fallback handler
// FallbackHandler will check provider/mapping and proxy if needed
// POST with /models/ path -> use Gemini bridge with unified routing wrapper
// ModelRoutingWrapper will check provider/mapping and proxy if needed
geminiV1Beta1Handler(c)
return
}
@@ -256,6 +258,41 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
})
}
// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing.
// This is used for testing the new routing implementation (T-021 onwards).
func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper {
// Create a registry - in production this would be populated with actual providers
registry := routing.NewRegistry()
// Create a minimal config with just AmpCode settings
// The Router only needs AmpCode.ModelMappings and OAuthModelAlias
cfg := &config.Config{
AmpCode: func() config.AmpCode {
if m.modelMapper != nil {
return config.AmpCode{
ModelMappings: m.modelMapper.GetMappingsAsConfig(),
}
}
return config.AmpCode{}
}(),
}
// Create router with registry and config
router := routing.NewRouter(registry, cfg)
// Create wrapper with proxy function
proxyFunc := func(c *gin.Context) {
proxy := m.getProxy()
if proxy != nil {
proxy.ServeHTTP(c.Writer, c.Request)
} else {
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
}
}
return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc)
}
// registerProviderAliases registers /api/provider/{provider}/... routes
// These allow Amp CLI to route requests like:
//
@@ -269,12 +306,9 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
// Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper, m.forceModelMappings)
// Create unified routing wrapper (T-021 onwards)
// Replaces FallbackHandler with Router-based unified routing
routingWrapper := m.createModelRoutingWrapper()
// Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider")
@@ -302,33 +336,36 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
}
// Root-level routes (for providers that omit /v1, like groq/cerebras)
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
// T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
// /v1 routes (OpenAI/Claude-compatible endpoints)
v1Amp := provider.Group("/v1")
{
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
// OpenAI-compatible endpoints with fallback
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
// OpenAI-compatible endpoints with ModelRoutingWrapper
// T-021, T-022: Migrated to unified routing wrapper
v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
// Claude/Anthropic-compatible endpoints with fallback
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
// Claude/Anthropic-compatible endpoints with ModelRoutingWrapper
// T-023: Migrated Claude routes to unified routing wrapper
v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages))
v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens))
}
// /v1beta routes (Gemini native API)
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
// T-024: Migrated Gemini v1beta routes to unified routing wrapper
v1betaAmp := provider.Group("/v1beta")
{
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler))
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
}
}