feat: add prioritize-model-mappings config option

Add a configuration option to control whether model mappings take
precedence over local API keys for Amp CLI requests.

- Add PrioritizeModelMappings field to AmpCode config struct
- When false (default): Local API keys take precedence (original behavior)
- When true: Model mappings take precedence over local API keys
- Add management API endpoints GET/PUT /prioritize-model-mappings

This allows users who want mapping priority to enable it explicitly
while preserving backward compatibility.

Config example:
  ampcode:
    model-mappings:
      - from: claude-opus-4-5-20251101
        to: gemini-claude-opus-4-5-thinking
    prioritize-model-mappings: true
This commit is contained in:
huynhgiabuu
2025-12-07 22:47:43 +07:00
parent 6cf1d8a947
commit afcab5efda
6 changed files with 90 additions and 29 deletions

View File

@@ -241,3 +241,11 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) {
h.cfg.ProxyURL = "" h.cfg.ProxyURL = ""
h.persist(c) h.persist(c)
} }
// Prioritize Model Mappings (for Amp CLI)
func (h *Handler) GetPrioritizeModelMappings(c *gin.Context) {
c.JSON(200, gin.H{"prioritize-model-mappings": h.cfg.AmpCode.PrioritizeModelMappings})
}
func (h *Handler) PutPrioritizeModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.PrioritizeModelMappings = v })
}

View File

@@ -100,6 +100,16 @@ func (m *AmpModule) Name() string {
return "amp-routing" return "amp-routing"
} }
// getPrioritizeModelMappings returns whether model mappings should take precedence over local API keys
func (m *AmpModule) getPrioritizeModelMappings() bool {
m.configMu.RLock()
defer m.configMu.RUnlock()
if m.lastConfig == nil {
return false
}
return m.lastConfig.PrioritizeModelMappings
}
// Register sets up Amp routes if configured. // Register sets up Amp routes if configured.
// This implements the RouteModuleV2 interface with Context. // This implements the RouteModuleV2 interface with Context.
// Routes are registered only once via sync.Once for idempotent behavior. // Routes are registered only once via sync.Once for idempotent behavior.

View File

@@ -77,23 +77,29 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com // FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI // when the model's provider is not available in CLIProxyAPI
type FallbackHandler struct { type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper modelMapper ModelMapper
getPrioritizeModelMappings func() bool
} }
// NewFallbackHandler creates a new fallback handler wrapper // NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{ return &FallbackHandler{
getProxy: getProxy, getProxy: getProxy,
getPrioritizeModelMappings: func() bool { return false },
} }
} }
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler { func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, getPrioritize func() bool) *FallbackHandler {
if getPrioritize == nil {
getPrioritize = func() bool { return false }
}
return &FallbackHandler{ return &FallbackHandler{
getProxy: getProxy, getProxy: getProxy,
modelMapper: mapper, modelMapper: mapper,
getPrioritizeModelMappings: getPrioritize,
} }
} }
@@ -130,34 +136,65 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Normalize model (handles Gemini thinking suffixes) // Normalize model (handles Gemini thinking suffixes)
normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName) normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName)
// Check if we have providers for this model
providers := util.GetProviderName(normalizedModel)
// Track resolved model for logging (may change if mapping is applied) // Track resolved model for logging (may change if mapping is applied)
resolvedModel := normalizedModel resolvedModel := normalizedModel
usedMapping := false usedMapping := false
var providers []string
if len(providers) == 0 { // Check if model mappings should take priority over local API keys
// No providers configured - check if we have a model mapping prioritizeMappings := fh.getPrioritizeModelMappings != nil && fh.getPrioritizeModelMappings()
if prioritizeMappings {
// PRIORITY MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers
if fh.modelMapper != nil { if fh.modelMapper != nil {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
// Mapping found - rewrite the model in request body // Mapping found - check if we have a provider for the mapped model
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) mappedProviders := util.GetProviderName(mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) if len(mappedProviders) > 0 {
// Store mapped model in context for handlers that check it (like gemini bridge) // Mapping found and provider available - rewrite the model in request body
c.Set(MappedModelContextKey, mappedModel) bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
resolvedModel = mappedModel c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
usedMapping = true // Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
// Get providers for the mapped model resolvedModel = mappedModel
providers = util.GetProviderName(mappedModel) usedMapping = true
providers = mappedProviders
// Continue to handler with remapped model }
goto handleRequest
} }
} }
// No mapping found - check if we have a proxy for fallback // If no mapping applied, check for local providers
if !usedMapping {
providers = util.GetProviderName(normalizedModel)
}
} else {
// DEFAULT MODE: Check local providers first, then mappings as fallback
providers = util.GetProviderName(normalizedModel)
if len(providers) == 0 {
// No providers configured - check if we have a model mapping
if fh.modelMapper != nil {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
// Mapping found - check if we have a provider for the mapped model
mappedProviders := util.GetProviderName(mappedModel)
if len(mappedProviders) > 0 {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
}
}
}
// If no providers available, fallback to ampcode.com
if len(providers) == 0 {
proxy := fh.getProxy() proxy := fh.getProxy()
if proxy != nil { if proxy != nil {
// Log: Forwarding to ampcode.com (uses Amp credits) // Log: Forwarding to ampcode.com (uses Amp credits)
@@ -175,8 +212,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
} }
handleRequest:
// Log the routing decision // Log the routing decision
providerName := "" providerName := ""
if len(providers) > 0 { if len(providers) > 0 {

View File

@@ -171,7 +171,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy() return m.getProxy()
}, m.modelMapper) }, m.modelMapper, m.getPrioritizeModelMappings)
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
// Route POST model calls through Gemini bridge with FallbackHandler. // Route POST model calls through Gemini bridge with FallbackHandler.
@@ -209,7 +209,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
// Also includes model mapping support for routing unavailable models to alternatives // Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy() return m.getProxy()
}, m.modelMapper) }, m.modelMapper, m.getPrioritizeModelMappings)
// Provider-specific routes under /api/provider/:provider // Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider") ampProviders := engine.Group("/api/provider")

View File

@@ -520,6 +520,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.GET("/prioritize-model-mappings", s.mgmt.GetPrioritizeModelMappings)
mgmt.PUT("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings)
mgmt.PATCH("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry) mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)

View File

@@ -143,6 +143,10 @@ type AmpCode struct {
// When Amp requests a model that isn't available locally, these mappings // When Amp requests a model that isn't available locally, these mappings
// allow routing to an alternative model that IS available. // allow routing to an alternative model that IS available.
ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"` ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"`
// PrioritizeModelMappings when true, model mappings take precedence over local API keys.
// When false (default), local API keys are used first if available.
PrioritizeModelMappings bool `yaml:"prioritize-model-mappings" json:"prioritize-model-mappings"`
} }
// PayloadConfig defines default and override parameter rules applied to provider payloads. // PayloadConfig defines default and override parameter rules applied to provider payloads.