Files
CLIProxyAPI/internal/api/modules/amp/model_mapping.go
이대희 9299897e04 Implements unified model routing
Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach.

This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings.
It provides a more flexible and maintainable routing mechanism by centralizing routing logic.

The changes include:
- Introducing new `routing` package with core routing logic.
- Creating characterization tests to capture existing behavior.
- Implementing model extraction and rewriting.
- Updating AMP module routes to utilize the new routing wrapper.
- Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
2026-02-01 16:58:32 +09:00

299 lines
9.2 KiB
Go

// Package amp provides model mapping functionality for routing Amp CLI requests
// to alternative models when the requested model is not available locally.
package amp
import (
"regexp"
"strings"
"sync"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// ModelMapper provides model name mapping/aliasing for Amp CLI requests.
// When an Amp request comes in for a model that isn't available locally,
// this mapper can redirect it to an alternative model that IS available.
type ModelMapper interface {
// MapModel returns the target model name if a mapping exists and the target
// model has available providers. Returns empty string if no mapping applies.
MapModel(requestedModel string) string
// UpdateMappings refreshes the mapping configuration (for hot-reload).
UpdateMappings(mappings []config.AmpModelMapping)
}
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct {
mu sync.RWMutex
mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
// oauthAliasForward maps channel -> name (lower) -> []alias for oauth-model-alias lookup.
// This allows model-mappings targets to find providers via their aliases.
oauthAliasForward map[string]map[string][]string
}
// NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{
mappings: make(map[string]string),
regexps: nil,
oauthAliasForward: nil,
}
m.UpdateMappings(mappings)
return m
}
// UpdateOAuthModelAlias updates the oauth-model-alias lookup table.
// This is called during initialization and on config hot-reload.
func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.OAuthModelAlias) {
m.mu.Lock()
defer m.mu.Unlock()
if len(aliases) == 0 {
m.oauthAliasForward = nil
return
}
forward := make(map[string]map[string][]string, len(aliases))
for rawChannel, entries := range aliases {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
channelMap := make(map[string][]string)
for _, entry := range entries {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
nameKey := strings.ToLower(name)
channelMap[nameKey] = append(channelMap[nameKey], alias)
}
if len(channelMap) > 0 {
forward[channel] = channelMap
}
}
if len(forward) == 0 {
m.oauthAliasForward = nil
return
}
m.oauthAliasForward = forward
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
}
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
// that have available providers. Useful for fallback when one alias is quota-exceeded.
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
if m.oauthAliasForward == nil {
return nil
}
targetKey := strings.ToLower(strings.TrimSpace(targetModel))
if targetKey == "" {
return nil
}
var result []string
seen := make(map[string]struct{})
// Check all channels for this model name
for _, channelMap := range m.oauthAliasForward {
aliases := channelMap[targetKey]
for _, alias := range aliases {
aliasLower := strings.ToLower(alias)
if _, exists := seen[aliasLower]; exists {
continue
}
providers := util.GetProviderName(alias)
if len(providers) > 0 {
result = append(result, alias)
seen[aliasLower] = struct{}{}
}
}
}
return result
}
// MapModel checks if a mapping exists for the requested model and if the
// target model has available local providers. Returns the mapped model name
// or empty string if no valid mapping exists.
//
// If the requested model contains a thinking suffix (e.g., "g25p(8192)"),
// the suffix is preserved in the returned model name (e.g., "gemini-2.5-pro(8192)").
// However, if the mapping target already contains a suffix, the config suffix
// takes priority over the user's suffix.
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
models := m.MapModelWithFallbacks(requestedModel)
if len(models) == 0 {
return ""
}
return models[0]
}
// MapModelWithFallbacks returns all possible target models for the requested model,
// including fallback aliases from oauth-model-alias. The first model is the primary target,
// and subsequent models are fallbacks to try if the primary is unavailable (e.g., quota exceeded).
func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []string {
if requestedModel == "" {
return nil
}
m.mu.RLock()
defer m.mu.RUnlock()
// Extract thinking suffix from requested model using ParseSuffix
requestResult := thinking.ParseSuffix(requestedModel)
baseModel := requestResult.ModelName
// Normalize the base model for lookup (case-insensitive)
normalizedBase := strings.ToLower(strings.TrimSpace(baseModel))
// Check for direct mapping using base model name
targetModel, exists := m.mappings[normalizedBase]
if !exists {
// Try regex mappings in order using base model only
// (suffix is handled separately via ParseSuffix)
for _, rm := range m.regexps {
if rm.re.MatchString(baseModel) {
targetModel = rm.to
exists = true
break
}
}
if !exists {
return nil
}
}
// Check if target model already has a thinking suffix (config priority)
targetResult := thinking.ParseSuffix(targetModel)
targetBase := targetResult.ModelName
// Helper to apply suffix to a model
applySuffix := func(model string) string {
modelResult := thinking.ParseSuffix(model)
if modelResult.HasSuffix {
return model
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return model + "(" + requestResult.RawSuffix + ")"
}
return model
}
// Verify target model has available providers (use base model for lookup)
providers := util.GetProviderName(targetBase)
// If direct provider available, return it as primary
if len(providers) > 0 {
return []string{applySuffix(targetModel)}
}
// No direct providers - check oauth-model-alias for all aliases that have providers
allAliases := m.findAllAliasesWithProviders(targetBase)
if len(allAliases) == 0 {
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
return nil
}
// Log resolution
if len(allAliases) == 1 {
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
} else {
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1)
}
// Apply suffix to all aliases
result := make([]string, len(allAliases))
for i, alias := range allAliases {
result[i] = applySuffix(alias)
}
return result
}
// UpdateMappings refreshes the mapping configuration from config.
// This is called during initialization and on config hot-reload.
func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
m.mu.Lock()
defer m.mu.Unlock()
// Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings))
m.regexps = make([]regexMapping, 0, len(mappings))
for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if from == "" || to == "" {
log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to)
continue
}
if mapping.Regex {
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
pattern := "(?i)" + from
re, err := regexp.Compile(pattern)
if err != nil {
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
continue
}
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
} else {
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
}
if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
}
if n := len(m.regexps); n > 0 {
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
}
}
// GetMappings returns a copy of current mappings (for debugging/status).
func (m *DefaultModelMapper) GetMappings() map[string]string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]string, len(m.mappings))
for k, v := range m.mappings {
result[k] = v
}
return result
}
// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice.
// Safe for concurrent use.
func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]config.AmpModelMapping, 0, len(m.mappings))
for from, to := range m.mappings {
result = append(result, config.AmpModelMapping{
From: from,
To: to,
})
}
return result
}
type regexMapping struct {
re *regexp.Regexp
to string
}