diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index c18657c9..2a2ccb13 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -279,19 +279,26 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp return true } - // Build map for efficient comparison - oldMap := make(map[string]string, len(old.ModelMappings)) - for _, mapping := range old.ModelMappings { - oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To) - } + // Build map for efficient and robust comparison + type mappingInfo struct { + to string + regex bool + } + oldMap := make(map[string]mappingInfo, len(old.ModelMappings)) + for _, mapping := range old.ModelMappings { + oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{ + to: strings.TrimSpace(mapping.To), + regex: mapping.Regex, + } + } - for _, mapping := range new.ModelMappings { - from := strings.TrimSpace(mapping.From) - to := strings.TrimSpace(mapping.To) - if oldTo, exists := oldMap[from]; !exists || oldTo != to { - return true - } - } + for _, mapping := range new.ModelMappings { + from := strings.TrimSpace(mapping.From) + to := strings.TrimSpace(mapping.To) + if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex { + return true + } + } return false } diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index bc31c4e5..0741a52c 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -3,6 +3,7 @@ package amp import ( + "regexp" "strings" "sync" @@ -26,13 +27,15 @@ type ModelMapper interface { // DefaultModelMapper implements ModelMapper with thread-safe mapping storage. type DefaultModelMapper struct { mu sync.RWMutex - mappings map[string]string // from -> to (normalized lowercase keys) + mappings map[string]string // exact: from -> to (normalized lowercase keys) + regexps []regexMapping // regex rules evaluated in order } // NewModelMapper creates a new model mapper with the given initial mappings. func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { m := &DefaultModelMapper{ - mappings: make(map[string]string), + mappings: make(map[string]string), + regexps: nil, } m.UpdateMappings(mappings) return m @@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string { // Check for direct mapping targetModel, exists := m.mappings[normalizedRequest] if !exists { - return "" + // Try regex mappings in order + base, _ := util.NormalizeThinkingModel(requestedModel) + for _, rm := range m.regexps { + if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) { + targetModel = rm.to + exists = true + break + } + } + if !exists { + return "" + } } // Verify target model has available providers @@ -77,7 +91,8 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { defer m.mu.Unlock() // Clear and rebuild mappings - m.mappings = make(map[string]string, len(mappings)) + m.mappings = make(map[string]string, len(mappings)) + m.regexps = make([]regexMapping, 0, len(mappings)) for _, mapping := range mappings { from := strings.TrimSpace(mapping.From) @@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { continue } - // Store with normalized lowercase key for case-insensitive lookup - normalizedFrom := strings.ToLower(from) - m.mappings[normalizedFrom] = to - - log.Debugf("amp model mapping registered: %s -> %s", from, to) + if mapping.Regex { + // Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups + pattern := "(?i)" + from + re, err := regexp.Compile(pattern) + if err != nil { + log.Warnf("amp model mapping: invalid regex %q: %v", from, err) + continue + } + m.regexps = append(m.regexps, regexMapping{re: re, to: to}) + log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to) + } else { + // Store with normalized lowercase key for case-insensitive lookup + normalizedFrom := strings.ToLower(from) + m.mappings[normalizedFrom] = to + log.Debugf("amp model mapping registered: %s -> %s", from, to) + } } if len(m.mappings) > 0 { log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) } + if n := len(m.regexps); n > 0 { + log.Infof("amp model mapping: loaded %d regex mapping(s)", n) + } } // GetMappings returns a copy of current mappings (for debugging/status). @@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string { } return result } + +type regexMapping struct { + re *regexp.Regexp + to string +} diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go index 664a17c5..f4691448 100644 --- a/internal/api/modules/amp/model_mapping_test.go +++ b/internal/api/modules/amp/model_mapping_test.go @@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) { t.Error("Original map was modified") } } + +func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{ + {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"}, + }) + defer reg.UnregisterClient("test-client-regex-1") + + mappings := []config.AmpModelMapping{ + {From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true}, + } + + mapper := NewModelMapper(mappings) + + // Incoming model has reasoning suffix but should match base via regex + result := mapper.MapModel("gpt-5(high)") + if result != "gemini-2.5-pro" { + t.Errorf("Expected gemini-2.5-pro, got %s", result) + } +} + +func TestModelMapper_Regex_ExactPrecedence(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{ + {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"}, + }) + defer reg.UnregisterClient("test-client-regex-2") + defer reg.UnregisterClient("test-client-regex-3") + + mappings := []config.AmpModelMapping{ + {From: "gpt-5", To: "claude-sonnet-4"}, // exact + {From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex + } + + mapper := NewModelMapper(mappings) + + // Exact match should win over regex + result := mapper.MapModel("gpt-5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} + +func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) { + // Invalid regex should be skipped and not cause panic + mappings := []config.AmpModelMapping{ + {From: "(", To: "target", Regex: true}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("anything") + if result != "" { + t.Errorf("Expected empty result due to invalid regex, got %s", result) + } +} + +func TestModelMapper_Regex_CaseInsensitive(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{ + {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"}, + }) + defer reg.UnregisterClient("test-client-regex-4") + + mappings := []config.AmpModelMapping{ + {From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("claude-opus-4.5") + if result != "claude-sonnet-4" { + t.Errorf("Expected claude-sonnet-4, got %s", result) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 6bd74c03..9d0ad606 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -144,6 +144,11 @@ type AmpModelMapping struct { // To is the target model name to route to (e.g., "claude-sonnet-4"). // The target model must have available providers in the registry. To string `yaml:"to" json:"to"` + + // Regex indicates whether the 'from' field should be interpreted as a regular + // expression for matching model names. When true, this mapping is evaluated + // after exact matches and in the order provided. Defaults to false (exact match). + Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"` } // AmpCode groups Amp CLI integration settings including upstream routing,