mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(amp): add model mapping support for routing unavailable models to alternatives
- Add AmpModelMapping config to route models like 'claude-opus-4.5' to 'claude-sonnet-4' - Add ModelMapper interface and DefaultModelMapper implementation with hot-reload support - Enhance FallbackHandler to apply model mappings before falling back to ampcode.com - Add structured logging for routing decisions (local provider, mapping, amp credits) - Update config.example.yaml with amp-model-mappings documentation
This commit is contained in:
@@ -55,6 +55,28 @@ quota-exceeded:
|
|||||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
ws-auth: false
|
ws-auth: false
|
||||||
|
|
||||||
|
# Amp CLI Integration
|
||||||
|
# Configure upstream URL for Amp CLI OAuth and management features
|
||||||
|
#amp-upstream-url: "https://ampcode.com"
|
||||||
|
|
||||||
|
# Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||||
|
#amp-upstream-api-key: ""
|
||||||
|
|
||||||
|
# Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||||
|
#amp-restrict-management-to-localhost: true
|
||||||
|
|
||||||
|
# Amp Model Mappings
|
||||||
|
# Route unavailable Amp models to alternative models available in your local proxy.
|
||||||
|
# Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||||
|
# but you have a similar model available (e.g., Claude Sonnet 4).
|
||||||
|
#amp-model-mappings:
|
||||||
|
# - from: "claude-opus-4.5" # Model requested by Amp CLI
|
||||||
|
# to: "claude-sonnet-4" # Route to this available model instead
|
||||||
|
# - from: "gpt-5"
|
||||||
|
# to: "gemini-2.5-pro"
|
||||||
|
# - from: "claude-3-opus-20240229"
|
||||||
|
# to: "claude-3-5-sonnet-20241022"
|
||||||
|
|
||||||
# Gemini API keys (preferred)
|
# Gemini API keys (preferred)
|
||||||
#gemini-api-key:
|
#gemini-api-key:
|
||||||
# - api-key: "AIzaSy...01"
|
# - api-key: "AIzaSy...01"
|
||||||
|
|||||||
@@ -23,11 +23,13 @@ type Option func(*AmpModule)
|
|||||||
// - Reverse proxy to Amp control plane for OAuth/management
|
// - Reverse proxy to Amp control plane for OAuth/management
|
||||||
// - Provider-specific route aliases (/api/provider/{provider}/...)
|
// - Provider-specific route aliases (/api/provider/{provider}/...)
|
||||||
// - Automatic gzip decompression for misconfigured upstreams
|
// - Automatic gzip decompression for misconfigured upstreams
|
||||||
|
// - Model mapping for routing unavailable models to alternatives
|
||||||
type AmpModule struct {
|
type AmpModule struct {
|
||||||
secretSource SecretSource
|
secretSource SecretSource
|
||||||
proxy *httputil.ReverseProxy
|
proxy *httputil.ReverseProxy
|
||||||
accessManager *sdkaccess.Manager
|
accessManager *sdkaccess.Manager
|
||||||
authMiddleware_ gin.HandlerFunc
|
authMiddleware_ gin.HandlerFunc
|
||||||
|
modelMapper *DefaultModelMapper
|
||||||
enabled bool
|
enabled bool
|
||||||
registerOnce sync.Once
|
registerOnce sync.Once
|
||||||
}
|
}
|
||||||
@@ -101,6 +103,9 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
|||||||
// Use registerOnce to ensure routes are only registered once
|
// Use registerOnce to ensure routes are only registered once
|
||||||
var regErr error
|
var regErr error
|
||||||
m.registerOnce.Do(func() {
|
m.registerOnce.Do(func() {
|
||||||
|
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
||||||
|
m.modelMapper = NewModelMapper(ctx.Config.AmpModelMappings)
|
||||||
|
|
||||||
// Always register provider aliases - these work without an upstream
|
// Always register provider aliases - these work without an upstream
|
||||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||||
|
|
||||||
@@ -159,8 +164,13 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
|
|||||||
// OnConfigUpdated handles configuration updates.
|
// OnConfigUpdated handles configuration updates.
|
||||||
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
|
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
|
||||||
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||||
|
// Update model mappings (hot-reload supported)
|
||||||
|
if m.modelMapper != nil {
|
||||||
|
m.modelMapper.UpdateMappings(cfg.AmpModelMappings)
|
||||||
|
}
|
||||||
|
|
||||||
if !m.enabled {
|
if !m.enabled {
|
||||||
log.Debug("Amp routing not enabled, skipping config update")
|
log.Debug("Amp routing not enabled, skipping other config updates")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,3 +191,8 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
log.Debug("Amp config updated (restart required for URL changes)")
|
log.Debug("Amp config updated (restart required for URL changes)")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||||
|
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||||
|
return m.modelMapper
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,16 +6,75 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AmpRouteType represents the type of routing decision made for an Amp request
|
||||||
|
type AmpRouteType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free)
|
||||||
|
RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER"
|
||||||
|
// RouteTypeModelMapping indicates the request was remapped to another available model (free)
|
||||||
|
RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING"
|
||||||
|
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits)
|
||||||
|
RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS"
|
||||||
|
// RouteTypeNoProvider indicates no provider or fallback available
|
||||||
|
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
||||||
|
)
|
||||||
|
|
||||||
|
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||||
|
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||||
|
fields := log.Fields{
|
||||||
|
"component": "amp-routing",
|
||||||
|
"route_type": string(routeType),
|
||||||
|
"requested_model": requestedModel,
|
||||||
|
"path": path,
|
||||||
|
"timestamp": time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
if resolvedModel != "" && resolvedModel != requestedModel {
|
||||||
|
fields["resolved_model"] = resolvedModel
|
||||||
|
}
|
||||||
|
if provider != "" {
|
||||||
|
fields["provider"] = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
switch routeType {
|
||||||
|
case RouteTypeLocalProvider:
|
||||||
|
fields["cost"] = "free"
|
||||||
|
fields["source"] = "local_oauth"
|
||||||
|
log.WithFields(fields).Infof("[AMP] Using local provider for model: %s", requestedModel)
|
||||||
|
|
||||||
|
case RouteTypeModelMapping:
|
||||||
|
fields["cost"] = "free"
|
||||||
|
fields["source"] = "local_oauth"
|
||||||
|
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
||||||
|
log.WithFields(fields).Infof("[AMP] Model mapped: %s -> %s", requestedModel, resolvedModel)
|
||||||
|
|
||||||
|
case RouteTypeAmpCredits:
|
||||||
|
fields["cost"] = "amp_credits"
|
||||||
|
fields["source"] = "ampcode.com"
|
||||||
|
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||||
|
log.WithFields(fields).Warnf("[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||||
|
|
||||||
|
case RouteTypeNoProvider:
|
||||||
|
fields["cost"] = "none"
|
||||||
|
fields["source"] = "error"
|
||||||
|
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||||
|
log.WithFields(fields).Warnf("[AMP] No provider available for model_id: %s", requestedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||||
// when the model's provider is not available in CLIProxyAPI
|
// when the model's provider is not available in CLIProxyAPI
|
||||||
type FallbackHandler struct {
|
type FallbackHandler struct {
|
||||||
getProxy func() *httputil.ReverseProxy
|
getProxy func() *httputil.ReverseProxy
|
||||||
|
modelMapper ModelMapper
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFallbackHandler creates a new fallback handler wrapper
|
// NewFallbackHandler creates a new fallback handler wrapper
|
||||||
@@ -26,10 +85,25 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||||
|
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler {
|
||||||
|
return &FallbackHandler{
|
||||||
|
getProxy: getProxy,
|
||||||
|
modelMapper: mapper,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelMapper sets the model mapper for this handler (allows late binding)
|
||||||
|
func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
||||||
|
fh.modelMapper = mapper
|
||||||
|
}
|
||||||
|
|
||||||
// WrapHandler wraps a gin.HandlerFunc with fallback logic
|
// WrapHandler wraps a gin.HandlerFunc with fallback logic
|
||||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
requestPath := c.Request.URL.Path
|
||||||
|
|
||||||
// Read the request body to extract the model name
|
// Read the request body to extract the model name
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -55,12 +129,33 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
// Check if we have providers for this model
|
// Check if we have providers for this model
|
||||||
providers := util.GetProviderName(normalizedModel)
|
providers := util.GetProviderName(normalizedModel)
|
||||||
|
|
||||||
|
// Track resolved model for logging (may change if mapping is applied)
|
||||||
|
resolvedModel := normalizedModel
|
||||||
|
usedMapping := false
|
||||||
|
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
// No providers configured - check if we have a proxy for fallback
|
// No providers configured - check if we have a model mapping
|
||||||
|
if fh.modelMapper != nil {
|
||||||
|
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||||
|
// Mapping found - rewrite the model in request body
|
||||||
|
bodyBytes = rewriteModelInBody(bodyBytes, mappedModel)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
resolvedModel = mappedModel
|
||||||
|
usedMapping = true
|
||||||
|
|
||||||
|
// Get providers for the mapped model
|
||||||
|
providers = util.GetProviderName(mappedModel)
|
||||||
|
|
||||||
|
// Continue to handler with remapped model
|
||||||
|
goto handleRequest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No mapping found - check if we have a proxy for fallback
|
||||||
proxy := fh.getProxy()
|
proxy := fh.getProxy()
|
||||||
if proxy != nil {
|
if proxy != nil {
|
||||||
// Fallback to ampcode.com
|
// Log: Forwarding to ampcode.com (uses Amp credits)
|
||||||
log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName)
|
logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath)
|
||||||
|
|
||||||
// Restore body again for the proxy
|
// Restore body again for the proxy
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
@@ -71,7 +166,23 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// No proxy available, let the normal handler return the error
|
// No proxy available, let the normal handler return the error
|
||||||
log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName)
|
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
handleRequest:
|
||||||
|
|
||||||
|
// Log the routing decision
|
||||||
|
providerName := ""
|
||||||
|
if len(providers) > 0 {
|
||||||
|
providerName = providers[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if usedMapping {
|
||||||
|
// Log: Model was mapped to another model
|
||||||
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
} else if len(providers) > 0 {
|
||||||
|
// Log: Using local provider (free)
|
||||||
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Providers available or no proxy for fallback, restore body and use normal handler
|
// Providers available or no proxy for fallback, restore body and use normal handler
|
||||||
@@ -91,6 +202,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rewriteModelInBody replaces the model name in a JSON request body
|
||||||
|
func rewriteModelInBody(body []byte, newModel string) []byte {
|
||||||
|
var payload map[string]interface{}
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := payload["model"]; exists {
|
||||||
|
payload["model"] = newModel
|
||||||
|
newBody, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return newBody
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
// extractModelFromRequest attempts to extract the model name from various request formats
|
// extractModelFromRequest attempts to extract the model name from various request formats
|
||||||
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
||||||
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||||
|
|||||||
113
internal/api/modules/amp/model_mapping.go
Normal file
113
internal/api/modules/amp/model_mapping.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// Package amp provides model mapping functionality for routing Amp CLI requests
|
||||||
|
// to alternative models when the requested model is not available locally.
|
||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelMapper provides model name mapping/aliasing for Amp CLI requests.
|
||||||
|
// When an Amp request comes in for a model that isn't available locally,
|
||||||
|
// this mapper can redirect it to an alternative model that IS available.
|
||||||
|
type ModelMapper interface {
|
||||||
|
// MapModel returns the target model name if a mapping exists and the target
|
||||||
|
// model has available providers. Returns empty string if no mapping applies.
|
||||||
|
MapModel(requestedModel string) string
|
||||||
|
|
||||||
|
// UpdateMappings refreshes the mapping configuration (for hot-reload).
|
||||||
|
UpdateMappings(mappings []config.AmpModelMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
||||||
|
type DefaultModelMapper struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
mappings map[string]string // from -> to (normalized lowercase keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||||
|
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||||
|
m := &DefaultModelMapper{
|
||||||
|
mappings: make(map[string]string),
|
||||||
|
}
|
||||||
|
m.UpdateMappings(mappings)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapModel checks if a mapping exists for the requested model and if the
|
||||||
|
// target model has available local providers. Returns the mapped model name
|
||||||
|
// or empty string if no valid mapping exists.
|
||||||
|
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||||
|
if requestedModel == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
// Normalize the requested model for lookup
|
||||||
|
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||||
|
|
||||||
|
// Check for direct mapping
|
||||||
|
targetModel, exists := m.mappings[normalizedRequest]
|
||||||
|
if !exists {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify target model has available providers
|
||||||
|
providers := util.GetProviderName(targetModel)
|
||||||
|
if len(providers) == 0 {
|
||||||
|
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||||
|
log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel)
|
||||||
|
return targetModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMappings refreshes the mapping configuration from config.
|
||||||
|
// This is called during initialization and on config hot-reload.
|
||||||
|
func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Clear and rebuild mappings
|
||||||
|
m.mappings = make(map[string]string, len(mappings))
|
||||||
|
|
||||||
|
for _, mapping := range mappings {
|
||||||
|
from := strings.TrimSpace(mapping.From)
|
||||||
|
to := strings.TrimSpace(mapping.To)
|
||||||
|
|
||||||
|
if from == "" || to == "" {
|
||||||
|
log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store with normalized lowercase key for case-insensitive lookup
|
||||||
|
normalizedFrom := strings.ToLower(from)
|
||||||
|
m.mappings[normalizedFrom] = to
|
||||||
|
|
||||||
|
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.mappings) > 0 {
|
||||||
|
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMappings returns a copy of current mappings (for debugging/status).
|
||||||
|
func (m *DefaultModelMapper) GetMappings() map[string]string {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make(map[string]string, len(m.mappings))
|
||||||
|
for k, v := range m.mappings {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
186
internal/api/modules/amp/model_mapping_test.go
Normal file
186
internal/api/modules/amp/model_mapping_test.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewModelMapper(t *testing.T) {
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||||
|
{From: "gpt-5", To: "gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
if mapper == nil {
|
||||||
|
t.Fatal("Expected non-nil mapper")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := mapper.GetMappings()
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Errorf("Expected 2 mappings, got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewModelMapper_Empty(t *testing.T) {
|
||||||
|
mapper := NewModelMapper(nil)
|
||||||
|
if mapper == nil {
|
||||||
|
t.Fatal("Expected non-nil mapper")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := mapper.GetMappings()
|
||||||
|
if len(result) != 0 {
|
||||||
|
t.Errorf("Expected 0 mappings, got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_MapModel_NoProvider(t *testing.T) {
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
// Without a registered provider for the target, mapping should return empty
|
||||||
|
result := mapper.MapModel("claude-opus-4.5")
|
||||||
|
if result != "" {
|
||||||
|
t.Errorf("Expected empty result when target has no provider, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
||||||
|
// Register a mock provider for the target model
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("test-client")
|
||||||
|
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
// With a registered provider, mapping should work
|
||||||
|
result := mapper.MapModel("claude-opus-4.5")
|
||||||
|
if result != "claude-sonnet-4" {
|
||||||
|
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
||||||
|
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("test-client2")
|
||||||
|
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "Claude-Opus-4.5", To: "claude-sonnet-4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
// Should match case-insensitively
|
||||||
|
result := mapper.MapModel("claude-opus-4.5")
|
||||||
|
if result != "claude-sonnet-4" {
|
||||||
|
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_MapModel_NotFound(t *testing.T) {
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
// Unknown model should return empty
|
||||||
|
result := mapper.MapModel("unknown-model")
|
||||||
|
if result != "" {
|
||||||
|
t.Errorf("Expected empty for unknown model, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_MapModel_EmptyInput(t *testing.T) {
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
result := mapper.MapModel("")
|
||||||
|
if result != "" {
|
||||||
|
t.Errorf("Expected empty for empty input, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_UpdateMappings(t *testing.T) {
|
||||||
|
mapper := NewModelMapper(nil)
|
||||||
|
|
||||||
|
// Initially empty
|
||||||
|
if len(mapper.GetMappings()) != 0 {
|
||||||
|
t.Error("Expected 0 initial mappings")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update with new mappings
|
||||||
|
mapper.UpdateMappings([]config.AmpModelMapping{
|
||||||
|
{From: "model-a", To: "model-b"},
|
||||||
|
{From: "model-c", To: "model-d"},
|
||||||
|
})
|
||||||
|
|
||||||
|
result := mapper.GetMappings()
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Errorf("Expected 2 mappings after update, got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update again should replace, not append
|
||||||
|
mapper.UpdateMappings([]config.AmpModelMapping{
|
||||||
|
{From: "model-x", To: "model-y"},
|
||||||
|
})
|
||||||
|
|
||||||
|
result = mapper.GetMappings()
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Errorf("Expected 1 mapping after second update, got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
|
||||||
|
mapper := NewModelMapper(nil)
|
||||||
|
|
||||||
|
mapper.UpdateMappings([]config.AmpModelMapping{
|
||||||
|
{From: "", To: "model-b"}, // Invalid: empty from
|
||||||
|
{From: "model-a", To: ""}, // Invalid: empty to
|
||||||
|
{From: " ", To: "model-b"}, // Invalid: whitespace from
|
||||||
|
{From: "model-c", To: "model-d"}, // Valid
|
||||||
|
})
|
||||||
|
|
||||||
|
result := mapper.GetMappings()
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Errorf("Expected 1 valid mapping, got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
|
||||||
|
mappings := []config.AmpModelMapping{
|
||||||
|
{From: "model-a", To: "model-b"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapper := NewModelMapper(mappings)
|
||||||
|
|
||||||
|
// Get mappings and modify the returned map
|
||||||
|
result := mapper.GetMappings()
|
||||||
|
result["new-key"] = "new-value"
|
||||||
|
|
||||||
|
// Original should be unchanged
|
||||||
|
original := mapper.GetMappings()
|
||||||
|
if len(original) != 1 {
|
||||||
|
t.Errorf("Expected original to have 1 mapping, got %d", len(original))
|
||||||
|
}
|
||||||
|
if _, exists := original["new-key"]; exists {
|
||||||
|
t.Error("Original map was modified")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -162,9 +162,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
|
|
||||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
||||||
// Uses lazy evaluation to access proxy (which is created after routes are registered)
|
// Uses lazy evaluation to access proxy (which is created after routes are registered)
|
||||||
fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy {
|
// Also includes model mapping support for routing unavailable models to alternatives
|
||||||
|
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||||
return m.proxy
|
return m.proxy
|
||||||
})
|
}, m.modelMapper)
|
||||||
|
|
||||||
// Provider-specific routes under /api/provider/:provider
|
// Provider-specific routes under /api/provider/:provider
|
||||||
ampProviders := engine.Group("/api/provider")
|
ampProviders := engine.Group("/api/provider")
|
||||||
|
|||||||
@@ -37,6 +37,12 @@ type Config struct {
|
|||||||
// browser attacks and remote access to management endpoints. Default: true (recommended).
|
// browser attacks and remote access to management endpoints. Default: true (recommended).
|
||||||
AmpRestrictManagementToLocalhost bool `yaml:"amp-restrict-management-to-localhost" json:"amp-restrict-management-to-localhost"`
|
AmpRestrictManagementToLocalhost bool `yaml:"amp-restrict-management-to-localhost" json:"amp-restrict-management-to-localhost"`
|
||||||
|
|
||||||
|
// AmpModelMappings defines model name mappings for Amp CLI requests.
|
||||||
|
// When Amp requests a model that isn't available locally, these mappings
|
||||||
|
// allow routing to an alternative model that IS available.
|
||||||
|
// Example: Map "claude-opus-4.5" -> "claude-sonnet-4" when opus isn't available.
|
||||||
|
AmpModelMappings []AmpModelMapping `yaml:"amp-model-mappings" json:"amp-model-mappings"`
|
||||||
|
|
||||||
// AuthDir is the directory where authentication token files are stored.
|
// AuthDir is the directory where authentication token files are stored.
|
||||||
AuthDir string `yaml:"auth-dir" json:"-"`
|
AuthDir string `yaml:"auth-dir" json:"-"`
|
||||||
|
|
||||||
@@ -115,6 +121,18 @@ type QuotaExceeded struct {
|
|||||||
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||||
|
// When Amp requests a model that isn't available locally, this mapping
|
||||||
|
// allows routing to an alternative model that IS available.
|
||||||
|
type AmpModelMapping struct {
|
||||||
|
// From is the model name that Amp CLI requests (e.g., "claude-opus-4.5").
|
||||||
|
From string `yaml:"from" json:"from"`
|
||||||
|
|
||||||
|
// To is the target model name to route to (e.g., "claude-sonnet-4").
|
||||||
|
// The target model must have available providers in the registry.
|
||||||
|
To string `yaml:"to" json:"to"`
|
||||||
|
}
|
||||||
|
|
||||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||||
type PayloadConfig struct {
|
type PayloadConfig struct {
|
||||||
// Default defines rules that only set parameters when they are missing in the payload.
|
// Default defines rules that only set parameters when they are missing in the payload.
|
||||||
|
|||||||
Reference in New Issue
Block a user