feat: regex support for model-mappings

This commit is contained in:
altamash
2025-12-23 18:41:58 +05:30
parent e52b542e22
commit 5dcf7cb846
4 changed files with 149 additions and 20 deletions

View File

@@ -280,22 +280,34 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
} }
// Build map for efficient comparison // Build map for efficient comparison
oldMap := make(map[string]string, len(old.ModelMappings)) oldMap := make(map[string]string, len(old.ModelMappings))
for _, mapping := range old.ModelMappings { for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To) from := strings.TrimSpace(mapping.From)
} to := strings.TrimSpace(mapping.To)
key := from
val := to + "|regex=" + boolTo01(mapping.Regex)
oldMap[key] = val
}
for _, mapping := range new.ModelMappings { for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From) from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To) to := strings.TrimSpace(mapping.To)
if oldTo, exists := oldMap[from]; !exists || oldTo != to { val := to + "|regex=" + boolTo01(mapping.Regex)
return true if oldVal, exists := oldMap[from]; !exists || oldVal != val {
} return true
} }
}
return false return false
} }
func boolTo01(b bool) string {
if b {
return "1"
}
return "0"
}
// hasAPIKeyChanged compares old and new API keys. // hasAPIKeyChanged compares old and new API keys.
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool { func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
oldKey := "" oldKey := ""

View File

@@ -3,6 +3,7 @@
package amp package amp
import ( import (
"regexp"
"strings" "strings"
"sync" "sync"
@@ -26,13 +27,15 @@ type ModelMapper interface {
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage. // DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct { type DefaultModelMapper struct {
mu sync.RWMutex mu sync.RWMutex
mappings map[string]string // from -> to (normalized lowercase keys) mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
} }
// NewModelMapper creates a new model mapper with the given initial mappings. // NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{ m := &DefaultModelMapper{
mappings: make(map[string]string), mappings: make(map[string]string),
regexps: nil,
} }
m.UpdateMappings(mappings) m.UpdateMappings(mappings)
return m return m
@@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
// Check for direct mapping // Check for direct mapping
targetModel, exists := m.mappings[normalizedRequest] targetModel, exists := m.mappings[normalizedRequest]
if !exists { if !exists {
return "" // Try regex mappings in order
base, _ := util.NormalizeThinkingModel(requestedModel)
for _, rm := range m.regexps {
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
targetModel = rm.to
exists = true
break
}
}
if !exists {
return ""
}
} }
// Verify target model has available providers // Verify target model has available providers
@@ -77,7 +91,8 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
defer m.mu.Unlock() defer m.mu.Unlock()
// Clear and rebuild mappings // Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings)) m.mappings = make(map[string]string)
m.regexps = m.regexps[:0]
for _, mapping := range mappings { for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From) from := strings.TrimSpace(mapping.From)
@@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
continue continue
} }
// Store with normalized lowercase key for case-insensitive lookup if mapping.Regex {
normalizedFrom := strings.ToLower(from) // Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
m.mappings[normalizedFrom] = to pattern := "(?i)" + from
re, err := regexp.Compile(pattern)
log.Debugf("amp model mapping registered: %s -> %s", from, to) if err != nil {
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
continue
}
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
} else {
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
} }
if len(m.mappings) > 0 { if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
} }
if n := len(m.regexps); n > 0 {
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
}
} }
// GetMappings returns a copy of current mappings (for debugging/status). // GetMappings returns a copy of current mappings (for debugging/status).
@@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
} }
return result return result
} }
type regexMapping struct {
re *regexp.Regexp
to string
}

View File

@@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
t.Error("Original map was modified") t.Error("Original map was modified")
} }
} }
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-1")
mappings := []config.AmpModelMapping{
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
}
mapper := NewModelMapper(mappings)
// Incoming model has reasoning suffix but should match base via regex
result := mapper.MapModel("gpt-5(high)")
if result != "gemini-2.5-pro" {
t.Errorf("Expected gemini-2.5-pro, got %s", result)
}
}
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-2")
defer reg.UnregisterClient("test-client-regex-3")
mappings := []config.AmpModelMapping{
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
}
mapper := NewModelMapper(mappings)
// Exact match should win over regex
result := mapper.MapModel("gpt-5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
// Invalid regex should be skipped and not cause panic
mappings := []config.AmpModelMapping{
{From: "(", To: "target", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("anything")
if result != "" {
t.Errorf("Expected empty result due to invalid regex, got %s", result)
}
}
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-regex-4")
mappings := []config.AmpModelMapping{
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}

View File

@@ -144,6 +144,11 @@ type AmpModelMapping struct {
// To is the target model name to route to (e.g., "claude-sonnet-4"). // To is the target model name to route to (e.g., "claude-sonnet-4").
// The target model must have available providers in the registry. // The target model must have available providers in the registry.
To string `yaml:"to" json:"to"` To string `yaml:"to" json:"to"`
// Regex indicates whether the 'from' field should be interpreted as a regular
// expression for matching model names. When true, this mapping is evaluated
// after exact matches and in the order provided. Defaults to false (exact match).
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
} }
// AmpCode groups Amp CLI integration settings including upstream routing, // AmpCode groups Amp CLI integration settings including upstream routing,