mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
Implements unified model routing
Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach. This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings. It provides a more flexible and maintainable routing mechanism by centralizing routing logic. The changes include: - Introducing new `routing` package with core routing logic. - Creating characterization tests to capture existing behavior. - Implementing model extraction and rewriting. - Updating AMP module routes to utilize the new routing wrapper. - Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
This commit is contained in:
@@ -86,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
|||||||
|
|
||||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||||
// when the model's provider is not available in CLIProxyAPI
|
// when the model's provider is not available in CLIProxyAPI
|
||||||
|
//
|
||||||
|
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
|
||||||
|
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
|
||||||
|
// This type is kept for backward compatibility and test purposes.
|
||||||
type FallbackHandler struct {
|
type FallbackHandler struct {
|
||||||
getProxy func() *httputil.ReverseProxy
|
getProxy func() *httputil.ReverseProxy
|
||||||
modelMapper ModelMapper
|
modelMapper ModelMapper
|
||||||
@@ -94,6 +98,8 @@ type FallbackHandler struct {
|
|||||||
|
|
||||||
// NewFallbackHandler creates a new fallback handler wrapper
|
// NewFallbackHandler creates a new fallback handler wrapper
|
||||||
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
||||||
|
//
|
||||||
|
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
||||||
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
||||||
return &FallbackHandler{
|
return &FallbackHandler{
|
||||||
getProxy: getProxy,
|
getProxy: getProxy,
|
||||||
@@ -102,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||||
|
//
|
||||||
|
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
||||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
||||||
if forceModelMappings == nil {
|
if forceModelMappings == nil {
|
||||||
forceModelMappings = func() bool { return false }
|
forceModelMappings = func() bool { return false }
|
||||||
|
|||||||
@@ -0,0 +1,326 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Characterization tests for fallback_handlers.go using testutil recorders
|
||||||
|
// These tests capture existing behavior before refactoring to routing layer
|
||||||
|
|
||||||
|
func TestCharacterization_LocalProvider(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Register a mock provider for the test model
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{
|
||||||
|
{ID: "test-model-local"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("char-test-local")
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create gin context
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Create fallback handler with proxy recorder
|
||||||
|
// Create a test server to act as the proxy target
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
// Create a reverse proxy that forwards to our test server
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||||
|
wrapped(c)
|
||||||
|
|
||||||
|
// Assert: proxy NOT called
|
||||||
|
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
|
||||||
|
|
||||||
|
// Assert: local handler called once
|
||||||
|
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||||
|
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
|
||||||
|
|
||||||
|
// Assert: request body model unchanged
|
||||||
|
assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCharacterization_ModelMapping(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Register a mock provider for the TARGET model (the mapped-to model)
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{
|
||||||
|
{ID: "gpt-4-local"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("char-test-mapped")
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create model mapper with a mapping
|
||||||
|
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||||
|
{From: "gpt-4-turbo", To: "gpt-4-local"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create gin context
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
// Request with original model that gets mapped
|
||||||
|
body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Create fallback handler with mapper
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
}, mapper, func() bool { return false })
|
||||||
|
|
||||||
|
// Execute - use handler that returns model in response for rewriter to work
|
||||||
|
wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel())
|
||||||
|
wrapped(c)
|
||||||
|
|
||||||
|
// Assert: proxy NOT called
|
||||||
|
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping")
|
||||||
|
|
||||||
|
// Assert: local handler called once
|
||||||
|
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||||
|
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
|
||||||
|
|
||||||
|
// Assert: request body model was rewritten to mapped model
|
||||||
|
assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model")
|
||||||
|
assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model")
|
||||||
|
|
||||||
|
// Assert: context has mapped_model key set
|
||||||
|
mappedModel, exists := handlerRecorder.GetContextKey("mapped_model")
|
||||||
|
assert.True(t, exists, "context should have mapped_model key")
|
||||||
|
assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model")
|
||||||
|
|
||||||
|
// Assert: response body model rewritten back to original
|
||||||
|
// The response writer should rewrite model names in the response
|
||||||
|
responseBody := w.Body.String()
|
||||||
|
assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCharacterization_AmpCreditsProxy(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Setup recorders - NO local provider registered, NO mapping configured
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create gin context with CloseNotifier support (required for ReverseProxy)
|
||||||
|
w := testutil.NewCloseNotifierRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
// Request with a model that has no local provider and no mapping
|
||||||
|
body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Create fallback handler
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||||
|
wrapped(c)
|
||||||
|
|
||||||
|
// Assert: proxy called once
|
||||||
|
assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping")
|
||||||
|
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
|
||||||
|
|
||||||
|
// Assert: local handler NOT called
|
||||||
|
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy")
|
||||||
|
|
||||||
|
// Assert: body forwarded to proxy is original (no rewrite)
|
||||||
|
assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCharacterization_BodyRestore(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Register a mock provider for the test model
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{
|
||||||
|
{ID: "test-model-body"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("char-test-body")
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create gin context
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
// Create a complex request body that will be read by the wrapper for model extraction
|
||||||
|
originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Create fallback handler with proxy recorder
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||||
|
wrapped(c)
|
||||||
|
|
||||||
|
// Assert: local handler called (not proxy, since we have a local provider)
|
||||||
|
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||||
|
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
|
||||||
|
|
||||||
|
// Assert: handler receives complete original body
|
||||||
|
// This verifies that the body was properly restored after the wrapper read it for model extraction
|
||||||
|
assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler
|
||||||
|
// This is a characterization test for the route gating logic in routes.go
|
||||||
|
func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Register a mock provider for the test model (Gemini format uses path-based model extraction)
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{
|
||||||
|
{ID: "gemini-pro"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("char-test-gemini")
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create a test server for the proxy
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
// Create fallback handler
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the Gemini bridge handler (simulating what routes.go does)
|
||||||
|
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
|
||||||
|
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
|
||||||
|
|
||||||
|
// Create router with the same gating logic as routes.go
|
||||||
|
r := gin.New()
|
||||||
|
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||||
|
if c.Request.Method == "POST" {
|
||||||
|
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||||
|
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||||
|
geminiV1Beta1Handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Non-POST or no /models/ in path -> proxy upstream
|
||||||
|
proxyRecorder.ServeHTTP(c.Writer, c.Request)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute: POST request with /models/ in path
|
||||||
|
body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Assert: local Gemini handler called
|
||||||
|
assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/")
|
||||||
|
|
||||||
|
// Assert: proxy NOT called
|
||||||
|
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy
|
||||||
|
// This is a characterization test for the route gating logic in routes.go
|
||||||
|
func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create a test server for the proxy
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
// Create fallback handler
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the Gemini bridge handler
|
||||||
|
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
|
||||||
|
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
|
||||||
|
|
||||||
|
// Create router with the same gating logic as routes.go
|
||||||
|
r := gin.New()
|
||||||
|
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||||
|
if c.Request.Method == "POST" {
|
||||||
|
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||||
|
geminiV1Beta1Handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
proxyRecorder.ServeHTTP(c.Writer, c.Request)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute: GET request (even with /models/ in path)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Assert: proxy called
|
||||||
|
assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests")
|
||||||
|
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
|
||||||
|
|
||||||
|
// Assert: local handler NOT called
|
||||||
|
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests")
|
||||||
|
}
|
||||||
@@ -276,6 +276,22 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice.
|
||||||
|
// Safe for concurrent use.
|
||||||
|
func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]config.AmpModelMapping, 0, len(m.mappings))
|
||||||
|
for from, to := range m.mappings {
|
||||||
|
result = append(result, config.AmpModelMapping{
|
||||||
|
From: from,
|
||||||
|
To: to,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
type regexMapping struct {
|
type regexMapping struct {
|
||||||
re *regexp.Regexp
|
re *regexp.Regexp
|
||||||
to string
|
to string
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||||
@@ -234,19 +235,20 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
||||||
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
|
||||||
return m.getProxy()
|
|
||||||
}, m.modelMapper, m.forceModelMappings)
|
|
||||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
|
||||||
|
|
||||||
// Route POST model calls through Gemini bridge with FallbackHandler.
|
// T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper
|
||||||
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
|
// Create a dedicated routing wrapper for the Gemini bridge
|
||||||
|
geminiBridgeWrapper := m.createModelRoutingWrapper()
|
||||||
|
geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge)
|
||||||
|
|
||||||
|
// Route POST model calls through Gemini bridge with ModelRoutingWrapper.
|
||||||
|
// ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically.
|
||||||
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
||||||
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||||
if c.Request.Method == "POST" {
|
if c.Request.Method == "POST" {
|
||||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||||
// POST with /models/ path -> use Gemini bridge with fallback handler
|
// POST with /models/ path -> use Gemini bridge with unified routing wrapper
|
||||||
// FallbackHandler will check provider/mapping and proxy if needed
|
// ModelRoutingWrapper will check provider/mapping and proxy if needed
|
||||||
geminiV1Beta1Handler(c)
|
geminiV1Beta1Handler(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -256,6 +258,41 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing.
|
||||||
|
// This is used for testing the new routing implementation (T-021 onwards).
|
||||||
|
func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper {
|
||||||
|
// Create a registry - in production this would be populated with actual providers
|
||||||
|
registry := routing.NewRegistry()
|
||||||
|
|
||||||
|
// Create a minimal config with just AmpCode settings
|
||||||
|
// The Router only needs AmpCode.ModelMappings and OAuthModelAlias
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: func() config.AmpCode {
|
||||||
|
if m.modelMapper != nil {
|
||||||
|
return config.AmpCode{
|
||||||
|
ModelMappings: m.modelMapper.GetMappingsAsConfig(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config.AmpCode{}
|
||||||
|
}(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create router with registry and config
|
||||||
|
router := routing.NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Create wrapper with proxy function
|
||||||
|
proxyFunc := func(c *gin.Context) {
|
||||||
|
proxy := m.getProxy()
|
||||||
|
if proxy != nil {
|
||||||
|
proxy.ServeHTTP(c.Writer, c.Request)
|
||||||
|
} else {
|
||||||
|
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc)
|
||||||
|
}
|
||||||
|
|
||||||
// registerProviderAliases registers /api/provider/{provider}/... routes
|
// registerProviderAliases registers /api/provider/{provider}/... routes
|
||||||
// These allow Amp CLI to route requests like:
|
// These allow Amp CLI to route requests like:
|
||||||
//
|
//
|
||||||
@@ -269,12 +306,9 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
||||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||||
|
|
||||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
// Create unified routing wrapper (T-021 onwards)
|
||||||
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
|
// Replaces FallbackHandler with Router-based unified routing
|
||||||
// Also includes model mapping support for routing unavailable models to alternatives
|
routingWrapper := m.createModelRoutingWrapper()
|
||||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
|
||||||
return m.getProxy()
|
|
||||||
}, m.modelMapper, m.forceModelMappings)
|
|
||||||
|
|
||||||
// Provider-specific routes under /api/provider/:provider
|
// Provider-specific routes under /api/provider/:provider
|
||||||
ampProviders := engine.Group("/api/provider")
|
ampProviders := engine.Group("/api/provider")
|
||||||
@@ -302,33 +336,36 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
||||||
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
|
// T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing
|
||||||
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
|
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
|
||||||
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
|
||||||
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
|
||||||
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
|
||||||
|
|
||||||
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
||||||
v1Amp := provider.Group("/v1")
|
v1Amp := provider.Group("/v1")
|
||||||
{
|
{
|
||||||
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
|
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
|
||||||
|
|
||||||
// OpenAI-compatible endpoints with fallback
|
// OpenAI-compatible endpoints with ModelRoutingWrapper
|
||||||
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
// T-021, T-022: Migrated to unified routing wrapper
|
||||||
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
|
||||||
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
|
||||||
|
v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
|
||||||
|
|
||||||
// Claude/Anthropic-compatible endpoints with fallback
|
// Claude/Anthropic-compatible endpoints with ModelRoutingWrapper
|
||||||
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
|
// T-023: Migrated Claude routes to unified routing wrapper
|
||||||
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
|
v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages))
|
||||||
|
v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
// /v1beta routes (Gemini native API)
|
// /v1beta routes (Gemini native API)
|
||||||
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
|
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
|
||||||
|
// T-024: Migrated Gemini v1beta routes to unified routing wrapper
|
||||||
v1betaAmp := provider.Group("/v1beta")
|
v1betaAmp := provider.Group("/v1beta")
|
||||||
{
|
{
|
||||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler))
|
||||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
59
internal/routing/extractor.go
Normal file
59
internal/routing/extractor.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelExtractor extracts model names from request data.
|
||||||
|
type ModelExtractor interface {
|
||||||
|
// Extract returns the model name from the request body and gin parameters.
|
||||||
|
// The ginParams map contains route parameters like "action" and "path".
|
||||||
|
Extract(body []byte, ginParams map[string]string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelExtractor is the standard implementation of ModelExtractor.
|
||||||
|
type DefaultModelExtractor struct{}
|
||||||
|
|
||||||
|
// NewModelExtractor creates a new DefaultModelExtractor.
|
||||||
|
func NewModelExtractor() *DefaultModelExtractor {
|
||||||
|
return &DefaultModelExtractor{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract extracts the model name from the request.
|
||||||
|
// It checks in order:
|
||||||
|
// 1. JSON body "model" field (OpenAI, Claude format)
|
||||||
|
// 2. "action" parameter for Gemini standard format (e.g., "gemini-pro:generateContent")
|
||||||
|
// 3. "path" parameter for AMP CLI Gemini format (e.g., "/publishers/google/models/gemini-3-pro:streamGenerateContent")
|
||||||
|
func (e *DefaultModelExtractor) Extract(body []byte, ginParams map[string]string) (string, error) {
|
||||||
|
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||||
|
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
||||||
|
return result.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For Gemini requests, model is in the URL path
|
||||||
|
// Standard format: /models/{model}:generateContent -> :action parameter
|
||||||
|
if action, ok := ginParams["action"]; ok && action != "" {
|
||||||
|
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
|
||||||
|
parts := strings.Split(action, ":")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
return parts[0], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
|
||||||
|
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||||
|
if path, ok := ginParams["path"]; ok && path != "" {
|
||||||
|
// Look for /models/{model}:method pattern
|
||||||
|
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
||||||
|
modelPart := path[idx+8:] // Skip "/models/"
|
||||||
|
// Split by colon to get model name
|
||||||
|
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
||||||
|
return modelPart[:colonIdx], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
214
internal/routing/extractor_test.go
Normal file
214
internal/routing/extractor_test.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModelExtractor_ExtractFromJSONBody(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extract from JSON body with model field",
|
||||||
|
body: []byte(`{"model":"gpt-4.1"}`),
|
||||||
|
want: "gpt-4.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract claude model from JSON body",
|
||||||
|
body: []byte(`{"model":"claude-3-5-sonnet-20241022"}`),
|
||||||
|
want: "claude-3-5-sonnet-20241022",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract with additional fields",
|
||||||
|
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`),
|
||||||
|
want: "gpt-4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty body returns empty",
|
||||||
|
body: []byte{},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no model field returns empty",
|
||||||
|
body: []byte(`{"messages":[]}`),
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model is not string returns empty",
|
||||||
|
body: []byte(`{"model":123}`),
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := extractor.Extract(tt.body, nil)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelExtractor_ExtractFromGeminiActionParam(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
ginParams map[string]string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extract from action parameter - gemini-pro",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"action": "gemini-pro:generateContent"},
|
||||||
|
want: "gemini-pro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract from action parameter - gemini-ultra",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"action": "gemini-ultra:chat"},
|
||||||
|
want: "gemini-ultra",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty action returns empty",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"action": ""},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "action without colon returns full value",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"action": "gemini-model"},
|
||||||
|
want: "gemini-model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelExtractor_ExtractFromGeminiV1Beta1Path(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
ginParams map[string]string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extract from v1beta1 path - gemini-3-pro",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro:streamGenerateContent"},
|
||||||
|
want: "gemini-3-pro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract from v1beta1 path with preview",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro-preview:generateContent"},
|
||||||
|
want: "gemini-3-pro-preview",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path without models segment returns empty",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": "/publishers/google/gemini-3-pro:streamGenerateContent"},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty path returns empty",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": ""},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path with /models/ but no colon returns empty",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro"},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelExtractor_ExtractPriority(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
// JSON body takes priority over gin params
|
||||||
|
t.Run("JSON body takes priority over action param", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-4"}`)
|
||||||
|
params := map[string]string{"action": "gemini-pro:generateContent"}
|
||||||
|
got, err := extractor.Extract(body, params)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "gpt-4", got)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Action param takes priority over path param
|
||||||
|
t.Run("action param takes priority over path param", func(t *testing.T) {
|
||||||
|
body := []byte(`{}`)
|
||||||
|
params := map[string]string{
|
||||||
|
"action": "gemini-action:generate",
|
||||||
|
"path": "/publishers/google/models/gemini-path:streamGenerateContent",
|
||||||
|
}
|
||||||
|
got, err := extractor.Extract(body, params)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "gemini-action", got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelExtractor_NoModelFound(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
ginParams map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty body and no params",
|
||||||
|
body: []byte{},
|
||||||
|
ginParams: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "body without model and no params",
|
||||||
|
body: []byte(`{"messages":[]}`),
|
||||||
|
ginParams: map[string]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "irrelevant params only",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"other": "value"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
159
internal/routing/rewriter.go
Normal file
159
internal/routing/rewriter.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelRewriter handles model name rewriting in requests and responses.
|
||||||
|
type ModelRewriter interface {
|
||||||
|
// RewriteRequestBody rewrites the model field in a JSON request body.
|
||||||
|
// Returns the modified body or the original if no rewrite was needed.
|
||||||
|
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
|
||||||
|
|
||||||
|
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
|
||||||
|
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
|
||||||
|
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelRewriter is the standard implementation of ModelRewriter.
|
||||||
|
type DefaultModelRewriter struct{}
|
||||||
|
|
||||||
|
// NewModelRewriter creates a new DefaultModelRewriter.
|
||||||
|
func NewModelRewriter() *DefaultModelRewriter {
|
||||||
|
return &DefaultModelRewriter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteRequestBody replaces the model name in a JSON request body.
|
||||||
|
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
|
||||||
|
if !gjson.GetBytes(body, "model").Exists() {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
result, err := sjson.SetBytes(body, "model", newModel)
|
||||||
|
if err != nil {
|
||||||
|
return body, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapResponseWriter wraps a response writer to rewrite model names.
|
||||||
|
// The cleanup function must be called after the handler completes to flush any buffered data.
|
||||||
|
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
|
||||||
|
rw := &responseRewriter{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: &bytes.Buffer{},
|
||||||
|
requestedModel: requestedModel,
|
||||||
|
resolvedModel: resolvedModel,
|
||||||
|
}
|
||||||
|
return rw, func() { rw.flush() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
|
||||||
|
type responseRewriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
requestedModel string
|
||||||
|
resolvedModel string
|
||||||
|
isStreaming bool
|
||||||
|
wroteHeader bool
|
||||||
|
flushed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write intercepts response writes and buffers them for model name replacement.
|
||||||
|
func (rw *responseRewriter) Write(data []byte) (int, error) {
|
||||||
|
// Ensure header is written
|
||||||
|
if !rw.wroteHeader {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect streaming on first write
|
||||||
|
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||||
|
contentType := rw.Header().Get("Content-Type")
|
||||||
|
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||||
|
strings.Contains(contentType, "stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rw.isStreaming {
|
||||||
|
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||||
|
if err == nil {
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return rw.body.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||||
|
func (rw *responseRewriter) WriteHeader(code int) {
|
||||||
|
if !rw.wroteHeader {
|
||||||
|
rw.wroteHeader = true
|
||||||
|
rw.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush writes the buffered response with model names rewritten.
|
||||||
|
func (rw *responseRewriter) flush() {
|
||||||
|
if rw.flushed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.flushed = true
|
||||||
|
|
||||||
|
if rw.isStreaming {
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rw.body.Len() > 0 {
|
||||||
|
data := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||||
|
if _, err := rw.ResponseWriter.Write(data); err != nil {
|
||||||
|
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelFieldPaths lists all JSON paths where model name may appear.
|
||||||
|
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||||
|
|
||||||
|
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
|
||||||
|
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
|
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range modelFieldPaths {
|
||||||
|
if gjson.GetBytes(data, path).Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteStreamChunk rewrites model names in SSE stream chunks.
|
||||||
|
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||||
|
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE format: "data: {json}\n\n"
|
||||||
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
|
for i, line := range lines {
|
||||||
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
// Rewrite JSON in the data line
|
||||||
|
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||||
|
lines[i] = append([]byte("data: "), rewritten...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Join(lines, []byte("\n"))
|
||||||
|
}
|
||||||
342
internal/routing/rewriter_test.go
Normal file
342
internal/routing/rewriter_test.go
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModelRewriter_RewriteRequestBody(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
newModel string
|
||||||
|
wantModel string
|
||||||
|
wantChange bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "rewrites model field in JSON body",
|
||||||
|
body: []byte(`{"model":"gpt-4.1","messages":[]}`),
|
||||||
|
newModel: "claude-local",
|
||||||
|
wantModel: "claude-local",
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rewrites with empty body returns empty",
|
||||||
|
body: []byte{},
|
||||||
|
newModel: "gpt-4",
|
||||||
|
wantModel: "",
|
||||||
|
wantChange: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles missing model field gracefully",
|
||||||
|
body: []byte(`{"messages":[{"role":"user"}]}`),
|
||||||
|
newModel: "gpt-4",
|
||||||
|
wantModel: "",
|
||||||
|
wantChange: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preserves other fields when rewriting",
|
||||||
|
body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`),
|
||||||
|
newModel: "new-model",
|
||||||
|
wantModel: "new-model",
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles nested JSON structure",
|
||||||
|
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`),
|
||||||
|
newModel: "claude-3-opus",
|
||||||
|
wantModel: "claude-3-opus",
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tt.wantChange {
|
||||||
|
assert.NotEqual(t, string(tt.body), string(result), "body should have been modified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantModel != "" {
|
||||||
|
// Parse result and check model field
|
||||||
|
model, _ := NewModelExtractor().Extract(result, nil)
|
||||||
|
assert.Equal(t, tt.wantModel, model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRewriter_WrapResponseWriter(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
t.Run("response writer wraps without error", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
require.NotNil(t, wrapped)
|
||||||
|
require.NotNil(t, cleanup)
|
||||||
|
defer cleanup()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rewrites model in non-streaming response", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
// Write a response with the resolved model
|
||||||
|
response := []byte(`{"model":"claude-local","content":"hello"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := wrapped.Write(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Cleanup triggers the rewrite
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
// Check the response was rewritten to the requested model
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||||
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no-op when requested equals resolved", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"gpt-4","content":"hello"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := wrapped.Write(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rewrites modelVersion field", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"modelVersion":"claude-local","content":"hello"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := wrapped.Write(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"modelVersion":"gpt-4"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles streaming responses", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
// Set streaming content type
|
||||||
|
wrapped.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
|
// Write SSE chunks with resolved model
|
||||||
|
chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n")
|
||||||
|
_, err := wrapped.Write(chunk1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n")
|
||||||
|
_, err = wrapped.Write(chunk2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
// For streaming, data is written immediately with rewrites
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||||
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty body handled gracefully", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
// Don't write anything
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Empty(t, body)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves other JSON fields", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := wrapped.Write(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"temperature":0.7`)
|
||||||
|
assert.Contains(t, string(body), `"prompt_tokens":10`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_ImplementsInterfaces(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Should implement http.ResponseWriter
|
||||||
|
assert.Implements(t, (*http.ResponseWriter)(nil), wrapped)
|
||||||
|
|
||||||
|
// Should preserve header access
|
||||||
|
wrapped.Header().Set("X-Custom", "value")
|
||||||
|
assert.Equal(t, "value", recorder.Header().Get("X-Custom"))
|
||||||
|
|
||||||
|
// Should write status
|
||||||
|
wrapped.WriteHeader(http.StatusCreated)
|
||||||
|
assert.Equal(t, http.StatusCreated, recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_Flush(t *testing.T) {
|
||||||
|
t.Run("flush writes buffered content", func(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"claude-local","content":"test"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
wrapped.Write(response)
|
||||||
|
|
||||||
|
// Before cleanup, response should be empty (buffered)
|
||||||
|
assert.Empty(t, recorder.Body.Bytes())
|
||||||
|
|
||||||
|
// After cleanup, response should be written
|
||||||
|
cleanup()
|
||||||
|
assert.NotEmpty(t, recorder.Body.Bytes())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple flush calls are safe", func(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"claude-local"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
wrapped.Write(response)
|
||||||
|
|
||||||
|
// First cleanup
|
||||||
|
cleanup()
|
||||||
|
firstBody := recorder.Body.Bytes()
|
||||||
|
|
||||||
|
// Second cleanup should not write again
|
||||||
|
cleanup()
|
||||||
|
secondBody := recorder.Body.Bytes()
|
||||||
|
|
||||||
|
assert.Equal(t, firstBody, secondBody)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_StreamingWithDataLines(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
wrapped.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
|
// SSE format with multiple data lines
|
||||||
|
chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n")
|
||||||
|
wrapped.Write(chunk)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
// Both data lines should have model rewritten
|
||||||
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||||
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRewriter_RoundTrip(t *testing.T) {
|
||||||
|
// Simulate a full request -> response cycle with model rewriting
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
// Step 1: Rewrite request body
|
||||||
|
originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify request was rewritten
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
requestModel, _ := extractor.Extract(rewrittenRequest, nil)
|
||||||
|
assert.Equal(t, "claude-local", requestModel)
|
||||||
|
|
||||||
|
// Step 2: Simulate response with resolved model
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
wrapped.Write(response)
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
// Verify response was rewritten back
|
||||||
|
body, _ := io.ReadAll(recorder.Result().Body)
|
||||||
|
responseModel, _ := extractor.Extract(body, nil)
|
||||||
|
assert.Equal(t, "gpt-4", responseModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRewriter_NonJSONBody(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
// Binary/non-JSON body should be returned unchanged
|
||||||
|
body := []byte{0x00, 0x01, 0x02, 0x03}
|
||||||
|
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, body, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRewriter_InvalidJSON(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
// Invalid JSON without model field should be returned unchanged
|
||||||
|
body := []byte(`not valid json`)
|
||||||
|
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, body, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_StatusCodePreserved(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
wrapped.WriteHeader(http.StatusAccepted)
|
||||||
|
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusAccepted, recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_HeaderFlushed(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
wrapped.Header().Set("X-Request-ID", "abc123")
|
||||||
|
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
result := recorder.Result()
|
||||||
|
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
|
||||||
|
assert.Equal(t, "abc123", result.Header.Get("X-Request-ID"))
|
||||||
|
}
|
||||||
@@ -31,15 +31,17 @@ func NewRouter(registry *Registry, cfg *config.Config) *Router {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoutingDecision contains the resolved routing information.
|
// LegacyRoutingDecision contains the resolved routing information.
|
||||||
type RoutingDecision struct {
|
// Deprecated: Will be replaced by RoutingDecision from types.go in T-013.
|
||||||
|
type LegacyRoutingDecision struct {
|
||||||
RequestedModel string // Original model from request
|
RequestedModel string // Original model from request
|
||||||
ResolvedModel string // After model-mappings
|
ResolvedModel string // After model-mappings
|
||||||
Candidates []ProviderCandidate // Ordered list of providers to try
|
Candidates []ProviderCandidate // Ordered list of providers to try
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve determines the routing decision for the requested model.
|
// Resolve determines the routing decision for the requested model.
|
||||||
func (r *Router) Resolve(requestedModel string) *RoutingDecision {
|
// Deprecated: Will be updated to use RoutingRequest and return *RoutingDecision in T-013.
|
||||||
|
func (r *Router) Resolve(requestedModel string) *LegacyRoutingDecision {
|
||||||
// 1. Extract thinking suffix
|
// 1. Extract thinking suffix
|
||||||
suffixResult := thinking.ParseSuffix(requestedModel)
|
suffixResult := thinking.ParseSuffix(requestedModel)
|
||||||
baseModel := suffixResult.ModelName
|
baseModel := suffixResult.ModelName
|
||||||
@@ -60,13 +62,151 @@ func (r *Router) Resolve(requestedModel string) *RoutingDecision {
|
|||||||
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
||||||
})
|
})
|
||||||
|
|
||||||
return &RoutingDecision{
|
return &LegacyRoutingDecision{
|
||||||
RequestedModel: requestedModel,
|
RequestedModel: requestedModel,
|
||||||
ResolvedModel: targetModel,
|
ResolvedModel: targetModel,
|
||||||
Candidates: candidates,
|
Candidates: candidates,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveV2 determines the routing decision for a routing request.
|
||||||
|
// It uses the new RoutingRequest and RoutingDecision types.
|
||||||
|
func (r *Router) ResolveV2(req RoutingRequest) *RoutingDecision {
|
||||||
|
// 1. Extract thinking suffix
|
||||||
|
suffixResult := thinking.ParseSuffix(req.RequestedModel)
|
||||||
|
baseModel := suffixResult.ModelName
|
||||||
|
thinkingSuffix := ""
|
||||||
|
if suffixResult.HasSuffix {
|
||||||
|
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check for local providers
|
||||||
|
localCandidates := r.findLocalCandidates(baseModel, suffixResult)
|
||||||
|
|
||||||
|
// 3. Apply model-mappings if needed
|
||||||
|
mappedModel := r.applyMappings(baseModel)
|
||||||
|
mappingCandidates := r.findLocalCandidates(mappedModel, suffixResult)
|
||||||
|
|
||||||
|
// 4. Determine route type based on preferences and availability
|
||||||
|
var decision *RoutingDecision
|
||||||
|
|
||||||
|
if req.ForceModelMapping && mappedModel != baseModel && len(mappingCandidates) > 0 {
|
||||||
|
// FORCE MODE: Use mapping even if local provider exists
|
||||||
|
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
|
||||||
|
} else if req.PreferLocalProvider && len(localCandidates) > 0 {
|
||||||
|
// DEFAULT MODE with local preference: Use local provider first
|
||||||
|
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
|
||||||
|
} else if len(localCandidates) > 0 {
|
||||||
|
// DEFAULT MODE: Local provider available
|
||||||
|
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
|
||||||
|
} else if mappedModel != baseModel && len(mappingCandidates) > 0 {
|
||||||
|
// DEFAULT MODE: No local provider, but mapping available
|
||||||
|
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
|
||||||
|
} else {
|
||||||
|
// No local provider, no mapping - use amp credits proxy
|
||||||
|
decision = &RoutingDecision{
|
||||||
|
RouteType: RouteTypeAmpCredits,
|
||||||
|
ResolvedModel: req.RequestedModel,
|
||||||
|
ShouldProxy: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return decision
|
||||||
|
}
|
||||||
|
|
||||||
|
// findLocalCandidates finds local provider candidates for a model.
|
||||||
|
func (r *Router) findLocalCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
|
||||||
|
var candidates []ProviderCandidate
|
||||||
|
|
||||||
|
for _, p := range r.registry.All() {
|
||||||
|
if !p.SupportsModel(model) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply thinking suffix if needed
|
||||||
|
actualModel := model
|
||||||
|
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
|
||||||
|
actualModel = model + "(" + suffixResult.RawSuffix + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Available(actualModel) {
|
||||||
|
candidates = append(candidates, ProviderCandidate{
|
||||||
|
Provider: p,
|
||||||
|
Model: actualModel,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by priority
|
||||||
|
sort.Slice(candidates, func(i, j int) bool {
|
||||||
|
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
||||||
|
})
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildLocalProviderDecision creates a decision for local provider routing.
|
||||||
|
func (r *Router) buildLocalProviderDecision(requestedModel string, candidates []ProviderCandidate, thinkingSuffix string) *RoutingDecision {
|
||||||
|
resolvedModel := requestedModel
|
||||||
|
if thinkingSuffix != "" {
|
||||||
|
// Ensure thinking suffix is preserved
|
||||||
|
sr := thinking.ParseSuffix(requestedModel)
|
||||||
|
if !sr.HasSuffix {
|
||||||
|
resolvedModel = requestedModel + thinkingSuffix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fallbackModels []string
|
||||||
|
if len(candidates) > 1 {
|
||||||
|
for _, c := range candidates[1:] {
|
||||||
|
fallbackModels = append(fallbackModels, c.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RoutingDecision{
|
||||||
|
RouteType: RouteTypeLocalProvider,
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ProviderName: candidates[0].Provider.Name(),
|
||||||
|
FallbackModels: fallbackModels,
|
||||||
|
ShouldProxy: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildMappingDecision creates a decision for model mapping routing.
|
||||||
|
func (r *Router) buildMappingDecision(requestedModel, mappedModel string, candidates []ProviderCandidate, thinkingSuffix string, fallbackCandidates []ProviderCandidate) *RoutingDecision {
|
||||||
|
// Apply thinking suffix to resolved model if needed
|
||||||
|
resolvedModel := mappedModel
|
||||||
|
if thinkingSuffix != "" {
|
||||||
|
sr := thinking.ParseSuffix(mappedModel)
|
||||||
|
if !sr.HasSuffix {
|
||||||
|
resolvedModel = mappedModel + thinkingSuffix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fallbackModels []string
|
||||||
|
for _, c := range fallbackCandidates {
|
||||||
|
fallbackModels = append(fallbackModels, c.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also add oauth aliases as fallbacks
|
||||||
|
baseMapped := thinking.ParseSuffix(mappedModel).ModelName
|
||||||
|
for _, alias := range r.oauthAliases[strings.ToLower(baseMapped)] {
|
||||||
|
// Check if this alias has providers
|
||||||
|
aliasCandidates := r.findLocalCandidates(alias, thinking.SuffixResult{ModelName: alias})
|
||||||
|
for _, c := range aliasCandidates {
|
||||||
|
fallbackModels = append(fallbackModels, c.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RoutingDecision{
|
||||||
|
RouteType: RouteTypeModelMapping,
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ProviderName: candidates[0].Provider.Name(),
|
||||||
|
FallbackModels: fallbackModels,
|
||||||
|
ShouldProxy: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// applyMappings applies model-mappings configuration.
|
// applyMappings applies model-mappings configuration.
|
||||||
func (r *Router) applyMappings(model string) string {
|
func (r *Router) applyMappings(model string) string {
|
||||||
key := strings.ToLower(strings.TrimSpace(model))
|
key := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
|||||||
245
internal/routing/router_v2_test.go
Normal file
245
internal/routing/router_v2_test.go
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRouter_DefaultMode_PrefersLocal(t *testing.T) {
|
||||||
|
// Setup: Create a router with a mock provider that supports "gpt-4"
|
||||||
|
registry := NewRegistry()
|
||||||
|
mockProvider := &MockProvider{
|
||||||
|
name: "openai",
|
||||||
|
supportedModels: []string{"gpt-4"},
|
||||||
|
available: true,
|
||||||
|
priority: 1,
|
||||||
|
}
|
||||||
|
registry.Register(mockProvider)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "gpt-4", To: "claude-local"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request gpt-4 when local provider exists
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "gpt-4",
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Should return LOCAL_PROVIDER, not MODEL_MAPPING
|
||||||
|
assert.Equal(t, RouteTypeLocalProvider, decision.RouteType)
|
||||||
|
assert.Equal(t, "gpt-4", decision.ResolvedModel)
|
||||||
|
assert.Equal(t, "openai", decision.ProviderName)
|
||||||
|
assert.False(t, decision.ShouldProxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DefaultMode_MapsWhenNoLocal(t *testing.T) {
|
||||||
|
// Setup: Create a router with NO provider for "gpt-4" but a mapping to "claude-local"
|
||||||
|
// which has a provider
|
||||||
|
registry := NewRegistry()
|
||||||
|
mockProvider := &MockProvider{
|
||||||
|
name: "anthropic",
|
||||||
|
supportedModels: []string{"claude-local"},
|
||||||
|
available: true,
|
||||||
|
priority: 1,
|
||||||
|
}
|
||||||
|
registry.Register(mockProvider)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "gpt-4", To: "claude-local"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request gpt-4 when no local provider exists, but mapping exists
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "gpt-4",
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Should return MODEL_MAPPING
|
||||||
|
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||||
|
assert.Equal(t, "claude-local", decision.ResolvedModel)
|
||||||
|
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||||
|
assert.False(t, decision.ShouldProxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DefaultMode_AmpCreditsWhenNoLocalOrMapping(t *testing.T) {
|
||||||
|
// Setup: Create a router with no providers and no mappings
|
||||||
|
registry := NewRegistry()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request a model with no local provider and no mapping
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "unknown-model",
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Should return AMP_CREDITS with ShouldProxy=true
|
||||||
|
assert.Equal(t, RouteTypeAmpCredits, decision.RouteType)
|
||||||
|
assert.Equal(t, "unknown-model", decision.ResolvedModel)
|
||||||
|
assert.True(t, decision.ShouldProxy)
|
||||||
|
assert.Empty(t, decision.ProviderName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_ForceMode_MapsEvenWithLocal(t *testing.T) {
|
||||||
|
// Setup: Create a router with BOTH a local provider for "gpt-4" AND a mapping from "gpt-4" to "claude-local"
|
||||||
|
// The mapping target "claude-local" also has a provider
|
||||||
|
registry := NewRegistry()
|
||||||
|
|
||||||
|
// Local provider for gpt-4
|
||||||
|
openaiProvider := &MockProvider{
|
||||||
|
name: "openai",
|
||||||
|
supportedModels: []string{"gpt-4"},
|
||||||
|
available: true,
|
||||||
|
priority: 1,
|
||||||
|
}
|
||||||
|
registry.Register(openaiProvider)
|
||||||
|
|
||||||
|
// Local provider for the mapped model
|
||||||
|
anthropicProvider := &MockProvider{
|
||||||
|
name: "anthropic",
|
||||||
|
supportedModels: []string{"claude-local"},
|
||||||
|
available: true,
|
||||||
|
priority: 2,
|
||||||
|
}
|
||||||
|
registry.Register(anthropicProvider)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "gpt-4", To: "claude-local"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request gpt-4 with ForceModelMapping=true
|
||||||
|
// Even though gpt-4 has a local provider, mapping should take precedence
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "gpt-4",
|
||||||
|
PreferLocalProvider: false,
|
||||||
|
ForceModelMapping: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Should return MODEL_MAPPING, not LOCAL_PROVIDER
|
||||||
|
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||||
|
assert.Equal(t, "claude-local", decision.ResolvedModel)
|
||||||
|
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||||
|
assert.False(t, decision.ShouldProxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_ThinkingSuffix_Preserved(t *testing.T) {
|
||||||
|
// Setup: Create a router with mapping and provider for mapped model
|
||||||
|
registry := NewRegistry()
|
||||||
|
|
||||||
|
mockProvider := &MockProvider{
|
||||||
|
name: "anthropic",
|
||||||
|
supportedModels: []string{"claude-local"},
|
||||||
|
available: true,
|
||||||
|
priority: 1,
|
||||||
|
}
|
||||||
|
registry.Register(mockProvider)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "claude-3-5-sonnet", To: "claude-local"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request claude-3-5-sonnet with thinking suffix
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "claude-3-5-sonnet(thinking:foo)",
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Thinking suffix should be preserved in resolved model
|
||||||
|
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||||
|
assert.Equal(t, "claude-local(thinking:foo)", decision.ResolvedModel)
|
||||||
|
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockProvider is a mock implementation of Provider for testing
|
||||||
|
type MockProvider struct {
|
||||||
|
name string
|
||||||
|
providerType ProviderType
|
||||||
|
supportedModels []string
|
||||||
|
available bool
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Name() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Type() ProviderType {
|
||||||
|
if m.providerType == "" {
|
||||||
|
return ProviderTypeOAuth
|
||||||
|
}
|
||||||
|
return m.providerType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) SupportsModel(model string) bool {
|
||||||
|
for _, supported := range m.supportedModels {
|
||||||
|
if supported == model {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Available(model string) bool {
|
||||||
|
return m.available
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Priority() int {
|
||||||
|
return m.priority
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||||
|
return executor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
113
internal/routing/testutil/fake_handler.go
Normal file
113
internal/routing/testutil/fake_handler.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FakeHandlerRecorder records handler invocations for testing.
|
||||||
|
type FakeHandlerRecorder struct {
|
||||||
|
Called bool
|
||||||
|
CallCount int
|
||||||
|
RequestBody []byte
|
||||||
|
RequestHeader http.Header
|
||||||
|
ContextKeys map[string]interface{}
|
||||||
|
ResponseStatus int
|
||||||
|
ResponseBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFakeHandlerRecorder creates a new fake handler recorder.
|
||||||
|
func NewFakeHandlerRecorder() *FakeHandlerRecorder {
|
||||||
|
return &FakeHandlerRecorder{
|
||||||
|
ContextKeys: make(map[string]interface{}),
|
||||||
|
ResponseStatus: http.StatusOK,
|
||||||
|
ResponseBody: []byte(`{"status":"handled"}`),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GinHandler returns a gin.HandlerFunc that records the invocation.
|
||||||
|
func (f *FakeHandlerRecorder) GinHandler() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
f.record(c)
|
||||||
|
c.Data(f.ResponseStatus, "application/json", f.ResponseBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GinHandlerWithModel returns a gin.HandlerFunc that records the invocation and returns the model from context.
|
||||||
|
// Useful for testing response rewriting in model mapping scenarios.
|
||||||
|
func (f *FakeHandlerRecorder) GinHandlerWithModel() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
f.record(c)
|
||||||
|
// Return a response with the model field that would be in the actual API response
|
||||||
|
// If ResponseBody was explicitly set (not default), use that; otherwise generate from context
|
||||||
|
var body []byte
|
||||||
|
if mappedModel, exists := c.Get("mapped_model"); exists {
|
||||||
|
body = []byte(`{"model":"` + mappedModel.(string) + `","status":"handled"}`)
|
||||||
|
} else {
|
||||||
|
body = f.ResponseBody
|
||||||
|
}
|
||||||
|
c.Data(f.ResponseStatus, "application/json", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPHandler returns an http.HandlerFunc that records the invocation.
|
||||||
|
func (f *FakeHandlerRecorder) HTTPHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
f.Called = true
|
||||||
|
f.CallCount++
|
||||||
|
f.RequestBody = body
|
||||||
|
f.RequestHeader = r.Header.Clone()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(f.ResponseStatus)
|
||||||
|
w.Write(f.ResponseBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// record captures the request details from gin context.
|
||||||
|
func (f *FakeHandlerRecorder) record(c *gin.Context) {
|
||||||
|
f.Called = true
|
||||||
|
f.CallCount++
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(c.Request.Body)
|
||||||
|
f.RequestBody = body
|
||||||
|
f.RequestHeader = c.Request.Header.Clone()
|
||||||
|
|
||||||
|
// Capture common context keys used by routing
|
||||||
|
if val, exists := c.Get("mapped_model"); exists {
|
||||||
|
f.ContextKeys["mapped_model"] = val
|
||||||
|
}
|
||||||
|
if val, exists := c.Get("fallback_models"); exists {
|
||||||
|
f.ContextKeys["fallback_models"] = val
|
||||||
|
}
|
||||||
|
if val, exists := c.Get("route_type"); exists {
|
||||||
|
f.ContextKeys["route_type"] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the recorder state.
|
||||||
|
func (f *FakeHandlerRecorder) Reset() {
|
||||||
|
f.Called = false
|
||||||
|
f.CallCount = 0
|
||||||
|
f.RequestBody = nil
|
||||||
|
f.RequestHeader = nil
|
||||||
|
f.ContextKeys = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetContextKey returns a captured context key value.
|
||||||
|
func (f *FakeHandlerRecorder) GetContextKey(key string) (interface{}, bool) {
|
||||||
|
val, ok := f.ContextKeys[key]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// WasCalled returns true if the handler was called.
|
||||||
|
func (f *FakeHandlerRecorder) WasCalled() bool {
|
||||||
|
return f.Called
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCallCount returns the number of times the handler was called.
|
||||||
|
func (f *FakeHandlerRecorder) GetCallCount() int {
|
||||||
|
return f.CallCount
|
||||||
|
}
|
||||||
83
internal/routing/testutil/fake_proxy.go
Normal file
83
internal/routing/testutil/fake_proxy.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CloseNotifierRecorder wraps httptest.ResponseRecorder with CloseNotify support.
|
||||||
|
// This is needed because ReverseProxy requires http.CloseNotifier.
|
||||||
|
type CloseNotifierRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
closeChan chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCloseNotifierRecorder creates a ResponseRecorder that implements CloseNotifier.
|
||||||
|
func NewCloseNotifierRecorder() *CloseNotifierRecorder {
|
||||||
|
return &CloseNotifierRecorder{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
closeChan: make(chan bool, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseNotify implements http.CloseNotifier.
|
||||||
|
func (c *CloseNotifierRecorder) CloseNotify() <-chan bool {
|
||||||
|
return c.closeChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// FakeProxyRecorder records proxy invocations for testing.
|
||||||
|
type FakeProxyRecorder struct {
|
||||||
|
Called bool
|
||||||
|
CallCount int
|
||||||
|
RequestBody []byte
|
||||||
|
RequestHeaders http.Header
|
||||||
|
ResponseStatus int
|
||||||
|
ResponseBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFakeProxyRecorder creates a new fake proxy recorder.
|
||||||
|
func NewFakeProxyRecorder() *FakeProxyRecorder {
|
||||||
|
return &FakeProxyRecorder{
|
||||||
|
ResponseStatus: http.StatusOK,
|
||||||
|
ResponseBody: []byte(`{"status":"proxied"}`),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements http.Handler to act as a reverse proxy.
|
||||||
|
func (f *FakeProxyRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
f.Called = true
|
||||||
|
f.CallCount++
|
||||||
|
f.RequestHeaders = r.Header.Clone()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err == nil {
|
||||||
|
f.RequestBody = body
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(f.ResponseStatus)
|
||||||
|
w.Write(f.ResponseBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCallCount returns the number of times the proxy was called.
|
||||||
|
func (f *FakeProxyRecorder) GetCallCount() int {
|
||||||
|
return f.CallCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the recorder state.
|
||||||
|
func (f *FakeProxyRecorder) Reset() {
|
||||||
|
f.Called = false
|
||||||
|
f.CallCount = 0
|
||||||
|
f.RequestBody = nil
|
||||||
|
f.RequestHeaders = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToHandler returns the recorder as an http.Handler for use with httptest.
|
||||||
|
func (f *FakeProxyRecorder) ToHandler() http.Handler {
|
||||||
|
return http.HandlerFunc(f.ServeHTTP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTestServer creates an httptest server with this fake proxy.
|
||||||
|
func (f *FakeProxyRecorder) CreateTestServer() *httptest.Server {
|
||||||
|
return httptest.NewServer(f.ToHandler())
|
||||||
|
}
|
||||||
62
internal/routing/types.go
Normal file
62
internal/routing/types.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
// RouteType represents the type of routing decision made for a request.
|
||||||
|
type RouteType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free).
|
||||||
|
RouteTypeLocalProvider RouteType = "LOCAL_PROVIDER"
|
||||||
|
// RouteTypeModelMapping indicates the request was remapped to another available model (free).
|
||||||
|
RouteTypeModelMapping RouteType = "MODEL_MAPPING"
|
||||||
|
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits).
|
||||||
|
RouteTypeAmpCredits RouteType = "AMP_CREDITS"
|
||||||
|
// RouteTypeNoProvider indicates no provider or fallback available.
|
||||||
|
RouteTypeNoProvider RouteType = "NO_PROVIDER"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoutingRequest contains the information needed to make a routing decision.
|
||||||
|
type RoutingRequest struct {
|
||||||
|
// RequestedModel is the model name from the incoming request.
|
||||||
|
RequestedModel string
|
||||||
|
// PreferLocalProvider indicates whether to prefer local providers over mappings.
|
||||||
|
// When true, check local providers first before applying model mappings.
|
||||||
|
PreferLocalProvider bool
|
||||||
|
// ForceModelMapping indicates whether to force model mapping even if local provider exists.
|
||||||
|
// When true, apply model mappings first and skip local provider checks.
|
||||||
|
ForceModelMapping bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoutingDecision contains the result of a routing decision.
|
||||||
|
type RoutingDecision struct {
|
||||||
|
// RouteType indicates the type of routing decision.
|
||||||
|
RouteType RouteType
|
||||||
|
// ResolvedModel is the final model name after any mappings.
|
||||||
|
ResolvedModel string
|
||||||
|
// ProviderName is the name of the selected provider (if any).
|
||||||
|
ProviderName string
|
||||||
|
// FallbackModels is a list of alternative models to try if the primary fails.
|
||||||
|
FallbackModels []string
|
||||||
|
// ShouldProxy indicates whether the request should be proxied to ampcode.com.
|
||||||
|
ShouldProxy bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoutingDecision creates a new RoutingDecision with the given parameters.
|
||||||
|
func NewRoutingDecision(routeType RouteType, resolvedModel, providerName string, fallbackModels []string, shouldProxy bool) *RoutingDecision {
|
||||||
|
return &RoutingDecision{
|
||||||
|
RouteType: routeType,
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ProviderName: providerName,
|
||||||
|
FallbackModels: fallbackModels,
|
||||||
|
ShouldProxy: shouldProxy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLocal returns true if the decision routes to a local provider.
|
||||||
|
func (d *RoutingDecision) IsLocal() bool {
|
||||||
|
return d.RouteType == RouteTypeLocalProvider || d.RouteType == RouteTypeModelMapping
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFallbacks returns true if there are fallback models available.
|
||||||
|
func (d *RoutingDecision) HasFallbacks() bool {
|
||||||
|
return len(d.FallbackModels) > 0
|
||||||
|
}
|
||||||
270
internal/routing/wrapper.go
Normal file
270
internal/routing/wrapper.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyFunc is the function type for proxying requests.
|
||||||
|
type ProxyFunc func(c *gin.Context)
|
||||||
|
|
||||||
|
// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic.
|
||||||
|
// It replaces the FallbackHandler logic with a Router-based approach.
|
||||||
|
type ModelRoutingWrapper struct {
|
||||||
|
router *Router
|
||||||
|
extractor ModelExtractor
|
||||||
|
rewriter ModelRewriter
|
||||||
|
proxyFunc ProxyFunc
|
||||||
|
logger *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies.
|
||||||
|
// If extractor is nil, a DefaultModelExtractor is used.
|
||||||
|
// If rewriter is nil, a DefaultModelRewriter is used.
|
||||||
|
// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead.
|
||||||
|
func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper {
|
||||||
|
if extractor == nil {
|
||||||
|
extractor = NewModelExtractor()
|
||||||
|
}
|
||||||
|
if rewriter == nil {
|
||||||
|
rewriter = NewModelRewriter()
|
||||||
|
}
|
||||||
|
return &ModelRoutingWrapper{
|
||||||
|
router: router,
|
||||||
|
extractor: extractor,
|
||||||
|
rewriter: rewriter,
|
||||||
|
proxyFunc: proxyFunc,
|
||||||
|
logger: logrus.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogger sets the logger for the wrapper.
|
||||||
|
func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) {
|
||||||
|
w.logger = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap wraps a gin.HandlerFunc with model routing logic.
|
||||||
|
// The returned handler will:
|
||||||
|
// 1. Extract the model from the request
|
||||||
|
// 2. Get a routing decision from the Router
|
||||||
|
// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS)
|
||||||
|
func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Read request body
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Errorf("routing wrapper: failed to read request body: %v", err)
|
||||||
|
handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract model from request
|
||||||
|
ginParams := map[string]string{
|
||||||
|
"action": c.Param("action"),
|
||||||
|
"path": c.Param("path"),
|
||||||
|
}
|
||||||
|
modelName, err := w.extractor.Extract(bodyBytes, ginParams)
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Warnf("routing wrapper: failed to extract model: %v", err)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelName == "" {
|
||||||
|
// No model found, proceed with original handler
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get routing decision
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: modelName,
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false, // TODO: Get from config
|
||||||
|
}
|
||||||
|
decision := w.router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Store decision in context for downstream handlers
|
||||||
|
c.Set(string(ctxkeys.RoutingDecision), decision)
|
||||||
|
|
||||||
|
// Handle based on route type
|
||||||
|
switch decision.RouteType {
|
||||||
|
case RouteTypeLocalProvider:
|
||||||
|
w.handleLocalProvider(c, handler, bodyBytes, decision)
|
||||||
|
case RouteTypeModelMapping:
|
||||||
|
w.handleModelMapping(c, handler, bodyBytes, decision)
|
||||||
|
case RouteTypeAmpCredits:
|
||||||
|
w.handleAmpCredits(c, handler, bodyBytes)
|
||||||
|
default:
|
||||||
|
// No provider available
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLocalProvider handles the LOCAL_PROVIDER route type.
|
||||||
|
func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
||||||
|
// Filter Anthropic-Beta header for local provider
|
||||||
|
filterAnthropicBetaHeader(c)
|
||||||
|
|
||||||
|
// Restore body with original content
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
handler(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleModelMapping handles the MODEL_MAPPING route type.
|
||||||
|
func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
||||||
|
// Rewrite request body with mapped model
|
||||||
|
rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel)
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err)
|
||||||
|
rewrittenBody = bodyBytes
|
||||||
|
}
|
||||||
|
_ = rewrittenBody
|
||||||
|
|
||||||
|
// Store mapped model in context
|
||||||
|
c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel)
|
||||||
|
|
||||||
|
// Store fallback models in context if present
|
||||||
|
if len(decision.FallbackModels) > 0 {
|
||||||
|
c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter Anthropic-Beta header for local provider
|
||||||
|
filterAnthropicBetaHeader(c)
|
||||||
|
|
||||||
|
// Restore body with rewritten content
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody))
|
||||||
|
|
||||||
|
// Wrap response writer to rewrite model back
|
||||||
|
wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel)
|
||||||
|
c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer}
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
handler(c)
|
||||||
|
|
||||||
|
// Cleanup (flush response rewriting)
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAmpCredits handles the AMP_CREDITS route type.
|
||||||
|
// It calls the proxy function directly if available, otherwise passes to handler.
|
||||||
|
// Does NOT filter headers or rewrite body - proxy handles everything.
|
||||||
|
func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) {
|
||||||
|
// Restore body with original content (no rewriting for proxy)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|
||||||
|
// Call proxy function if available, otherwise fall back to handler
|
||||||
|
if w.proxyFunc != nil {
|
||||||
|
w.proxyFunc(c)
|
||||||
|
} else {
|
||||||
|
handler(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers.
|
||||||
|
func filterAnthropicBetaHeader(c *gin.Context) {
|
||||||
|
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||||
|
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
||||||
|
if filtered != "" {
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||||
|
} else {
|
||||||
|
c.Request.Header.Del("Anthropic-Beta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterBetaFeatures removes specified beta features from the header.
|
||||||
|
func filterBetaFeatures(betaHeader, featureToRemove string) string {
|
||||||
|
// Simple implementation - can be enhanced
|
||||||
|
if betaHeader == featureToRemove {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return betaHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter.
|
||||||
|
type ginResponseWriterAdapter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
original gin.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ginResponseWriterAdapter) WriteHeader(code int) {
|
||||||
|
a.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) {
|
||||||
|
return a.ResponseWriter.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ginResponseWriterAdapter) Header() http.Header {
|
||||||
|
return a.ResponseWriter.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseNotify implements http.CloseNotifier.
|
||||||
|
func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool {
|
||||||
|
if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok {
|
||||||
|
return notifier.CloseNotify()
|
||||||
|
}
|
||||||
|
return a.original.CloseNotify()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush implements http.Flusher.
|
||||||
|
func (a *ginResponseWriterAdapter) Flush() {
|
||||||
|
if flusher, ok := a.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack implements http.Hijacker.
|
||||||
|
func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
return a.original.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status returns the HTTP status code.
|
||||||
|
func (a *ginResponseWriterAdapter) Status() int {
|
||||||
|
return a.original.Status()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the number of bytes already written into the response http body.
|
||||||
|
func (a *ginResponseWriterAdapter) Size() int {
|
||||||
|
return a.original.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Written returns whether or not the response for this context has been written.
|
||||||
|
func (a *ginResponseWriterAdapter) Written() bool {
|
||||||
|
return a.original.Written()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeaderNow forces WriteHeader to be called.
|
||||||
|
func (a *ginResponseWriterAdapter) WriteHeaderNow() {
|
||||||
|
a.original.WriteHeaderNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteString writes the given string into the response body.
|
||||||
|
func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) {
|
||||||
|
return a.Write([]byte(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pusher returns the http.Pusher for server push.
|
||||||
|
func (a *ginResponseWriterAdapter) Pusher() http.Pusher {
|
||||||
|
if pusher, ok := a.ResponseWriter.(http.Pusher); ok {
|
||||||
|
return pusher
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user