mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
Compare commits
7 Commits
v6.7.41
...
9299897e04
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9299897e04 | ||
|
|
527a269799 | ||
|
|
2fe0b6cd2d | ||
|
|
eeb1812d60 | ||
|
|
adedb16d35 | ||
|
|
89907231c1 | ||
|
|
09044e8ccc |
@@ -125,6 +125,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.registerOnce.Do(func() {
|
||||
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
// Load oauth-model-alias for provider lookup via aliases
|
||||
m.modelMapper.UpdateOAuthModelAlias(ctx.Config.OAuthModelAlias)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
@@ -212,6 +214,11 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Always update oauth-model-alias for model mapper (used for provider lookup)
|
||||
if m.modelMapper != nil {
|
||||
m.modelMapper.UpdateOAuthModelAlias(cfg.OAuthModelAlias)
|
||||
}
|
||||
|
||||
if m.enabled {
|
||||
// Check upstream URL change - now supports hot-reload
|
||||
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
||||
|
||||
@@ -2,12 +2,15 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -30,7 +33,13 @@ const (
|
||||
)
|
||||
|
||||
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||
const MappedModelContextKey = "mapped_model"
|
||||
// Deprecated: Use ctxkeys.MappedModel instead.
|
||||
const MappedModelContextKey = string(ctxkeys.MappedModel)
|
||||
|
||||
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
|
||||
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
|
||||
// Deprecated: Use ctxkeys.FallbackModels instead.
|
||||
const FallbackModelsContextKey = string(ctxkeys.FallbackModels)
|
||||
|
||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||
@@ -77,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
|
||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||
// when the model's provider is not available in CLIProxyAPI
|
||||
//
|
||||
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
|
||||
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
|
||||
// This type is kept for backward compatibility and test purposes.
|
||||
type FallbackHandler struct {
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
modelMapper ModelMapper
|
||||
@@ -85,6 +98,8 @@ type FallbackHandler struct {
|
||||
|
||||
// NewFallbackHandler creates a new fallback handler wrapper
|
||||
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
||||
//
|
||||
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
||||
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
||||
return &FallbackHandler{
|
||||
getProxy: getProxy,
|
||||
@@ -93,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
|
||||
}
|
||||
|
||||
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||
//
|
||||
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
||||
if forceModelMappings == nil {
|
||||
forceModelMappings = func() bool { return false }
|
||||
@@ -113,6 +130,20 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
|
||||
// ReverseProxy raises this panic when the client connection is closed prematurely
|
||||
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
|
||||
// with a ResponseWriter that doesn't implement http.CloseNotifier.
|
||||
// This is an expected error condition, not a bug, so we handle it gracefully.
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||
return
|
||||
}
|
||||
panic(rec)
|
||||
}
|
||||
}()
|
||||
|
||||
requestPath := c.Request.URL.Path
|
||||
|
||||
// Read the request body to extract the model name
|
||||
@@ -142,36 +173,57 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
resolveMappedModel := func() (string, []string) {
|
||||
// resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
|
||||
resolveMappedModels := func() ([]string, []string) {
|
||||
if fh.modelMapper == nil {
|
||||
return "", nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
mappedModel = strings.TrimSpace(mappedModel)
|
||||
if mappedModel == "" {
|
||||
return "", nil
|
||||
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
|
||||
if !ok {
|
||||
// Fallback to single model for non-DefaultModelMapper
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
if mappedModel == "" {
|
||||
return nil, nil
|
||||
}
|
||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{mappedModel}, mappedProviders
|
||||
}
|
||||
|
||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||
// already specifies its own thinking suffix.
|
||||
if thinkingSuffix != "" {
|
||||
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
|
||||
if !mappedSuffixResult.HasSuffix {
|
||||
mappedModel += thinkingSuffix
|
||||
// Use MapModelWithFallbacks for DefaultModelMapper
|
||||
mappedModels := mapper.MapModelWithFallbacks(modelName)
|
||||
if len(mappedModels) == 0 {
|
||||
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
|
||||
}
|
||||
if len(mappedModels) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Apply thinking suffix if needed
|
||||
for i, model := range mappedModels {
|
||||
if thinkingSuffix != "" {
|
||||
suffixResult := thinking.ParseSuffix(model)
|
||||
if !suffixResult.HasSuffix {
|
||||
mappedModels[i] = model + thinkingSuffix
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return "", nil
|
||||
// Get providers for the first model
|
||||
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
|
||||
providers := util.GetProviderName(firstBaseModel)
|
||||
if len(providers) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return mappedModel, mappedProviders
|
||||
return mappedModels, providers
|
||||
}
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
@@ -179,21 +231,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
usedMapping := false
|
||||
var providers []string
|
||||
|
||||
// Helper to apply model mapping and update state
|
||||
applyMapping := func(mappedModels []string, mappedProviders []string) {
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
|
||||
if len(mappedModels) > 1 {
|
||||
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
|
||||
}
|
||||
resolvedModel = mappedModels[0]
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
|
||||
// Check if model mappings should be forced ahead of local API keys
|
||||
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
||||
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||
applyMapping(mappedModels, mappedProviders)
|
||||
}
|
||||
|
||||
// If no mapping applied, check for local providers
|
||||
@@ -206,15 +264,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||
applyMapping(mappedModels, mappedProviders)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Characterization tests for fallback_handlers.go using testutil recorders
|
||||
// These tests capture existing behavior before refactoring to routing layer
|
||||
|
||||
func TestCharacterization_LocalProvider(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the test model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{
|
||||
{ID: "test-model-local"},
|
||||
})
|
||||
defer reg.UnregisterClient("char-test-local")
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create gin context
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Create fallback handler with proxy recorder
|
||||
// Create a test server to act as the proxy target
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
// Create a reverse proxy that forwards to our test server
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||
wrapped(c)
|
||||
|
||||
// Assert: proxy NOT called
|
||||
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
|
||||
|
||||
// Assert: local handler called once
|
||||
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
|
||||
|
||||
// Assert: request body model unchanged
|
||||
assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged")
|
||||
}
|
||||
|
||||
func TestCharacterization_ModelMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the TARGET model (the mapped-to model)
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{
|
||||
{ID: "gpt-4-local"},
|
||||
})
|
||||
defer reg.UnregisterClient("char-test-mapped")
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create model mapper with a mapping
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-4-turbo", To: "gpt-4-local"},
|
||||
})
|
||||
|
||||
// Create gin context
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Request with original model that gets mapped
|
||||
body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Create fallback handler with mapper
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
}, mapper, func() bool { return false })
|
||||
|
||||
// Execute - use handler that returns model in response for rewriter to work
|
||||
wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel())
|
||||
wrapped(c)
|
||||
|
||||
// Assert: proxy NOT called
|
||||
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping")
|
||||
|
||||
// Assert: local handler called once
|
||||
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
|
||||
|
||||
// Assert: request body model was rewritten to mapped model
|
||||
assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model")
|
||||
assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model")
|
||||
|
||||
// Assert: context has mapped_model key set
|
||||
mappedModel, exists := handlerRecorder.GetContextKey("mapped_model")
|
||||
assert.True(t, exists, "context should have mapped_model key")
|
||||
assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model")
|
||||
|
||||
// Assert: response body model rewritten back to original
|
||||
// The response writer should rewrite model names in the response
|
||||
responseBody := w.Body.String()
|
||||
assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name")
|
||||
}
|
||||
|
||||
func TestCharacterization_AmpCreditsProxy(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Setup recorders - NO local provider registered, NO mapping configured
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create gin context with CloseNotifier support (required for ReverseProxy)
|
||||
w := testutil.NewCloseNotifierRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Request with a model that has no local provider and no mapping
|
||||
body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Create fallback handler
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||
wrapped(c)
|
||||
|
||||
// Assert: proxy called once
|
||||
assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping")
|
||||
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
|
||||
|
||||
// Assert: local handler NOT called
|
||||
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy")
|
||||
|
||||
// Assert: body forwarded to proxy is original (no rewrite)
|
||||
assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying")
|
||||
}
|
||||
|
||||
func TestCharacterization_BodyRestore(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the test model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{
|
||||
{ID: "test-model-body"},
|
||||
})
|
||||
defer reg.UnregisterClient("char-test-body")
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create gin context
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create a complex request body that will be read by the wrapper for model extraction
|
||||
originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Create fallback handler with proxy recorder
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||
wrapped(c)
|
||||
|
||||
// Assert: local handler called (not proxy, since we have a local provider)
|
||||
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
|
||||
|
||||
// Assert: handler receives complete original body
|
||||
// This verifies that the body was properly restored after the wrapper read it for model extraction
|
||||
assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction")
|
||||
}
|
||||
|
||||
// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler
|
||||
// This is a characterization test for the route gating logic in routes.go
|
||||
func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the test model (Gemini format uses path-based model extraction)
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{
|
||||
{ID: "gemini-pro"},
|
||||
})
|
||||
defer reg.UnregisterClient("char-test-gemini")
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create a test server for the proxy
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
// Create fallback handler
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Create the Gemini bridge handler (simulating what routes.go does)
|
||||
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
|
||||
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
|
||||
|
||||
// Create router with the same gating logic as routes.go
|
||||
r := gin.New()
|
||||
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||
if c.Request.Method == "POST" {
|
||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
// Non-POST or no /models/ in path -> proxy upstream
|
||||
proxyRecorder.ServeHTTP(c.Writer, c.Request)
|
||||
})
|
||||
|
||||
// Execute: POST request with /models/ in path
|
||||
body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Assert: local Gemini handler called
|
||||
assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/")
|
||||
|
||||
// Assert: proxy NOT called
|
||||
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path")
|
||||
}
|
||||
|
||||
// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy
|
||||
// This is a characterization test for the route gating logic in routes.go
|
||||
func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create a test server for the proxy
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
// Create fallback handler
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Create the Gemini bridge handler
|
||||
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
|
||||
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
|
||||
|
||||
// Create router with the same gating logic as routes.go
|
||||
r := gin.New()
|
||||
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||
if c.Request.Method == "POST" {
|
||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
proxyRecorder.ServeHTTP(c.Writer, c.Request)
|
||||
})
|
||||
|
||||
// Execute: GET request (even with /models/ in path)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Assert: proxy called
|
||||
assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests")
|
||||
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
|
||||
|
||||
// Assert: local handler NOT called
|
||||
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests")
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
@@ -11,63 +11,138 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||
// Characterization tests for fallback_handlers.go
|
||||
// These tests capture existing behavior before refactoring to routing layer
|
||||
|
||||
func TestFallbackHandler_WrapHandler_LocalProvider_NoMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
||||
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
// Setup: model that has local providers (gemini-2.5-pro is registered)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
body := `{"model": "gemini-2.5-pro", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Handler that should be called (not proxy)
|
||||
handlerCalled := false
|
||||
handler := func(c *gin.Context) {
|
||||
handlerCalled = true
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
// Create fallback handler
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
return nil // no proxy
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-amp-fallback")
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handler)
|
||||
wrapped(c)
|
||||
|
||||
// Assert: handler should be called directly (no mapping needed)
|
||||
assert.True(t, handlerCalled, "handler should be called for local provider")
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestFallbackHandler_WrapHandler_MappingApplied(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the target model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client", "anthropic", []*registry.ModelInfo{
|
||||
{ID: "claude-opus-4-5-thinking"},
|
||||
})
|
||||
|
||||
// Setup: model that needs mapping
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
body := `{"model": "claude-opus-4-5-20251101", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Handler to capture rewritten body
|
||||
var capturedBody []byte
|
||||
handler := func(c *gin.Context) {
|
||||
capturedBody, _ = io.ReadAll(c.Request.Body)
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
// Create fallback handler with mapper
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"},
|
||||
})
|
||||
|
||||
fh := NewFallbackHandlerWithMapper(
|
||||
func() *httputil.ReverseProxy { return nil },
|
||||
mapper,
|
||||
func() bool { return false },
|
||||
)
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handler)
|
||||
wrapped(c)
|
||||
|
||||
// Assert: body should be rewritten
|
||||
assert.Contains(t, string(capturedBody), "claude-opus-4-5-thinking")
|
||||
|
||||
// Assert: context should have mapped model
|
||||
mappedModel, exists := c.Get(MappedModelContextKey)
|
||||
assert.True(t, exists, "MappedModelContextKey should be set")
|
||||
assert.NotEmpty(t, mappedModel)
|
||||
}
|
||||
|
||||
func TestFallbackHandler_WrapHandler_ThinkingSuffixPreserved(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the target model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-2", "anthropic", []*registry.ModelInfo{
|
||||
{ID: "claude-opus-4-5-thinking"},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Model with thinking suffix
|
||||
body := `{"model": "claude-opus-4-5-20251101(xhigh)", "messages": []}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
var capturedBody []byte
|
||||
handler := func(c *gin.Context) {
|
||||
capturedBody, _ = io.ReadAll(c.Request.Body)
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
||||
{From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"},
|
||||
})
|
||||
|
||||
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
||||
fh := NewFallbackHandlerWithMapper(
|
||||
func() *httputil.ReverseProxy { return nil },
|
||||
mapper,
|
||||
func() bool { return false },
|
||||
)
|
||||
|
||||
handler := func(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
wrapped := fh.WrapHandler(handler)
|
||||
wrapped(c)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": req.Model,
|
||||
"seen_model": req.Model,
|
||||
})
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
||||
|
||||
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Model string `json:"model"`
|
||||
SeenModel string `json:"seen_model"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response JSON: %v", err)
|
||||
}
|
||||
|
||||
if resp.Model != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
||||
}
|
||||
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
||||
}
|
||||
// Assert: thinking suffix should be preserved
|
||||
assert.Contains(t, string(capturedBody), "(xhigh)")
|
||||
}
|
||||
|
||||
func TestFallbackHandler_WrapHandler_NoProvider_NoMapping_ProxyEnabled(t *testing.T) {
|
||||
// Skip: httptest.ResponseRecorder doesn't implement http.CloseNotifier
|
||||
// which is required by httputil.ReverseProxy. This test requires a real
|
||||
// HTTP server and client to properly test proxy behavior.
|
||||
t.Skip("requires real HTTP server for proxy testing")
|
||||
}
|
||||
|
||||
@@ -30,18 +30,98 @@ type DefaultModelMapper struct {
|
||||
mu sync.RWMutex
|
||||
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||
regexps []regexMapping // regex rules evaluated in order
|
||||
|
||||
// oauthAliasForward maps channel -> name (lower) -> []alias for oauth-model-alias lookup.
|
||||
// This allows model-mappings targets to find providers via their aliases.
|
||||
oauthAliasForward map[string]map[string][]string
|
||||
}
|
||||
|
||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
m := &DefaultModelMapper{
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
oauthAliasForward: nil,
|
||||
}
|
||||
m.UpdateMappings(mappings)
|
||||
return m
|
||||
}
|
||||
|
||||
// UpdateOAuthModelAlias updates the oauth-model-alias lookup table.
|
||||
// This is called during initialization and on config hot-reload.
|
||||
func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.OAuthModelAlias) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if len(aliases) == 0 {
|
||||
m.oauthAliasForward = nil
|
||||
return
|
||||
}
|
||||
|
||||
forward := make(map[string]map[string][]string, len(aliases))
|
||||
for rawChannel, entries := range aliases {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
channelMap := make(map[string][]string)
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
nameKey := strings.ToLower(name)
|
||||
channelMap[nameKey] = append(channelMap[nameKey], alias)
|
||||
}
|
||||
if len(channelMap) > 0 {
|
||||
forward[channel] = channelMap
|
||||
}
|
||||
}
|
||||
if len(forward) == 0 {
|
||||
m.oauthAliasForward = nil
|
||||
return
|
||||
}
|
||||
m.oauthAliasForward = forward
|
||||
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
|
||||
}
|
||||
|
||||
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
|
||||
// that have available providers. Useful for fallback when one alias is quota-exceeded.
|
||||
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
|
||||
if m.oauthAliasForward == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetKey := strings.ToLower(strings.TrimSpace(targetModel))
|
||||
if targetKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
// Check all channels for this model name
|
||||
for _, channelMap := range m.oauthAliasForward {
|
||||
aliases := channelMap[targetKey]
|
||||
for _, alias := range aliases {
|
||||
aliasLower := strings.ToLower(alias)
|
||||
if _, exists := seen[aliasLower]; exists {
|
||||
continue
|
||||
}
|
||||
providers := util.GetProviderName(alias)
|
||||
if len(providers) > 0 {
|
||||
result = append(result, alias)
|
||||
seen[aliasLower] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MapModel checks if a mapping exists for the requested model and if the
|
||||
// target model has available local providers. Returns the mapped model name
|
||||
// or empty string if no valid mapping exists.
|
||||
@@ -51,9 +131,20 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
// However, if the mapping target already contains a suffix, the config suffix
|
||||
// takes priority over the user's suffix.
|
||||
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
if requestedModel == "" {
|
||||
models := m.MapModelWithFallbacks(requestedModel)
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return models[0]
|
||||
}
|
||||
|
||||
// MapModelWithFallbacks returns all possible target models for the requested model,
|
||||
// including fallback aliases from oauth-model-alias. The first model is the primary target,
|
||||
// and subsequent models are fallbacks to try if the primary is unavailable (e.g., quota exceeded).
|
||||
func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []string {
|
||||
if requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
@@ -78,34 +169,54 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if target model already has a thinking suffix (config priority)
|
||||
targetResult := thinking.ParseSuffix(targetModel)
|
||||
targetBase := targetResult.ModelName
|
||||
|
||||
// Helper to apply suffix to a model
|
||||
applySuffix := func(model string) string {
|
||||
modelResult := thinking.ParseSuffix(model)
|
||||
if modelResult.HasSuffix {
|
||||
return model
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return model + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// Verify target model has available providers (use base model for lookup)
|
||||
providers := util.GetProviderName(targetResult.ModelName)
|
||||
if len(providers) == 0 {
|
||||
providers := util.GetProviderName(targetBase)
|
||||
|
||||
// If direct provider available, return it as primary
|
||||
if len(providers) > 0 {
|
||||
return []string{applySuffix(targetModel)}
|
||||
}
|
||||
|
||||
// No direct providers - check oauth-model-alias for all aliases that have providers
|
||||
allAliases := m.findAllAliasesWithProviders(targetBase)
|
||||
if len(allAliases) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
|
||||
if targetResult.HasSuffix {
|
||||
// Config's "to" already contains a suffix - use it as-is (config priority)
|
||||
return targetModel
|
||||
// Log resolution
|
||||
if len(allAliases) == 1 {
|
||||
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
|
||||
} else {
|
||||
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1)
|
||||
}
|
||||
|
||||
// Preserve user's thinking suffix on the mapped model
|
||||
// (skip empty suffixes to avoid returning "model()")
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return targetModel + "(" + requestResult.RawSuffix + ")"
|
||||
// Apply suffix to all aliases
|
||||
result := make([]string, len(allAliases))
|
||||
for i, alias := range allAliases {
|
||||
result[i] = applySuffix(alias)
|
||||
}
|
||||
|
||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||
return targetModel
|
||||
return result
|
||||
}
|
||||
|
||||
// UpdateMappings refreshes the mapping configuration from config.
|
||||
@@ -165,6 +276,22 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice.
|
||||
// Safe for concurrent use.
|
||||
func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make([]config.AmpModelMapping, 0, len(m.mappings))
|
||||
for from, to := range m.mappings {
|
||||
result = append(result, config.AmpModelMapping{
|
||||
From: from,
|
||||
To: to,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type regexMapping struct {
|
||||
re *regexp.Regexp
|
||||
to string
|
||||
|
||||
@@ -5,11 +5,12 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
@@ -234,19 +235,20 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
||||
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.getProxy()
|
||||
}, m.modelMapper, m.forceModelMappings)
|
||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||
|
||||
// Route POST model calls through Gemini bridge with FallbackHandler.
|
||||
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
|
||||
// T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper
|
||||
// Create a dedicated routing wrapper for the Gemini bridge
|
||||
geminiBridgeWrapper := m.createModelRoutingWrapper()
|
||||
geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge)
|
||||
|
||||
// Route POST model calls through Gemini bridge with ModelRoutingWrapper.
|
||||
// ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically.
|
||||
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
||||
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||
if c.Request.Method == "POST" {
|
||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||
// FallbackHandler will check provider/mapping and proxy if needed
|
||||
// POST with /models/ path -> use Gemini bridge with unified routing wrapper
|
||||
// ModelRoutingWrapper will check provider/mapping and proxy if needed
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
@@ -256,6 +258,41 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
})
|
||||
}
|
||||
|
||||
// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing.
|
||||
// This is used for testing the new routing implementation (T-021 onwards).
|
||||
func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper {
|
||||
// Create a registry - in production this would be populated with actual providers
|
||||
registry := routing.NewRegistry()
|
||||
|
||||
// Create a minimal config with just AmpCode settings
|
||||
// The Router only needs AmpCode.ModelMappings and OAuthModelAlias
|
||||
cfg := &config.Config{
|
||||
AmpCode: func() config.AmpCode {
|
||||
if m.modelMapper != nil {
|
||||
return config.AmpCode{
|
||||
ModelMappings: m.modelMapper.GetMappingsAsConfig(),
|
||||
}
|
||||
}
|
||||
return config.AmpCode{}
|
||||
}(),
|
||||
}
|
||||
|
||||
// Create router with registry and config
|
||||
router := routing.NewRouter(registry, cfg)
|
||||
|
||||
// Create wrapper with proxy function
|
||||
proxyFunc := func(c *gin.Context) {
|
||||
proxy := m.getProxy()
|
||||
if proxy != nil {
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
} else {
|
||||
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
||||
}
|
||||
}
|
||||
|
||||
return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc)
|
||||
}
|
||||
|
||||
// registerProviderAliases registers /api/provider/{provider}/... routes
|
||||
// These allow Amp CLI to route requests like:
|
||||
//
|
||||
@@ -269,12 +306,9 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||
|
||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
||||
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
|
||||
// Also includes model mapping support for routing unavailable models to alternatives
|
||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.getProxy()
|
||||
}, m.modelMapper, m.forceModelMappings)
|
||||
// Create unified routing wrapper (T-021 onwards)
|
||||
// Replaces FallbackHandler with Router-based unified routing
|
||||
routingWrapper := m.createModelRoutingWrapper()
|
||||
|
||||
// Provider-specific routes under /api/provider/:provider
|
||||
ampProviders := engine.Group("/api/provider")
|
||||
@@ -302,33 +336,36 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
}
|
||||
|
||||
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
||||
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
|
||||
// T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing
|
||||
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
|
||||
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
||||
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
||||
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
||||
provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
|
||||
provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
|
||||
provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
|
||||
|
||||
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
||||
v1Amp := provider.Group("/v1")
|
||||
{
|
||||
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
|
||||
|
||||
// OpenAI-compatible endpoints with fallback
|
||||
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
||||
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
||||
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
||||
// OpenAI-compatible endpoints with ModelRoutingWrapper
|
||||
// T-021, T-022: Migrated to unified routing wrapper
|
||||
v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
|
||||
v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
|
||||
v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
|
||||
|
||||
// Claude/Anthropic-compatible endpoints with fallback
|
||||
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
|
||||
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
|
||||
// Claude/Anthropic-compatible endpoints with ModelRoutingWrapper
|
||||
// T-023: Migrated Claude routes to unified routing wrapper
|
||||
v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages))
|
||||
v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens))
|
||||
}
|
||||
|
||||
// /v1beta routes (Gemini native API)
|
||||
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
|
||||
// T-024: Migrated Gemini v1beta routes to unified routing wrapper
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1005,8 +1005,8 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
}
|
||||
|
||||
// Notify Amp module only when Amp config has changed.
|
||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
|
||||
// Notify Amp module when Amp config or OAuth model aliases have changed.
|
||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) || !reflect.DeepEqual(oldCfg.OAuthModelAlias, cfg.OAuthModelAlias)
|
||||
if ampConfigChanged {
|
||||
if s.ampModule != nil {
|
||||
log.Debugf("triggering amp module config update")
|
||||
|
||||
39
internal/routing/adapter.go
Normal file
39
internal/routing/adapter.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Package routing provides adapter to integrate with existing codebase.
|
||||
package routing
|
||||
|
||||
import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// Adapter bridges the new routing layer with existing auth manager.
|
||||
type Adapter struct {
|
||||
router *Router
|
||||
exec *Executor
|
||||
}
|
||||
|
||||
// NewAdapter creates a new adapter with the given configuration and auth manager.
|
||||
func NewAdapter(cfg *config.Config, authManager *coreauth.Manager) *Adapter {
|
||||
registry := NewRegistry()
|
||||
|
||||
// TODO: Register OAuth providers from authManager
|
||||
// TODO: Register API key providers from cfg
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
exec := NewExecutor(router)
|
||||
|
||||
return &Adapter{
|
||||
router: router,
|
||||
exec: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Router returns the underlying router.
|
||||
func (a *Adapter) Router() *Router {
|
||||
return a.router
|
||||
}
|
||||
|
||||
// Executor returns the underlying executor.
|
||||
func (a *Adapter) Executor() *Executor {
|
||||
return a.exec
|
||||
}
|
||||
11
internal/routing/ctxkeys/keys.go
Normal file
11
internal/routing/ctxkeys/keys.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package ctxkeys
|
||||
|
||||
type key string
|
||||
|
||||
const (
|
||||
MappedModel key = "mapped_model"
|
||||
FallbackModels key = "fallback_models"
|
||||
RouteCandidates key = "route_candidates"
|
||||
RoutingDecision key = "routing_decision"
|
||||
MappingApplied key = "mapping_applied"
|
||||
)
|
||||
111
internal/routing/executor.go
Normal file
111
internal/routing/executor.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Executor handles request execution with fallback support.
|
||||
type Executor struct {
|
||||
router *Router
|
||||
}
|
||||
|
||||
// NewExecutor creates a new executor with the given router.
|
||||
func NewExecutor(router *Router) *Executor {
|
||||
return &Executor{router: router}
|
||||
}
|
||||
|
||||
// Execute sends the request through the routing decision.
|
||||
func (e *Executor) Execute(ctx context.Context, req executor.Request) (executor.Response, error) {
|
||||
decision := e.router.Resolve(req.Model)
|
||||
|
||||
log.Debugf("routing: %s -> %s (%d candidates)",
|
||||
decision.RequestedModel,
|
||||
decision.ResolvedModel,
|
||||
len(decision.Candidates))
|
||||
|
||||
var lastErr error
|
||||
tried := make(map[string]struct{})
|
||||
|
||||
for i, candidate := range decision.Candidates {
|
||||
key := candidate.Provider.Name() + "/" + candidate.Model
|
||||
if _, ok := tried[key]; ok {
|
||||
continue
|
||||
}
|
||||
tried[key] = struct{}{}
|
||||
|
||||
log.Debugf("routing: trying candidate %d/%d: %s with model %s",
|
||||
i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model)
|
||||
|
||||
req.Model = candidate.Model
|
||||
resp, err := candidate.Provider.Execute(ctx, candidate.Model, req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Debugf("routing: candidate failed: %v", err)
|
||||
|
||||
// Check if it's a fatal error (not retryable)
|
||||
if isFatalError(err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return executor.Response{}, lastErr
|
||||
}
|
||||
return executor.Response{}, errors.New("no available providers")
|
||||
}
|
||||
|
||||
// ExecuteStream sends a streaming request through the routing decision.
|
||||
func (e *Executor) ExecuteStream(ctx context.Context, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
decision := e.router.Resolve(req.Model)
|
||||
|
||||
log.Debugf("routing stream: %s -> %s (%d candidates)",
|
||||
decision.RequestedModel,
|
||||
decision.ResolvedModel,
|
||||
len(decision.Candidates))
|
||||
|
||||
var lastErr error
|
||||
tried := make(map[string]struct{})
|
||||
|
||||
for i, candidate := range decision.Candidates {
|
||||
key := candidate.Provider.Name() + "/" + candidate.Model
|
||||
if _, ok := tried[key]; ok {
|
||||
continue
|
||||
}
|
||||
tried[key] = struct{}{}
|
||||
|
||||
log.Debugf("routing stream: trying candidate %d/%d: %s with model %s",
|
||||
i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model)
|
||||
|
||||
req.Model = candidate.Model
|
||||
chunks, err := candidate.Provider.ExecuteStream(ctx, candidate.Model, req)
|
||||
if err == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Debugf("routing stream: candidate failed: %v", err)
|
||||
|
||||
if isFatalError(err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errors.New("no available providers")
|
||||
}
|
||||
|
||||
// isFatalError returns true if the error is not retryable.
|
||||
func isFatalError(err error) bool {
|
||||
// TODO: implement based on error type
|
||||
// For now, all errors are retryable
|
||||
return false
|
||||
}
|
||||
59
internal/routing/extractor.go
Normal file
59
internal/routing/extractor.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ModelExtractor extracts model names from request data.
|
||||
type ModelExtractor interface {
|
||||
// Extract returns the model name from the request body and gin parameters.
|
||||
// The ginParams map contains route parameters like "action" and "path".
|
||||
Extract(body []byte, ginParams map[string]string) (string, error)
|
||||
}
|
||||
|
||||
// DefaultModelExtractor is the standard implementation of ModelExtractor.
|
||||
type DefaultModelExtractor struct{}
|
||||
|
||||
// NewModelExtractor creates a new DefaultModelExtractor.
|
||||
func NewModelExtractor() *DefaultModelExtractor {
|
||||
return &DefaultModelExtractor{}
|
||||
}
|
||||
|
||||
// Extract extracts the model name from the request.
|
||||
// It checks in order:
|
||||
// 1. JSON body "model" field (OpenAI, Claude format)
|
||||
// 2. "action" parameter for Gemini standard format (e.g., "gemini-pro:generateContent")
|
||||
// 3. "path" parameter for AMP CLI Gemini format (e.g., "/publishers/google/models/gemini-3-pro:streamGenerateContent")
|
||||
func (e *DefaultModelExtractor) Extract(body []byte, ginParams map[string]string) (string, error) {
|
||||
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
// For Gemini requests, model is in the URL path
|
||||
// Standard format: /models/{model}:generateContent -> :action parameter
|
||||
if action, ok := ginParams["action"]; ok && action != "" {
|
||||
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
|
||||
parts := strings.Split(action, ":")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
|
||||
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||
if path, ok := ginParams["path"]; ok && path != "" {
|
||||
// Look for /models/{model}:method pattern
|
||||
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
||||
modelPart := path[idx+8:] // Skip "/models/"
|
||||
// Split by colon to get model name
|
||||
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
||||
return modelPart[:colonIdx], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
214
internal/routing/extractor_test.go
Normal file
214
internal/routing/extractor_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestModelExtractor_ExtractFromJSONBody(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "extract from JSON body with model field",
|
||||
body: []byte(`{"model":"gpt-4.1"}`),
|
||||
want: "gpt-4.1",
|
||||
},
|
||||
{
|
||||
name: "extract claude model from JSON body",
|
||||
body: []byte(`{"model":"claude-3-5-sonnet-20241022"}`),
|
||||
want: "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
{
|
||||
name: "extract with additional fields",
|
||||
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`),
|
||||
want: "gpt-4",
|
||||
},
|
||||
{
|
||||
name: "empty body returns empty",
|
||||
body: []byte{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "no model field returns empty",
|
||||
body: []byte(`{"messages":[]}`),
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "model is not string returns empty",
|
||||
body: []byte(`{"model":123}`),
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := extractor.Extract(tt.body, nil)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelExtractor_ExtractFromGeminiActionParam(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
ginParams map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "extract from action parameter - gemini-pro",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"action": "gemini-pro:generateContent"},
|
||||
want: "gemini-pro",
|
||||
},
|
||||
{
|
||||
name: "extract from action parameter - gemini-ultra",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"action": "gemini-ultra:chat"},
|
||||
want: "gemini-ultra",
|
||||
},
|
||||
{
|
||||
name: "empty action returns empty",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"action": ""},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "action without colon returns full value",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"action": "gemini-model"},
|
||||
want: "gemini-model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelExtractor_ExtractFromGeminiV1Beta1Path(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
ginParams map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "extract from v1beta1 path - gemini-3-pro",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro:streamGenerateContent"},
|
||||
want: "gemini-3-pro",
|
||||
},
|
||||
{
|
||||
name: "extract from v1beta1 path with preview",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro-preview:generateContent"},
|
||||
want: "gemini-3-pro-preview",
|
||||
},
|
||||
{
|
||||
name: "path without models segment returns empty",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": "/publishers/google/gemini-3-pro:streamGenerateContent"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty path returns empty",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": ""},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "path with /models/ but no colon returns empty",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro"},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelExtractor_ExtractPriority(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
// JSON body takes priority over gin params
|
||||
t.Run("JSON body takes priority over action param", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
params := map[string]string{"action": "gemini-pro:generateContent"}
|
||||
got, err := extractor.Extract(body, params)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", got)
|
||||
})
|
||||
|
||||
// Action param takes priority over path param
|
||||
t.Run("action param takes priority over path param", func(t *testing.T) {
|
||||
body := []byte(`{}`)
|
||||
params := map[string]string{
|
||||
"action": "gemini-action:generate",
|
||||
"path": "/publishers/google/models/gemini-path:streamGenerateContent",
|
||||
}
|
||||
got, err := extractor.Extract(body, params)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "gemini-action", got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelExtractor_NoModelFound(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
ginParams map[string]string
|
||||
}{
|
||||
{
|
||||
name: "empty body and no params",
|
||||
body: []byte{},
|
||||
ginParams: nil,
|
||||
},
|
||||
{
|
||||
name: "body without model and no params",
|
||||
body: []byte(`{"messages":[]}`),
|
||||
ginParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "irrelevant params only",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"other": "value"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
80
internal/routing/provider.go
Normal file
80
internal/routing/provider.go
Normal file
@@ -0,0 +1,80 @@
|
||||
// Package routing provides unified model routing for all provider types.
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// ProviderType indicates the type of provider.
|
||||
type ProviderType string
|
||||
|
||||
const (
|
||||
ProviderTypeOAuth ProviderType = "oauth"
|
||||
ProviderTypeAPIKey ProviderType = "api_key"
|
||||
ProviderTypeVertex ProviderType = "vertex"
|
||||
)
|
||||
|
||||
// Provider is the unified interface for all provider types (OAuth, API key, etc.).
|
||||
type Provider interface {
|
||||
// Name returns the unique provider identifier.
|
||||
Name() string
|
||||
|
||||
// Type returns the provider type.
|
||||
Type() ProviderType
|
||||
|
||||
// SupportsModel returns true if this provider can handle the given model.
|
||||
SupportsModel(model string) bool
|
||||
|
||||
// Available returns true if the provider is available for the model (not quota exceeded).
|
||||
Available(model string) bool
|
||||
|
||||
// Priority returns the priority for this provider (lower = tried first).
|
||||
Priority() int
|
||||
|
||||
// Execute sends the request to the provider.
|
||||
Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error)
|
||||
|
||||
// ExecuteStream sends a streaming request to the provider.
|
||||
ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error)
|
||||
}
|
||||
|
||||
// ProviderCandidate represents a provider + model combination to try.
|
||||
type ProviderCandidate struct {
|
||||
Provider Provider
|
||||
Model string // The actual model name to use (may be different from requested due to aliasing)
|
||||
}
|
||||
|
||||
// Registry manages all available providers.
|
||||
type Registry struct {
|
||||
providers []Provider
|
||||
}
|
||||
|
||||
// NewRegistry creates a new provider registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
providers: make([]Provider, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a provider to the registry.
|
||||
func (r *Registry) Register(p Provider) {
|
||||
r.providers = append(r.providers, p)
|
||||
}
|
||||
|
||||
// FindProviders returns all providers that support the given model and are available.
|
||||
func (r *Registry) FindProviders(model string) []Provider {
|
||||
var result []Provider
|
||||
for _, p := range r.providers {
|
||||
if p.SupportsModel(model) && p.Available(model) {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// All returns all registered providers.
|
||||
func (r *Registry) All() []Provider {
|
||||
return r.providers
|
||||
}
|
||||
156
internal/routing/providers/apikey.go
Normal file
156
internal/routing/providers/apikey.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// APIKeyProvider wraps API key configs as routing.Provider.
|
||||
type APIKeyProvider struct {
|
||||
name string
|
||||
provider string // claude, gemini, codex, vertex
|
||||
keys []APIKeyEntry
|
||||
mu sync.RWMutex
|
||||
client HTTPClient
|
||||
}
|
||||
|
||||
// APIKeyEntry represents a single API key configuration.
|
||||
type APIKeyEntry struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Models []config.ClaudeModel // Using ClaudeModel as generic model alias
|
||||
}
|
||||
|
||||
// HTTPClient interface for making HTTP requests.
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// NewAPIKeyProvider creates a new API key provider.
|
||||
func NewAPIKeyProvider(name, provider string, client HTTPClient) *APIKeyProvider {
|
||||
return &APIKeyProvider{
|
||||
name: name,
|
||||
provider: provider,
|
||||
keys: make([]APIKeyEntry, 0),
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name.
|
||||
func (p *APIKeyProvider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// Type returns ProviderTypeAPIKey.
|
||||
func (p *APIKeyProvider) Type() routing.ProviderType {
|
||||
return routing.ProviderTypeAPIKey
|
||||
}
|
||||
|
||||
// SupportsModel checks if the model is supported by this provider.
|
||||
func (p *APIKeyProvider) SupportsModel(model string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, key := range p.keys {
|
||||
for _, m := range key.Models {
|
||||
if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Available always returns true for API keys (unless explicitly disabled).
|
||||
func (p *APIKeyProvider) Available(model string) bool {
|
||||
return p.SupportsModel(model)
|
||||
}
|
||||
|
||||
// Priority returns the priority (API key is lower priority than OAuth).
|
||||
func (p *APIKeyProvider) Priority() int {
|
||||
return 20
|
||||
}
|
||||
|
||||
// Execute sends the request using the API key.
|
||||
func (p *APIKeyProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
key := p.selectKey(model)
|
||||
if key == nil {
|
||||
return executor.Response{}, ErrNoMatchingAPIKey
|
||||
}
|
||||
|
||||
// Resolve the actual model name from alias
|
||||
actualModel := p.resolveModel(key, model)
|
||||
|
||||
// Execute via HTTP client
|
||||
return p.executeHTTP(ctx, key, actualModel, req)
|
||||
}
|
||||
|
||||
// ExecuteStream sends a streaming request.
|
||||
func (p *APIKeyProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (
|
||||
<-chan executor.StreamChunk, error) {
|
||||
key := p.selectKey(model)
|
||||
if key == nil {
|
||||
return nil, ErrNoMatchingAPIKey
|
||||
}
|
||||
|
||||
actualModel := p.resolveModel(key, model)
|
||||
return p.executeHTTPStream(ctx, key, actualModel, req)
|
||||
}
|
||||
|
||||
// AddKey adds an API key entry.
|
||||
func (p *APIKeyProvider) AddKey(entry APIKeyEntry) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.keys = append(p.keys, entry)
|
||||
}
|
||||
|
||||
// selectKey selects a key that supports the model.
|
||||
func (p *APIKeyProvider) selectKey(model string) *APIKeyEntry {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, key := range p.keys {
|
||||
for _, m := range key.Models {
|
||||
if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) {
|
||||
return &key
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveModel resolves alias to actual model name.
|
||||
func (p *APIKeyProvider) resolveModel(key *APIKeyEntry, requested string) string {
|
||||
for _, m := range key.Models {
|
||||
if strings.EqualFold(m.Alias, requested) {
|
||||
return m.Name
|
||||
}
|
||||
}
|
||||
return requested
|
||||
}
|
||||
|
||||
// executeHTTP makes the HTTP request.
|
||||
func (p *APIKeyProvider) executeHTTP(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) (executor.Response, error) {
|
||||
// TODO: implement actual HTTP execution
|
||||
// This is a placeholder - actual implementation would build HTTP request
|
||||
return executor.Response{}, errors.New("not yet implemented")
|
||||
}
|
||||
|
||||
// executeHTTPStream makes a streaming HTTP request.
|
||||
func (p *APIKeyProvider) executeHTTPStream(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) (
|
||||
<-chan executor.StreamChunk, error) {
|
||||
// TODO: implement actual HTTP streaming
|
||||
return nil, errors.New("not yet implemented")
|
||||
}
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrNoMatchingAPIKey = errors.New("no API key supports the requested model")
|
||||
)
|
||||
132
internal/routing/providers/oauth.go
Normal file
132
internal/routing/providers/oauth.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// OAuthProvider wraps OAuth-based auths as routing.Provider.
|
||||
type OAuthProvider struct {
|
||||
name string
|
||||
auths []*coreauth.Auth
|
||||
mu sync.RWMutex
|
||||
executor coreauth.ProviderExecutor
|
||||
}
|
||||
|
||||
// NewOAuthProvider creates a new OAuth provider.
|
||||
func NewOAuthProvider(name string, exec coreauth.ProviderExecutor) *OAuthProvider {
|
||||
return &OAuthProvider{
|
||||
name: name,
|
||||
auths: make([]*coreauth.Auth, 0),
|
||||
executor: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name.
|
||||
func (p *OAuthProvider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// Type returns ProviderTypeOAuth.
|
||||
func (p *OAuthProvider) Type() routing.ProviderType {
|
||||
return routing.ProviderTypeOAuth
|
||||
}
|
||||
|
||||
// SupportsModel checks if any auth supports the model.
|
||||
func (p *OAuthProvider) SupportsModel(model string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// OAuth providers typically support models via oauth-model-alias
|
||||
// The actual model support is determined at execution time
|
||||
return true
|
||||
}
|
||||
|
||||
// Available checks if there's an available auth for the model.
|
||||
func (p *OAuthProvider) Available(model string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, auth := range p.auths {
|
||||
if p.isAuthAvailable(auth, model) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Priority returns the priority (OAuth is preferred over API key).
|
||||
func (p *OAuthProvider) Priority() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
// Execute sends the request using an available OAuth auth.
|
||||
func (p *OAuthProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
auth := p.selectAuth(model)
|
||||
if auth == nil {
|
||||
return executor.Response{}, ErrNoAvailableAuth
|
||||
}
|
||||
|
||||
return p.executor.Execute(ctx, auth, req, executor.Options{})
|
||||
}
|
||||
|
||||
// ExecuteStream sends a streaming request.
|
||||
func (p *OAuthProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
auth := p.selectAuth(model)
|
||||
if auth == nil {
|
||||
return nil, ErrNoAvailableAuth
|
||||
}
|
||||
|
||||
return p.executor.ExecuteStream(ctx, auth, req, executor.Options{})
|
||||
}
|
||||
|
||||
// AddAuth adds an auth to this provider.
|
||||
func (p *OAuthProvider) AddAuth(auth *coreauth.Auth) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.auths = append(p.auths, auth)
|
||||
}
|
||||
|
||||
// RemoveAuth removes an auth from this provider.
|
||||
func (p *OAuthProvider) RemoveAuth(authID string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
filtered := make([]*coreauth.Auth, 0, len(p.auths))
|
||||
for _, auth := range p.auths {
|
||||
if auth.ID != authID {
|
||||
filtered = append(filtered, auth)
|
||||
}
|
||||
}
|
||||
p.auths = filtered
|
||||
}
|
||||
|
||||
// isAuthAvailable checks if an auth is available for the model.
|
||||
func (p *OAuthProvider) isAuthAvailable(auth *coreauth.Auth, model string) bool {
|
||||
// TODO: integrate with model_registry for quota checking
|
||||
// For now, just check if auth exists
|
||||
return auth != nil
|
||||
}
|
||||
|
||||
// selectAuth selects an available auth for the model.
|
||||
func (p *OAuthProvider) selectAuth(model string) *coreauth.Auth {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, auth := range p.auths {
|
||||
if p.isAuthAvailable(auth, model) {
|
||||
return auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrNoAvailableAuth = errors.New("no available OAuth auth for model")
|
||||
)
|
||||
159
internal/routing/rewriter.go
Normal file
159
internal/routing/rewriter.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ModelRewriter handles model name rewriting in requests and responses.
|
||||
type ModelRewriter interface {
|
||||
// RewriteRequestBody rewrites the model field in a JSON request body.
|
||||
// Returns the modified body or the original if no rewrite was needed.
|
||||
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
|
||||
|
||||
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
|
||||
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
|
||||
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
|
||||
}
|
||||
|
||||
// DefaultModelRewriter is the standard implementation of ModelRewriter.
|
||||
type DefaultModelRewriter struct{}
|
||||
|
||||
// NewModelRewriter creates a new DefaultModelRewriter.
|
||||
func NewModelRewriter() *DefaultModelRewriter {
|
||||
return &DefaultModelRewriter{}
|
||||
}
|
||||
|
||||
// RewriteRequestBody replaces the model name in a JSON request body.
|
||||
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
|
||||
if !gjson.GetBytes(body, "model").Exists() {
|
||||
return body, nil
|
||||
}
|
||||
result, err := sjson.SetBytes(body, "model", newModel)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// WrapResponseWriter wraps a response writer to rewrite model names.
|
||||
// The cleanup function must be called after the handler completes to flush any buffered data.
|
||||
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
|
||||
rw := &responseRewriter{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
requestedModel: requestedModel,
|
||||
resolvedModel: resolvedModel,
|
||||
}
|
||||
return rw, func() { rw.flush() }
|
||||
}
|
||||
|
||||
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
|
||||
type responseRewriter struct {
|
||||
http.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
requestedModel string
|
||||
resolvedModel string
|
||||
isStreaming bool
|
||||
wroteHeader bool
|
||||
flushed bool
|
||||
}
|
||||
|
||||
// Write intercepts response writes and buffers them for model name replacement.
|
||||
func (rw *responseRewriter) Write(data []byte) (int, error) {
|
||||
// Ensure header is written
|
||||
if !rw.wroteHeader {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// Detect streaming on first write
|
||||
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||
contentType := rw.Header().Get("Content-Type")
|
||||
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||
strings.Contains(contentType, "stream")
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||
func (rw *responseRewriter) WriteHeader(code int) {
|
||||
if !rw.wroteHeader {
|
||||
rw.wroteHeader = true
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
}
|
||||
|
||||
// flush writes the buffered response with model names rewritten.
|
||||
func (rw *responseRewriter) flush() {
|
||||
if rw.flushed {
|
||||
return
|
||||
}
|
||||
rw.flushed = true
|
||||
|
||||
if rw.isStreaming {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
if rw.body.Len() > 0 {
|
||||
data := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||
if _, err := rw.ResponseWriter.Write(data); err != nil {
|
||||
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear.
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
|
||||
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||
return data
|
||||
}
|
||||
|
||||
for _, path := range modelFieldPaths {
|
||||
if gjson.GetBytes(data, path).Exists() {
|
||||
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// rewriteStreamChunk rewrites model names in SSE stream chunks.
|
||||
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||
return chunk
|
||||
}
|
||||
|
||||
// SSE format: "data: {json}\n\n"
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i, line := range lines {
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
// Rewrite JSON in the data line
|
||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||
lines[i] = append([]byte("data: "), rewritten...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
}
|
||||
342
internal/routing/rewriter_test.go
Normal file
342
internal/routing/rewriter_test.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestModelRewriter_RewriteRequestBody(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
newModel string
|
||||
wantModel string
|
||||
wantChange bool
|
||||
}{
|
||||
{
|
||||
name: "rewrites model field in JSON body",
|
||||
body: []byte(`{"model":"gpt-4.1","messages":[]}`),
|
||||
newModel: "claude-local",
|
||||
wantModel: "claude-local",
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
name: "rewrites with empty body returns empty",
|
||||
body: []byte{},
|
||||
newModel: "gpt-4",
|
||||
wantModel: "",
|
||||
wantChange: false,
|
||||
},
|
||||
{
|
||||
name: "handles missing model field gracefully",
|
||||
body: []byte(`{"messages":[{"role":"user"}]}`),
|
||||
newModel: "gpt-4",
|
||||
wantModel: "",
|
||||
wantChange: false,
|
||||
},
|
||||
{
|
||||
name: "preserves other fields when rewriting",
|
||||
body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`),
|
||||
newModel: "new-model",
|
||||
wantModel: "new-model",
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
name: "handles nested JSON structure",
|
||||
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`),
|
||||
newModel: "claude-3-opus",
|
||||
wantModel: "claude-3-opus",
|
||||
wantChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.wantChange {
|
||||
assert.NotEqual(t, string(tt.body), string(result), "body should have been modified")
|
||||
}
|
||||
|
||||
if tt.wantModel != "" {
|
||||
// Parse result and check model field
|
||||
model, _ := NewModelExtractor().Extract(result, nil)
|
||||
assert.Equal(t, tt.wantModel, model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRewriter_WrapResponseWriter(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
t.Run("response writer wraps without error", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
require.NotNil(t, wrapped)
|
||||
require.NotNil(t, cleanup)
|
||||
defer cleanup()
|
||||
})
|
||||
|
||||
t.Run("rewrites model in non-streaming response", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
// Write a response with the resolved model
|
||||
response := []byte(`{"model":"claude-local","content":"hello"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
_, err := wrapped.Write(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanup triggers the rewrite
|
||||
cleanup()
|
||||
|
||||
// Check the response was rewritten to the requested model
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||
})
|
||||
|
||||
t.Run("no-op when requested equals resolved", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4")
|
||||
|
||||
response := []byte(`{"model":"gpt-4","content":"hello"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
_, err := wrapped.Write(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||
})
|
||||
|
||||
t.Run("rewrites modelVersion field", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"modelVersion":"claude-local","content":"hello"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
_, err := wrapped.Write(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"modelVersion":"gpt-4"`)
|
||||
})
|
||||
|
||||
t.Run("handles streaming responses", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
// Set streaming content type
|
||||
wrapped.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Write SSE chunks with resolved model
|
||||
chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n")
|
||||
_, err := wrapped.Write(chunk1)
|
||||
require.NoError(t, err)
|
||||
|
||||
chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n")
|
||||
_, err = wrapped.Write(chunk2)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup()
|
||||
|
||||
// For streaming, data is written immediately with rewrites
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||
})
|
||||
|
||||
t.Run("empty body handled gracefully", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
// Don't write anything
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Empty(t, body)
|
||||
})
|
||||
|
||||
t.Run("preserves other JSON fields", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
_, err := wrapped.Write(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"temperature":0.7`)
|
||||
assert.Contains(t, string(body), `"prompt_tokens":10`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponseRewriter_ImplementsInterfaces(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
defer cleanup()
|
||||
|
||||
// Should implement http.ResponseWriter
|
||||
assert.Implements(t, (*http.ResponseWriter)(nil), wrapped)
|
||||
|
||||
// Should preserve header access
|
||||
wrapped.Header().Set("X-Custom", "value")
|
||||
assert.Equal(t, "value", recorder.Header().Get("X-Custom"))
|
||||
|
||||
// Should write status
|
||||
wrapped.WriteHeader(http.StatusCreated)
|
||||
assert.Equal(t, http.StatusCreated, recorder.Code)
|
||||
}
|
||||
|
||||
func TestResponseRewriter_Flush(t *testing.T) {
|
||||
t.Run("flush writes buffered content", func(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"model":"claude-local","content":"test"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
wrapped.Write(response)
|
||||
|
||||
// Before cleanup, response should be empty (buffered)
|
||||
assert.Empty(t, recorder.Body.Bytes())
|
||||
|
||||
// After cleanup, response should be written
|
||||
cleanup()
|
||||
assert.NotEmpty(t, recorder.Body.Bytes())
|
||||
})
|
||||
|
||||
t.Run("multiple flush calls are safe", func(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"model":"claude-local"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
wrapped.Write(response)
|
||||
|
||||
// First cleanup
|
||||
cleanup()
|
||||
firstBody := recorder.Body.Bytes()
|
||||
|
||||
// Second cleanup should not write again
|
||||
cleanup()
|
||||
secondBody := recorder.Body.Bytes()
|
||||
|
||||
assert.Equal(t, firstBody, secondBody)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponseRewriter_StreamingWithDataLines(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
wrapped.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// SSE format with multiple data lines
|
||||
chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n")
|
||||
wrapped.Write(chunk)
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
// Both data lines should have model rewritten
|
||||
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||
}
|
||||
|
||||
func TestModelRewriter_RoundTrip(t *testing.T) {
|
||||
// Simulate a full request -> response cycle with model rewriting
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
// Step 1: Rewrite request body
|
||||
originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`)
|
||||
rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify request was rewritten
|
||||
extractor := NewModelExtractor()
|
||||
requestModel, _ := extractor.Extract(rewrittenRequest, nil)
|
||||
assert.Equal(t, "claude-local", requestModel)
|
||||
|
||||
// Step 2: Simulate response with resolved model
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
wrapped.Write(response)
|
||||
cleanup()
|
||||
|
||||
// Verify response was rewritten back
|
||||
body, _ := io.ReadAll(recorder.Result().Body)
|
||||
responseModel, _ := extractor.Extract(body, nil)
|
||||
assert.Equal(t, "gpt-4", responseModel)
|
||||
}
|
||||
|
||||
func TestModelRewriter_NonJSONBody(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
// Binary/non-JSON body should be returned unchanged
|
||||
body := []byte{0x00, 0x01, 0x02, 0x03}
|
||||
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestModelRewriter_InvalidJSON(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
// Invalid JSON without model field should be returned unchanged
|
||||
body := []byte(`not valid json`)
|
||||
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestResponseRewriter_StatusCodePreserved(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
wrapped.WriteHeader(http.StatusAccepted)
|
||||
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
||||
cleanup()
|
||||
|
||||
assert.Equal(t, http.StatusAccepted, recorder.Code)
|
||||
}
|
||||
|
||||
func TestResponseRewriter_HeaderFlushed(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
wrapped.Header().Set("X-Request-ID", "abc123")
|
||||
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
||||
cleanup()
|
||||
|
||||
result := recorder.Result()
|
||||
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "abc123", result.Header.Get("X-Request-ID"))
|
||||
}
|
||||
267
internal/routing/router.go
Normal file
267
internal/routing/router.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
)
|
||||
|
||||
// Router resolves models to provider candidates.
|
||||
type Router struct {
|
||||
registry *Registry
|
||||
modelMappings map[string]string // normalized from -> to
|
||||
oauthAliases map[string][]string // normalized model -> []alias
|
||||
}
|
||||
|
||||
// NewRouter creates a new router with the given configuration.
|
||||
func NewRouter(registry *Registry, cfg *config.Config) *Router {
|
||||
r := &Router{
|
||||
registry: registry,
|
||||
modelMappings: make(map[string]string),
|
||||
oauthAliases: make(map[string][]string),
|
||||
}
|
||||
|
||||
if cfg != nil {
|
||||
r.loadModelMappings(cfg.AmpCode.ModelMappings)
|
||||
r.loadOAuthAliases(cfg.OAuthModelAlias)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// LegacyRoutingDecision contains the resolved routing information.
|
||||
// Deprecated: Will be replaced by RoutingDecision from types.go in T-013.
|
||||
type LegacyRoutingDecision struct {
|
||||
RequestedModel string // Original model from request
|
||||
ResolvedModel string // After model-mappings
|
||||
Candidates []ProviderCandidate // Ordered list of providers to try
|
||||
}
|
||||
|
||||
// Resolve determines the routing decision for the requested model.
|
||||
// Deprecated: Will be updated to use RoutingRequest and return *RoutingDecision in T-013.
|
||||
func (r *Router) Resolve(requestedModel string) *LegacyRoutingDecision {
|
||||
// 1. Extract thinking suffix
|
||||
suffixResult := thinking.ParseSuffix(requestedModel)
|
||||
baseModel := suffixResult.ModelName
|
||||
|
||||
// 2. Apply model-mappings
|
||||
targetModel := r.applyMappings(baseModel)
|
||||
|
||||
// 3. Find primary providers
|
||||
candidates := r.findCandidates(targetModel, suffixResult)
|
||||
|
||||
// 4. Add fallback aliases
|
||||
for _, alias := range r.oauthAliases[strings.ToLower(targetModel)] {
|
||||
candidates = append(candidates, r.findCandidates(alias, suffixResult)...)
|
||||
}
|
||||
|
||||
// 5. Sort by priority
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
||||
})
|
||||
|
||||
return &LegacyRoutingDecision{
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: targetModel,
|
||||
Candidates: candidates,
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveV2 determines the routing decision for a routing request.
|
||||
// It uses the new RoutingRequest and RoutingDecision types.
|
||||
func (r *Router) ResolveV2(req RoutingRequest) *RoutingDecision {
|
||||
// 1. Extract thinking suffix
|
||||
suffixResult := thinking.ParseSuffix(req.RequestedModel)
|
||||
baseModel := suffixResult.ModelName
|
||||
thinkingSuffix := ""
|
||||
if suffixResult.HasSuffix {
|
||||
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
// 2. Check for local providers
|
||||
localCandidates := r.findLocalCandidates(baseModel, suffixResult)
|
||||
|
||||
// 3. Apply model-mappings if needed
|
||||
mappedModel := r.applyMappings(baseModel)
|
||||
mappingCandidates := r.findLocalCandidates(mappedModel, suffixResult)
|
||||
|
||||
// 4. Determine route type based on preferences and availability
|
||||
var decision *RoutingDecision
|
||||
|
||||
if req.ForceModelMapping && mappedModel != baseModel && len(mappingCandidates) > 0 {
|
||||
// FORCE MODE: Use mapping even if local provider exists
|
||||
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
|
||||
} else if req.PreferLocalProvider && len(localCandidates) > 0 {
|
||||
// DEFAULT MODE with local preference: Use local provider first
|
||||
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
|
||||
} else if len(localCandidates) > 0 {
|
||||
// DEFAULT MODE: Local provider available
|
||||
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
|
||||
} else if mappedModel != baseModel && len(mappingCandidates) > 0 {
|
||||
// DEFAULT MODE: No local provider, but mapping available
|
||||
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
|
||||
} else {
|
||||
// No local provider, no mapping - use amp credits proxy
|
||||
decision = &RoutingDecision{
|
||||
RouteType: RouteTypeAmpCredits,
|
||||
ResolvedModel: req.RequestedModel,
|
||||
ShouldProxy: true,
|
||||
}
|
||||
}
|
||||
|
||||
return decision
|
||||
}
|
||||
|
||||
// findLocalCandidates finds local provider candidates for a model.
|
||||
func (r *Router) findLocalCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
|
||||
var candidates []ProviderCandidate
|
||||
|
||||
for _, p := range r.registry.All() {
|
||||
if !p.SupportsModel(model) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply thinking suffix if needed
|
||||
actualModel := model
|
||||
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
|
||||
actualModel = model + "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
if p.Available(actualModel) {
|
||||
candidates = append(candidates, ProviderCandidate{
|
||||
Provider: p,
|
||||
Model: actualModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by priority
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
||||
})
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// buildLocalProviderDecision creates a decision for local provider routing.
|
||||
func (r *Router) buildLocalProviderDecision(requestedModel string, candidates []ProviderCandidate, thinkingSuffix string) *RoutingDecision {
|
||||
resolvedModel := requestedModel
|
||||
if thinkingSuffix != "" {
|
||||
// Ensure thinking suffix is preserved
|
||||
sr := thinking.ParseSuffix(requestedModel)
|
||||
if !sr.HasSuffix {
|
||||
resolvedModel = requestedModel + thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
var fallbackModels []string
|
||||
if len(candidates) > 1 {
|
||||
for _, c := range candidates[1:] {
|
||||
fallbackModels = append(fallbackModels, c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
return &RoutingDecision{
|
||||
RouteType: RouteTypeLocalProvider,
|
||||
ResolvedModel: resolvedModel,
|
||||
ProviderName: candidates[0].Provider.Name(),
|
||||
FallbackModels: fallbackModels,
|
||||
ShouldProxy: false,
|
||||
}
|
||||
}
|
||||
|
||||
// buildMappingDecision creates a decision for model mapping routing.
|
||||
func (r *Router) buildMappingDecision(requestedModel, mappedModel string, candidates []ProviderCandidate, thinkingSuffix string, fallbackCandidates []ProviderCandidate) *RoutingDecision {
|
||||
// Apply thinking suffix to resolved model if needed
|
||||
resolvedModel := mappedModel
|
||||
if thinkingSuffix != "" {
|
||||
sr := thinking.ParseSuffix(mappedModel)
|
||||
if !sr.HasSuffix {
|
||||
resolvedModel = mappedModel + thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
var fallbackModels []string
|
||||
for _, c := range fallbackCandidates {
|
||||
fallbackModels = append(fallbackModels, c.Model)
|
||||
}
|
||||
|
||||
// Also add oauth aliases as fallbacks
|
||||
baseMapped := thinking.ParseSuffix(mappedModel).ModelName
|
||||
for _, alias := range r.oauthAliases[strings.ToLower(baseMapped)] {
|
||||
// Check if this alias has providers
|
||||
aliasCandidates := r.findLocalCandidates(alias, thinking.SuffixResult{ModelName: alias})
|
||||
for _, c := range aliasCandidates {
|
||||
fallbackModels = append(fallbackModels, c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
return &RoutingDecision{
|
||||
RouteType: RouteTypeModelMapping,
|
||||
ResolvedModel: resolvedModel,
|
||||
ProviderName: candidates[0].Provider.Name(),
|
||||
FallbackModels: fallbackModels,
|
||||
ShouldProxy: false,
|
||||
}
|
||||
}
|
||||
|
||||
// applyMappings applies model-mappings configuration.
|
||||
func (r *Router) applyMappings(model string) string {
|
||||
key := strings.ToLower(strings.TrimSpace(model))
|
||||
if mapped, ok := r.modelMappings[key]; ok {
|
||||
return mapped
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// findCandidates finds all provider candidates for a model.
|
||||
func (r *Router) findCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
|
||||
var candidates []ProviderCandidate
|
||||
|
||||
for _, p := range r.registry.All() {
|
||||
if !p.SupportsModel(model) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply thinking suffix if needed
|
||||
actualModel := model
|
||||
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
|
||||
actualModel = model + "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
if p.Available(actualModel) {
|
||||
candidates = append(candidates, ProviderCandidate{
|
||||
Provider: p,
|
||||
Model: actualModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// loadModelMappings loads model-mappings from config.
|
||||
func (r *Router) loadModelMappings(mappings []config.AmpModelMapping) {
|
||||
for _, m := range mappings {
|
||||
from := strings.ToLower(strings.TrimSpace(m.From))
|
||||
to := strings.TrimSpace(m.To)
|
||||
if from != "" && to != "" {
|
||||
r.modelMappings[from] = to
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadOAuthAliases loads oauth-model-alias from config.
|
||||
func (r *Router) loadOAuthAliases(aliases map[string][]config.OAuthModelAlias) {
|
||||
for _, entries := range aliases {
|
||||
for _, entry := range entries {
|
||||
name := strings.ToLower(strings.TrimSpace(entry.Name))
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name != "" && alias != "" && name != alias {
|
||||
r.oauthAliases[name] = append(r.oauthAliases[name], alias)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
115
internal/routing/router_test.go
Normal file
115
internal/routing/router_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockProvider is a test double for Provider.
|
||||
type mockProvider struct {
|
||||
name string
|
||||
providerType ProviderType
|
||||
supportsModels map[string]bool
|
||||
available bool
|
||||
priority int
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string { return m.name }
|
||||
func (m *mockProvider) Type() ProviderType { return m.providerType }
|
||||
func (m *mockProvider) SupportsModel(model string) bool { return m.supportsModels[model] }
|
||||
func (m *mockProvider) Available(model string) bool { return m.available }
|
||||
func (m *mockProvider) Priority() int { return m.priority }
|
||||
func (m *mockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
return executor.Response{}, nil
|
||||
}
|
||||
func (m *mockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestRouter_Resolve_ModelMappings(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
|
||||
// Add a provider
|
||||
p := &mockProvider{
|
||||
name: "test-provider",
|
||||
providerType: ProviderTypeOAuth,
|
||||
supportsModels: map[string]bool{"target-model": true},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(p)
|
||||
|
||||
// Create router with model mapping
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "user-model", To: "target-model"},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Resolve
|
||||
decision := router.Resolve("user-model")
|
||||
|
||||
assert.Equal(t, "user-model", decision.RequestedModel)
|
||||
assert.Equal(t, "target-model", decision.ResolvedModel)
|
||||
assert.Len(t, decision.Candidates, 1)
|
||||
assert.Equal(t, "target-model", decision.Candidates[0].Model)
|
||||
}
|
||||
|
||||
func TestRouter_Resolve_OAuthAliases(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
|
||||
// Add providers
|
||||
p1 := &mockProvider{
|
||||
name: "oauth-1",
|
||||
providerType: ProviderTypeOAuth,
|
||||
supportsModels: map[string]bool{"primary-model": true},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
p2 := &mockProvider{
|
||||
name: "oauth-2",
|
||||
providerType: ProviderTypeOAuth,
|
||||
supportsModels: map[string]bool{"fallback-model": true},
|
||||
available: true,
|
||||
priority: 2,
|
||||
}
|
||||
registry.Register(p1)
|
||||
registry.Register(p2)
|
||||
|
||||
// Create router with oauth aliases
|
||||
cfg := &config.Config{
|
||||
OAuthModelAlias: map[string][]config.OAuthModelAlias{
|
||||
"test-channel": {
|
||||
{Name: "primary-model", Alias: "fallback-model"},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Resolve
|
||||
decision := router.Resolve("primary-model")
|
||||
|
||||
assert.Equal(t, "primary-model", decision.ResolvedModel)
|
||||
assert.Len(t, decision.Candidates, 2)
|
||||
// Primary should come first (lower priority value)
|
||||
assert.Equal(t, "primary-model", decision.Candidates[0].Model)
|
||||
assert.Equal(t, "fallback-model", decision.Candidates[1].Model)
|
||||
}
|
||||
|
||||
func TestRouter_Resolve_NoProviders(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
cfg := &config.Config{}
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
decision := router.Resolve("unknown-model")
|
||||
|
||||
assert.Equal(t, "unknown-model", decision.ResolvedModel)
|
||||
assert.Empty(t, decision.Candidates)
|
||||
}
|
||||
245
internal/routing/router_v2_test.go
Normal file
245
internal/routing/router_v2_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRouter_DefaultMode_PrefersLocal(t *testing.T) {
|
||||
// Setup: Create a router with a mock provider that supports "gpt-4"
|
||||
registry := NewRegistry()
|
||||
mockProvider := &MockProvider{
|
||||
name: "openai",
|
||||
supportedModels: []string{"gpt-4"},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(mockProvider)
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "gpt-4", To: "claude-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request gpt-4 when local provider exists
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "gpt-4",
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Should return LOCAL_PROVIDER, not MODEL_MAPPING
|
||||
assert.Equal(t, RouteTypeLocalProvider, decision.RouteType)
|
||||
assert.Equal(t, "gpt-4", decision.ResolvedModel)
|
||||
assert.Equal(t, "openai", decision.ProviderName)
|
||||
assert.False(t, decision.ShouldProxy)
|
||||
}
|
||||
|
||||
func TestRouter_DefaultMode_MapsWhenNoLocal(t *testing.T) {
|
||||
// Setup: Create a router with NO provider for "gpt-4" but a mapping to "claude-local"
|
||||
// which has a provider
|
||||
registry := NewRegistry()
|
||||
mockProvider := &MockProvider{
|
||||
name: "anthropic",
|
||||
supportedModels: []string{"claude-local"},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(mockProvider)
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "gpt-4", To: "claude-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request gpt-4 when no local provider exists, but mapping exists
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "gpt-4",
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Should return MODEL_MAPPING
|
||||
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||
assert.Equal(t, "claude-local", decision.ResolvedModel)
|
||||
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||
assert.False(t, decision.ShouldProxy)
|
||||
}
|
||||
|
||||
func TestRouter_DefaultMode_AmpCreditsWhenNoLocalOrMapping(t *testing.T) {
|
||||
// Setup: Create a router with no providers and no mappings
|
||||
registry := NewRegistry()
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request a model with no local provider and no mapping
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "unknown-model",
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Should return AMP_CREDITS with ShouldProxy=true
|
||||
assert.Equal(t, RouteTypeAmpCredits, decision.RouteType)
|
||||
assert.Equal(t, "unknown-model", decision.ResolvedModel)
|
||||
assert.True(t, decision.ShouldProxy)
|
||||
assert.Empty(t, decision.ProviderName)
|
||||
}
|
||||
|
||||
func TestRouter_ForceMode_MapsEvenWithLocal(t *testing.T) {
|
||||
// Setup: Create a router with BOTH a local provider for "gpt-4" AND a mapping from "gpt-4" to "claude-local"
|
||||
// The mapping target "claude-local" also has a provider
|
||||
registry := NewRegistry()
|
||||
|
||||
// Local provider for gpt-4
|
||||
openaiProvider := &MockProvider{
|
||||
name: "openai",
|
||||
supportedModels: []string{"gpt-4"},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(openaiProvider)
|
||||
|
||||
// Local provider for the mapped model
|
||||
anthropicProvider := &MockProvider{
|
||||
name: "anthropic",
|
||||
supportedModels: []string{"claude-local"},
|
||||
available: true,
|
||||
priority: 2,
|
||||
}
|
||||
registry.Register(anthropicProvider)
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "gpt-4", To: "claude-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request gpt-4 with ForceModelMapping=true
|
||||
// Even though gpt-4 has a local provider, mapping should take precedence
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "gpt-4",
|
||||
PreferLocalProvider: false,
|
||||
ForceModelMapping: true,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Should return MODEL_MAPPING, not LOCAL_PROVIDER
|
||||
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||
assert.Equal(t, "claude-local", decision.ResolvedModel)
|
||||
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||
assert.False(t, decision.ShouldProxy)
|
||||
}
|
||||
|
||||
func TestRouter_ThinkingSuffix_Preserved(t *testing.T) {
|
||||
// Setup: Create a router with mapping and provider for mapped model
|
||||
registry := NewRegistry()
|
||||
|
||||
mockProvider := &MockProvider{
|
||||
name: "anthropic",
|
||||
supportedModels: []string{"claude-local"},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(mockProvider)
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "claude-3-5-sonnet", To: "claude-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request claude-3-5-sonnet with thinking suffix
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "claude-3-5-sonnet(thinking:foo)",
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Thinking suffix should be preserved in resolved model
|
||||
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||
assert.Equal(t, "claude-local(thinking:foo)", decision.ResolvedModel)
|
||||
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||
}
|
||||
|
||||
// MockProvider is a mock implementation of Provider for testing
|
||||
type MockProvider struct {
|
||||
name string
|
||||
providerType ProviderType
|
||||
supportedModels []string
|
||||
available bool
|
||||
priority int
|
||||
}
|
||||
|
||||
func (m *MockProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockProvider) Type() ProviderType {
|
||||
if m.providerType == "" {
|
||||
return ProviderTypeOAuth
|
||||
}
|
||||
return m.providerType
|
||||
}
|
||||
|
||||
func (m *MockProvider) SupportsModel(model string) bool {
|
||||
for _, supported := range m.supportedModels {
|
||||
if supported == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockProvider) Available(model string) bool {
|
||||
return m.available
|
||||
}
|
||||
|
||||
func (m *MockProvider) Priority() int {
|
||||
return m.priority
|
||||
}
|
||||
|
||||
func (m *MockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
return executor.Response{}, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
return nil, nil
|
||||
}
|
||||
113
internal/routing/testutil/fake_handler.go
Normal file
113
internal/routing/testutil/fake_handler.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// FakeHandlerRecorder records handler invocations for testing.
|
||||
type FakeHandlerRecorder struct {
|
||||
Called bool
|
||||
CallCount int
|
||||
RequestBody []byte
|
||||
RequestHeader http.Header
|
||||
ContextKeys map[string]interface{}
|
||||
ResponseStatus int
|
||||
ResponseBody []byte
|
||||
}
|
||||
|
||||
// NewFakeHandlerRecorder creates a new fake handler recorder.
|
||||
func NewFakeHandlerRecorder() *FakeHandlerRecorder {
|
||||
return &FakeHandlerRecorder{
|
||||
ContextKeys: make(map[string]interface{}),
|
||||
ResponseStatus: http.StatusOK,
|
||||
ResponseBody: []byte(`{"status":"handled"}`),
|
||||
}
|
||||
}
|
||||
|
||||
// GinHandler returns a gin.HandlerFunc that records the invocation.
|
||||
func (f *FakeHandlerRecorder) GinHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
f.record(c)
|
||||
c.Data(f.ResponseStatus, "application/json", f.ResponseBody)
|
||||
}
|
||||
}
|
||||
|
||||
// GinHandlerWithModel returns a gin.HandlerFunc that records the invocation and returns the model from context.
|
||||
// Useful for testing response rewriting in model mapping scenarios.
|
||||
func (f *FakeHandlerRecorder) GinHandlerWithModel() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
f.record(c)
|
||||
// Return a response with the model field that would be in the actual API response
|
||||
// If ResponseBody was explicitly set (not default), use that; otherwise generate from context
|
||||
var body []byte
|
||||
if mappedModel, exists := c.Get("mapped_model"); exists {
|
||||
body = []byte(`{"model":"` + mappedModel.(string) + `","status":"handled"}`)
|
||||
} else {
|
||||
body = f.ResponseBody
|
||||
}
|
||||
c.Data(f.ResponseStatus, "application/json", body)
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPHandler returns an http.HandlerFunc that records the invocation.
|
||||
func (f *FakeHandlerRecorder) HTTPHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
f.Called = true
|
||||
f.CallCount++
|
||||
f.RequestBody = body
|
||||
f.RequestHeader = r.Header.Clone()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(f.ResponseStatus)
|
||||
w.Write(f.ResponseBody)
|
||||
}
|
||||
}
|
||||
|
||||
// record captures the request details from gin context.
|
||||
func (f *FakeHandlerRecorder) record(c *gin.Context) {
|
||||
f.Called = true
|
||||
f.CallCount++
|
||||
|
||||
body, _ := io.ReadAll(c.Request.Body)
|
||||
f.RequestBody = body
|
||||
f.RequestHeader = c.Request.Header.Clone()
|
||||
|
||||
// Capture common context keys used by routing
|
||||
if val, exists := c.Get("mapped_model"); exists {
|
||||
f.ContextKeys["mapped_model"] = val
|
||||
}
|
||||
if val, exists := c.Get("fallback_models"); exists {
|
||||
f.ContextKeys["fallback_models"] = val
|
||||
}
|
||||
if val, exists := c.Get("route_type"); exists {
|
||||
f.ContextKeys["route_type"] = val
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears the recorder state.
|
||||
func (f *FakeHandlerRecorder) Reset() {
|
||||
f.Called = false
|
||||
f.CallCount = 0
|
||||
f.RequestBody = nil
|
||||
f.RequestHeader = nil
|
||||
f.ContextKeys = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// GetContextKey returns a captured context key value.
|
||||
func (f *FakeHandlerRecorder) GetContextKey(key string) (interface{}, bool) {
|
||||
val, ok := f.ContextKeys[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// WasCalled returns true if the handler was called.
|
||||
func (f *FakeHandlerRecorder) WasCalled() bool {
|
||||
return f.Called
|
||||
}
|
||||
|
||||
// GetCallCount returns the number of times the handler was called.
|
||||
func (f *FakeHandlerRecorder) GetCallCount() int {
|
||||
return f.CallCount
|
||||
}
|
||||
83
internal/routing/testutil/fake_proxy.go
Normal file
83
internal/routing/testutil/fake_proxy.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
)
|
||||
|
||||
// CloseNotifierRecorder wraps httptest.ResponseRecorder with CloseNotify support.
|
||||
// This is needed because ReverseProxy requires http.CloseNotifier.
|
||||
type CloseNotifierRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
closeChan chan bool
|
||||
}
|
||||
|
||||
// NewCloseNotifierRecorder creates a ResponseRecorder that implements CloseNotifier.
|
||||
func NewCloseNotifierRecorder() *CloseNotifierRecorder {
|
||||
return &CloseNotifierRecorder{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
closeChan: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// CloseNotify implements http.CloseNotifier.
|
||||
func (c *CloseNotifierRecorder) CloseNotify() <-chan bool {
|
||||
return c.closeChan
|
||||
}
|
||||
|
||||
// FakeProxyRecorder records proxy invocations for testing.
|
||||
type FakeProxyRecorder struct {
|
||||
Called bool
|
||||
CallCount int
|
||||
RequestBody []byte
|
||||
RequestHeaders http.Header
|
||||
ResponseStatus int
|
||||
ResponseBody []byte
|
||||
}
|
||||
|
||||
// NewFakeProxyRecorder creates a new fake proxy recorder.
|
||||
func NewFakeProxyRecorder() *FakeProxyRecorder {
|
||||
return &FakeProxyRecorder{
|
||||
ResponseStatus: http.StatusOK,
|
||||
ResponseBody: []byte(`{"status":"proxied"}`),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler to act as a reverse proxy.
|
||||
func (f *FakeProxyRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
f.Called = true
|
||||
f.CallCount++
|
||||
f.RequestHeaders = r.Header.Clone()
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
f.RequestBody = body
|
||||
}
|
||||
|
||||
w.WriteHeader(f.ResponseStatus)
|
||||
w.Write(f.ResponseBody)
|
||||
}
|
||||
|
||||
// GetCallCount returns the number of times the proxy was called.
|
||||
func (f *FakeProxyRecorder) GetCallCount() int {
|
||||
return f.CallCount
|
||||
}
|
||||
|
||||
// Reset clears the recorder state.
|
||||
func (f *FakeProxyRecorder) Reset() {
|
||||
f.Called = false
|
||||
f.CallCount = 0
|
||||
f.RequestBody = nil
|
||||
f.RequestHeaders = nil
|
||||
}
|
||||
|
||||
// ToHandler returns the recorder as an http.Handler for use with httptest.
|
||||
func (f *FakeProxyRecorder) ToHandler() http.Handler {
|
||||
return http.HandlerFunc(f.ServeHTTP)
|
||||
}
|
||||
|
||||
// CreateTestServer creates an httptest server with this fake proxy.
|
||||
func (f *FakeProxyRecorder) CreateTestServer() *httptest.Server {
|
||||
return httptest.NewServer(f.ToHandler())
|
||||
}
|
||||
62
internal/routing/types.go
Normal file
62
internal/routing/types.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package routing
|
||||
|
||||
// RouteType represents the type of routing decision made for a request.
|
||||
type RouteType string
|
||||
|
||||
const (
|
||||
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free).
|
||||
RouteTypeLocalProvider RouteType = "LOCAL_PROVIDER"
|
||||
// RouteTypeModelMapping indicates the request was remapped to another available model (free).
|
||||
RouteTypeModelMapping RouteType = "MODEL_MAPPING"
|
||||
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits).
|
||||
RouteTypeAmpCredits RouteType = "AMP_CREDITS"
|
||||
// RouteTypeNoProvider indicates no provider or fallback available.
|
||||
RouteTypeNoProvider RouteType = "NO_PROVIDER"
|
||||
)
|
||||
|
||||
// RoutingRequest contains the information needed to make a routing decision.
|
||||
type RoutingRequest struct {
|
||||
// RequestedModel is the model name from the incoming request.
|
||||
RequestedModel string
|
||||
// PreferLocalProvider indicates whether to prefer local providers over mappings.
|
||||
// When true, check local providers first before applying model mappings.
|
||||
PreferLocalProvider bool
|
||||
// ForceModelMapping indicates whether to force model mapping even if local provider exists.
|
||||
// When true, apply model mappings first and skip local provider checks.
|
||||
ForceModelMapping bool
|
||||
}
|
||||
|
||||
// RoutingDecision contains the result of a routing decision.
|
||||
type RoutingDecision struct {
|
||||
// RouteType indicates the type of routing decision.
|
||||
RouteType RouteType
|
||||
// ResolvedModel is the final model name after any mappings.
|
||||
ResolvedModel string
|
||||
// ProviderName is the name of the selected provider (if any).
|
||||
ProviderName string
|
||||
// FallbackModels is a list of alternative models to try if the primary fails.
|
||||
FallbackModels []string
|
||||
// ShouldProxy indicates whether the request should be proxied to ampcode.com.
|
||||
ShouldProxy bool
|
||||
}
|
||||
|
||||
// NewRoutingDecision creates a new RoutingDecision with the given parameters.
|
||||
func NewRoutingDecision(routeType RouteType, resolvedModel, providerName string, fallbackModels []string, shouldProxy bool) *RoutingDecision {
|
||||
return &RoutingDecision{
|
||||
RouteType: routeType,
|
||||
ResolvedModel: resolvedModel,
|
||||
ProviderName: providerName,
|
||||
FallbackModels: fallbackModels,
|
||||
ShouldProxy: shouldProxy,
|
||||
}
|
||||
}
|
||||
|
||||
// IsLocal returns true if the decision routes to a local provider.
|
||||
func (d *RoutingDecision) IsLocal() bool {
|
||||
return d.RouteType == RouteTypeLocalProvider || d.RouteType == RouteTypeModelMapping
|
||||
}
|
||||
|
||||
// HasFallbacks returns true if there are fallback models available.
|
||||
func (d *RoutingDecision) HasFallbacks() bool {
|
||||
return len(d.FallbackModels) > 0
|
||||
}
|
||||
270
internal/routing/wrapper.go
Normal file
270
internal/routing/wrapper.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyFunc is the function type for proxying requests.
|
||||
type ProxyFunc func(c *gin.Context)
|
||||
|
||||
// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic.
|
||||
// It replaces the FallbackHandler logic with a Router-based approach.
|
||||
type ModelRoutingWrapper struct {
|
||||
router *Router
|
||||
extractor ModelExtractor
|
||||
rewriter ModelRewriter
|
||||
proxyFunc ProxyFunc
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies.
|
||||
// If extractor is nil, a DefaultModelExtractor is used.
|
||||
// If rewriter is nil, a DefaultModelRewriter is used.
|
||||
// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead.
|
||||
func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper {
|
||||
if extractor == nil {
|
||||
extractor = NewModelExtractor()
|
||||
}
|
||||
if rewriter == nil {
|
||||
rewriter = NewModelRewriter()
|
||||
}
|
||||
return &ModelRoutingWrapper{
|
||||
router: router,
|
||||
extractor: extractor,
|
||||
rewriter: rewriter,
|
||||
proxyFunc: proxyFunc,
|
||||
logger: logrus.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogger sets the logger for the wrapper.
|
||||
func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) {
|
||||
w.logger = logger
|
||||
}
|
||||
|
||||
// Wrap wraps a gin.HandlerFunc with model routing logic.
|
||||
// The returned handler will:
|
||||
// 1. Extract the model from the request
|
||||
// 2. Get a routing decision from the Router
|
||||
// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS)
|
||||
func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Read request body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
w.logger.Errorf("routing wrapper: failed to read request body: %v", err)
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model from request
|
||||
ginParams := map[string]string{
|
||||
"action": c.Param("action"),
|
||||
"path": c.Param("path"),
|
||||
}
|
||||
modelName, err := w.extractor.Extract(bodyBytes, ginParams)
|
||||
if err != nil {
|
||||
w.logger.Warnf("routing wrapper: failed to extract model: %v", err)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
if modelName == "" {
|
||||
// No model found, proceed with original handler
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Get routing decision
|
||||
req := RoutingRequest{
|
||||
RequestedModel: modelName,
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false, // TODO: Get from config
|
||||
}
|
||||
decision := w.router.ResolveV2(req)
|
||||
|
||||
// Store decision in context for downstream handlers
|
||||
c.Set(string(ctxkeys.RoutingDecision), decision)
|
||||
|
||||
// Handle based on route type
|
||||
switch decision.RouteType {
|
||||
case RouteTypeLocalProvider:
|
||||
w.handleLocalProvider(c, handler, bodyBytes, decision)
|
||||
case RouteTypeModelMapping:
|
||||
w.handleModelMapping(c, handler, bodyBytes, decision)
|
||||
case RouteTypeAmpCredits:
|
||||
w.handleAmpCredits(c, handler, bodyBytes)
|
||||
default:
|
||||
// No provider available
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleLocalProvider handles the LOCAL_PROVIDER route type.
|
||||
func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
||||
// Filter Anthropic-Beta header for local provider
|
||||
filterAnthropicBetaHeader(c)
|
||||
|
||||
// Restore body with original content
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
// Call handler
|
||||
handler(c)
|
||||
}
|
||||
|
||||
// handleModelMapping handles the MODEL_MAPPING route type.
|
||||
func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
||||
// Rewrite request body with mapped model
|
||||
rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel)
|
||||
if err != nil {
|
||||
w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err)
|
||||
rewrittenBody = bodyBytes
|
||||
}
|
||||
_ = rewrittenBody
|
||||
|
||||
// Store mapped model in context
|
||||
c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel)
|
||||
|
||||
// Store fallback models in context if present
|
||||
if len(decision.FallbackModels) > 0 {
|
||||
c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels)
|
||||
}
|
||||
|
||||
// Filter Anthropic-Beta header for local provider
|
||||
filterAnthropicBetaHeader(c)
|
||||
|
||||
// Restore body with rewritten content
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody))
|
||||
|
||||
// Wrap response writer to rewrite model back
|
||||
wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel)
|
||||
c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer}
|
||||
|
||||
// Call handler
|
||||
handler(c)
|
||||
|
||||
// Cleanup (flush response rewriting)
|
||||
cleanup()
|
||||
}
|
||||
|
||||
// handleAmpCredits handles the AMP_CREDITS route type.
|
||||
// It calls the proxy function directly if available, otherwise passes to handler.
|
||||
// Does NOT filter headers or rewrite body - proxy handles everything.
|
||||
func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) {
|
||||
// Restore body with original content (no rewriting for proxy)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
// Call proxy function if available, otherwise fall back to handler
|
||||
if w.proxyFunc != nil {
|
||||
w.proxyFunc(c)
|
||||
} else {
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
|
||||
// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers.
|
||||
func filterAnthropicBetaHeader(c *gin.Context) {
|
||||
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
||||
if filtered != "" {
|
||||
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||
} else {
|
||||
c.Request.Header.Del("Anthropic-Beta")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterBetaFeatures removes specified beta features from the header.
|
||||
func filterBetaFeatures(betaHeader, featureToRemove string) string {
|
||||
// Simple implementation - can be enhanced
|
||||
if betaHeader == featureToRemove {
|
||||
return ""
|
||||
}
|
||||
return betaHeader
|
||||
}
|
||||
|
||||
// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter.
|
||||
type ginResponseWriterAdapter struct {
|
||||
http.ResponseWriter
|
||||
original gin.ResponseWriter
|
||||
}
|
||||
|
||||
func (a *ginResponseWriterAdapter) WriteHeader(code int) {
|
||||
a.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) {
|
||||
return a.ResponseWriter.Write(data)
|
||||
}
|
||||
|
||||
func (a *ginResponseWriterAdapter) Header() http.Header {
|
||||
return a.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
// CloseNotify implements http.CloseNotifier.
|
||||
func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool {
|
||||
if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok {
|
||||
return notifier.CloseNotify()
|
||||
}
|
||||
return a.original.CloseNotify()
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher.
|
||||
func (a *ginResponseWriterAdapter) Flush() {
|
||||
if flusher, ok := a.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker.
|
||||
func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok {
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
return a.original.Hijack()
|
||||
}
|
||||
|
||||
// Status returns the HTTP status code.
|
||||
func (a *ginResponseWriterAdapter) Status() int {
|
||||
return a.original.Status()
|
||||
}
|
||||
|
||||
// Size returns the number of bytes already written into the response http body.
|
||||
func (a *ginResponseWriterAdapter) Size() int {
|
||||
return a.original.Size()
|
||||
}
|
||||
|
||||
// Written returns whether or not the response for this context has been written.
|
||||
func (a *ginResponseWriterAdapter) Written() bool {
|
||||
return a.original.Written()
|
||||
}
|
||||
|
||||
// WriteHeaderNow forces WriteHeader to be called.
|
||||
func (a *ginResponseWriterAdapter) WriteHeaderNow() {
|
||||
a.original.WriteHeaderNow()
|
||||
}
|
||||
|
||||
// WriteString writes the given string into the response body.
|
||||
func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) {
|
||||
return a.Write([]byte(s))
|
||||
}
|
||||
|
||||
// Pusher returns the http.Pusher for server push.
|
||||
func (a *ginResponseWriterAdapter) Pusher() http.Pusher {
|
||||
if pusher, ok := a.ResponseWriter.(http.Pusher); ok {
|
||||
return pusher
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -61,10 +61,13 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
|
||||
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
|
||||
// Also track if thinking is enabled to ensure reasoning_content is added for tool_calls
|
||||
thinkingEnabled := false
|
||||
if thinkingConfig := root.Get("thinking"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingType := thinkingConfig.Get("type"); thinkingType.Exists() {
|
||||
switch thinkingType.String() {
|
||||
case "enabled":
|
||||
thinkingEnabled = true
|
||||
if budgetTokens := thinkingConfig.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(budget); ok && effort != "" {
|
||||
@@ -217,6 +220,10 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Add reasoning_content if present
|
||||
if hasReasoning {
|
||||
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
|
||||
} else if thinkingEnabled && hasToolCalls {
|
||||
// Claude API requires reasoning_content in assistant messages with tool_calls
|
||||
// when thinking mode is enabled, even if empty
|
||||
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", "")
|
||||
}
|
||||
|
||||
// Add tool_calls if present (in same message as content)
|
||||
|
||||
@@ -588,3 +588,124 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
|
||||
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning tests that
|
||||
// when thinking mode is enabled and assistant message has tool_calls but no thinking content,
|
||||
// an empty reasoning_content is added to satisfy Claude API requirements.
|
||||
func TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantHasReasoningContent bool
|
||||
wantReasoningContent string
|
||||
}{
|
||||
{
|
||||
name: "thinking enabled with tool_calls but no thinking content adds empty reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"thinking": {"type": "enabled", "budget_tokens": 4000},
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I will help you."},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: true,
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled with tool_calls and thinking content uses actual reasoning",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"thinking": {"type": "enabled", "budget_tokens": 4000},
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me analyze this..."},
|
||||
{"type": "text", "text": "I will help you."},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: true,
|
||||
wantReasoningContent: "Let me analyze this...",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled with tool_calls does not add reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"thinking": {"type": "disabled"},
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I will help you."},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: false,
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
{
|
||||
name: "no thinking config with tool_calls does not add reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I will help you."},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: false,
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled without tool_calls and no thinking content does not add reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"thinking": {"type": "enabled", "budget_tokens": 4000},
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Simple response without tools."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: false,
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
if len(messages) == 0 {
|
||||
t.Fatal("Expected at least one message")
|
||||
}
|
||||
|
||||
assistantMsg := messages[0]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Fatalf("Expected assistant message, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
hasReasoningContent := assistantMsg.Get("reasoning_content").Exists()
|
||||
if hasReasoningContent != tt.wantHasReasoningContent {
|
||||
t.Errorf("reasoning_content existence = %v, want %v", hasReasoningContent, tt.wantHasReasoningContent)
|
||||
}
|
||||
|
||||
if hasReasoningContent {
|
||||
gotReasoningContent := assistantMsg.Get("reasoning_content").String()
|
||||
if gotReasoningContent != tt.wantReasoningContent {
|
||||
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,16 +255,15 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
parentCtx = logging.WithRequestID(parentCtx, requestID)
|
||||
}
|
||||
}
|
||||
newCtx, cancel := context.WithCancel(parentCtx)
|
||||
if requestCtx != nil && requestCtx != parentCtx {
|
||||
go func() {
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
cancel()
|
||||
case <-newCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
// Use requestCtx as base if available to preserve amp context values (fallback_models, etc.)
|
||||
// Falls back to parentCtx if no request context
|
||||
baseCtx := parentCtx
|
||||
if requestCtx != nil {
|
||||
baseCtx = requestCtx
|
||||
}
|
||||
|
||||
newCtx, cancel := context.WithCancel(baseCtx)
|
||||
newCtx = context.WithValue(newCtx, "gin", c)
|
||||
newCtx = context.WithValue(newCtx, "handler", handler)
|
||||
return newCtx, func(params ...interface{}) {
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -562,192 +563,188 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeWithFallback(
|
||||
ctx context.Context,
|
||||
initialProviders []string,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
exec func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error,
|
||||
) error {
|
||||
routeModel := req.Model
|
||||
providers := initialProviders
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
|
||||
// Track fallback models from context (provided by Amp module fallback_models key)
|
||||
var fallbacks []string
|
||||
if v := ctx.Value(ctxkeys.FallbackModels); v != nil {
|
||||
if fs, ok := v.([]string); ok {
|
||||
fallbacks = fs
|
||||
}
|
||||
}
|
||||
fallbackIdx := -1
|
||||
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
// No more auths for current model. Try next fallback model if available.
|
||||
if fallbackIdx+1 < len(fallbacks) {
|
||||
fallbackIdx++
|
||||
routeModel = fallbacks[fallbackIdx]
|
||||
log.Debugf("no more auths for current model, trying fallback model: %s (fallback %d/%d)", routeModel, fallbackIdx+1, len(fallbacks))
|
||||
|
||||
// Reset tried set for the new model and find its providers
|
||||
tried = make(map[string]struct{})
|
||||
providers = util.GetProviderName(thinking.ParseSuffix(routeModel).ModelName)
|
||||
// Reset opts for the new model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("fallback model %s has no providers, skipping", routeModel)
|
||||
continue // Try next fallback if this one has no providers
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return errPick
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
if err := exec(ctx, executor, auth, provider, routeModel); err != nil {
|
||||
if errCtx := ctx.Err(); errCtx != nil {
|
||||
return errCtx
|
||||
}
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedAttempt(
|
||||
ctx context.Context,
|
||||
auth *Auth,
|
||||
provider, routeModel string,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
exec func(ctx context.Context, execReq cliproxyexecutor.Request) error,
|
||||
) error {
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
|
||||
err := exec(execCtx, execReq)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: err == nil}
|
||||
if err != nil {
|
||||
result.Error = &Error{Message: err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(err, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(err); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
var resp cliproxyexecutor.Response
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
resp, errExec = executor.Execute(execCtx, auth, execReq, opts)
|
||||
return errExec
|
||||
})
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
var resp cliproxyexecutor.Response
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
resp, errExec = executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
return errExec
|
||||
})
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
var chunks <-chan cliproxyexecutor.StreamChunk
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
chunks, errExec = executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errExec != nil {
|
||||
return errExec
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
forward := true
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
forward := true
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
continue
|
||||
}
|
||||
if streamCtx == nil {
|
||||
out <- chunk
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
forward = false
|
||||
case out <- chunk:
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
continue
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
if streamCtx == nil {
|
||||
out <- chunk
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
forward = false
|
||||
case out <- chunk:
|
||||
}
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
chunks = out
|
||||
return nil
|
||||
})
|
||||
})
|
||||
return chunks, err
|
||||
}
|
||||
|
||||
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
||||
|
||||
Reference in New Issue
Block a user