mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
feature(ampcode): Improves AMP model mapping with alias support
Enhances the AMP model mapping functionality to support fallback mechanisms using . This change allows the system to attempt alternative models (aliases) if the primary mapped model fails due to issues like quota exhaustion. It updates the model mapper to load and utilize the configuration, enabling provider lookup via aliases. It also introduces context keys to pass fallback model names between handlers. Additionally, this change introduces a fix to prevent ReverseProxy from panicking by swallowing ErrAbortHandler panics. Amp-Thread-ID: https://ampcode.com/threads/T-019c0cd1-9e59-722b-83f0-e0582aba6914 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -125,6 +125,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
|||||||
m.registerOnce.Do(func() {
|
m.registerOnce.Do(func() {
|
||||||
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
||||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||||
|
// Load oauth-model-alias for provider lookup via aliases
|
||||||
|
m.modelMapper.UpdateOAuthModelAlias(ctx.Config.OAuthModelAlias)
|
||||||
|
|
||||||
// Store initial config for partial reload comparison
|
// Store initial config for partial reload comparison
|
||||||
settingsCopy := settings
|
settingsCopy := settings
|
||||||
@@ -212,6 +214,11 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always update oauth-model-alias for model mapper (used for provider lookup)
|
||||||
|
if m.modelMapper != nil {
|
||||||
|
m.modelMapper.UpdateOAuthModelAlias(cfg.OAuthModelAlias)
|
||||||
|
}
|
||||||
|
|
||||||
if m.enabled {
|
if m.enabled {
|
||||||
// Check upstream URL change - now supports hot-reload
|
// Check upstream URL change - now supports hot-reload
|
||||||
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package amp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -32,6 +34,10 @@ const (
|
|||||||
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||||
const MappedModelContextKey = "mapped_model"
|
const MappedModelContextKey = "mapped_model"
|
||||||
|
|
||||||
|
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
|
||||||
|
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
|
||||||
|
const FallbackModelsContextKey = "fallback_models"
|
||||||
|
|
||||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||||
fields := log.Fields{
|
fields := log.Fields{
|
||||||
@@ -113,6 +119,16 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
|||||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic(rec)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
requestPath := c.Request.URL.Path
|
requestPath := c.Request.URL.Path
|
||||||
|
|
||||||
// Read the request body to extract the model name
|
// Read the request body to extract the model name
|
||||||
@@ -142,36 +158,57 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||||
}
|
}
|
||||||
|
|
||||||
resolveMappedModel := func() (string, []string) {
|
// resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
|
||||||
|
resolveMappedModels := func() ([]string, []string) {
|
||||||
if fh.modelMapper == nil {
|
if fh.modelMapper == nil {
|
||||||
return "", nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
|
||||||
if mappedModel == "" {
|
if !ok {
|
||||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
// Fallback to single model for non-DefaultModelMapper
|
||||||
}
|
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||||
mappedModel = strings.TrimSpace(mappedModel)
|
if mappedModel == "" {
|
||||||
if mappedModel == "" {
|
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||||
return "", nil
|
}
|
||||||
|
if mappedModel == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||||
|
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||||
|
if len(mappedProviders) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return []string{mappedModel}, mappedProviders
|
||||||
}
|
}
|
||||||
|
|
||||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
// Use MapModelWithFallbacks for DefaultModelMapper
|
||||||
// already specifies its own thinking suffix.
|
mappedModels := mapper.MapModelWithFallbacks(modelName)
|
||||||
if thinkingSuffix != "" {
|
if len(mappedModels) == 0 {
|
||||||
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
|
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
|
||||||
if !mappedSuffixResult.HasSuffix {
|
}
|
||||||
mappedModel += thinkingSuffix
|
if len(mappedModels) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply thinking suffix if needed
|
||||||
|
for i, model := range mappedModels {
|
||||||
|
if thinkingSuffix != "" {
|
||||||
|
suffixResult := thinking.ParseSuffix(model)
|
||||||
|
if !suffixResult.HasSuffix {
|
||||||
|
mappedModels[i] = model + thinkingSuffix
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
// Get providers for the first model
|
||||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
|
||||||
if len(mappedProviders) == 0 {
|
providers := util.GetProviderName(firstBaseModel)
|
||||||
return "", nil
|
if len(providers) == 0 {
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return mappedModel, mappedProviders
|
return mappedModels, providers
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track resolved model for logging (may change if mapping is applied)
|
// Track resolved model for logging (may change if mapping is applied)
|
||||||
@@ -185,13 +222,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
if forceMappings {
|
if forceMappings {
|
||||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||||
// This allows users to route Amp requests to their preferred OAuth providers
|
// This allows users to route Amp requests to their preferred OAuth providers
|
||||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||||
// Mapping found and provider available - rewrite the model in request body
|
// Mapping found and provider available - rewrite the model in request body
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
// Store mapped model and fallbacks in context for handlers
|
||||||
c.Set(MappedModelContextKey, mappedModel)
|
c.Set(MappedModelContextKey, mappedModels[0])
|
||||||
resolvedModel = mappedModel
|
if len(mappedModels) > 1 {
|
||||||
|
c.Set(FallbackModelsContextKey, mappedModels[1:])
|
||||||
|
}
|
||||||
|
resolvedModel = mappedModels[0]
|
||||||
usedMapping = true
|
usedMapping = true
|
||||||
providers = mappedProviders
|
providers = mappedProviders
|
||||||
}
|
}
|
||||||
@@ -206,13 +246,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
|
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
// No providers configured - check if we have a model mapping
|
// No providers configured - check if we have a model mapping
|
||||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||||
// Mapping found and provider available - rewrite the model in request body
|
// Mapping found and provider available - rewrite the model in request body
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
// Store mapped model and fallbacks in context for handlers
|
||||||
c.Set(MappedModelContextKey, mappedModel)
|
c.Set(MappedModelContextKey, mappedModels[0])
|
||||||
resolvedModel = mappedModel
|
if len(mappedModels) > 1 {
|
||||||
|
c.Set(FallbackModelsContextKey, mappedModels[1:])
|
||||||
|
}
|
||||||
|
resolvedModel = mappedModels[0]
|
||||||
usedMapping = true
|
usedMapping = true
|
||||||
providers = mappedProviders
|
providers = mappedProviders
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,18 +30,112 @@ type DefaultModelMapper struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||||
regexps []regexMapping // regex rules evaluated in order
|
regexps []regexMapping // regex rules evaluated in order
|
||||||
|
|
||||||
|
// oauthAliasForward maps channel -> name (lower) -> []alias for oauth-model-alias lookup.
|
||||||
|
// This allows model-mappings targets to find providers via their aliases.
|
||||||
|
oauthAliasForward map[string]map[string][]string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||||
m := &DefaultModelMapper{
|
m := &DefaultModelMapper{
|
||||||
mappings: make(map[string]string),
|
mappings: make(map[string]string),
|
||||||
regexps: nil,
|
regexps: nil,
|
||||||
|
oauthAliasForward: nil,
|
||||||
}
|
}
|
||||||
m.UpdateMappings(mappings)
|
m.UpdateMappings(mappings)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateOAuthModelAlias updates the oauth-model-alias lookup table.
|
||||||
|
// This is called during initialization and on config hot-reload.
|
||||||
|
func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.OAuthModelAlias) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if len(aliases) == 0 {
|
||||||
|
m.oauthAliasForward = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
forward := make(map[string]map[string][]string, len(aliases))
|
||||||
|
for rawChannel, entries := range aliases {
|
||||||
|
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||||
|
if channel == "" || len(entries) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
channelMap := make(map[string][]string)
|
||||||
|
for _, entry := range entries {
|
||||||
|
name := strings.TrimSpace(entry.Name)
|
||||||
|
alias := strings.TrimSpace(entry.Alias)
|
||||||
|
if name == "" || alias == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(name, alias) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nameKey := strings.ToLower(name)
|
||||||
|
channelMap[nameKey] = append(channelMap[nameKey], alias)
|
||||||
|
}
|
||||||
|
if len(channelMap) > 0 {
|
||||||
|
forward[channel] = channelMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(forward) == 0 {
|
||||||
|
m.oauthAliasForward = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.oauthAliasForward = forward
|
||||||
|
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
|
||||||
|
}
|
||||||
|
|
||||||
|
// findProviderViaOAuthAlias checks if targetModel is an oauth-model-alias name
|
||||||
|
// and returns all aliases that have available providers.
|
||||||
|
// Returns the first alias and its providers for backward compatibility,
|
||||||
|
// and also populates allAliases with all available alias models.
|
||||||
|
func (m *DefaultModelMapper) findProviderViaOAuthAlias(targetModel string) (aliasModel string, providers []string) {
|
||||||
|
aliases := m.findAllAliasesWithProviders(targetModel)
|
||||||
|
if len(aliases) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
// Return first one for backward compatibility
|
||||||
|
first := aliases[0]
|
||||||
|
return first, util.GetProviderName(first)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
|
||||||
|
// that have available providers. Useful for fallback when one alias is quota-exceeded.
|
||||||
|
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
|
||||||
|
if m.oauthAliasForward == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
targetKey := strings.ToLower(strings.TrimSpace(targetModel))
|
||||||
|
if targetKey == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
|
||||||
|
// Check all channels for this model name
|
||||||
|
for _, channelMap := range m.oauthAliasForward {
|
||||||
|
aliases := channelMap[targetKey]
|
||||||
|
for _, alias := range aliases {
|
||||||
|
aliasLower := strings.ToLower(alias)
|
||||||
|
if _, exists := seen[aliasLower]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providers := util.GetProviderName(alias)
|
||||||
|
if len(providers) > 0 {
|
||||||
|
result = append(result, alias)
|
||||||
|
seen[aliasLower] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// MapModel checks if a mapping exists for the requested model and if the
|
// MapModel checks if a mapping exists for the requested model and if the
|
||||||
// target model has available local providers. Returns the mapped model name
|
// target model has available local providers. Returns the mapped model name
|
||||||
// or empty string if no valid mapping exists.
|
// or empty string if no valid mapping exists.
|
||||||
@@ -51,9 +145,20 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
|||||||
// However, if the mapping target already contains a suffix, the config suffix
|
// However, if the mapping target already contains a suffix, the config suffix
|
||||||
// takes priority over the user's suffix.
|
// takes priority over the user's suffix.
|
||||||
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||||
if requestedModel == "" {
|
models := m.MapModelWithFallbacks(requestedModel)
|
||||||
|
if len(models) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
return models[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapModelWithFallbacks returns all possible target models for the requested model,
|
||||||
|
// including fallback aliases from oauth-model-alias. The first model is the primary target,
|
||||||
|
// and subsequent models are fallbacks to try if the primary is unavailable (e.g., quota exceeded).
|
||||||
|
func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []string {
|
||||||
|
if requestedModel == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
@@ -78,34 +183,54 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !exists {
|
if !exists {
|
||||||
return ""
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if target model already has a thinking suffix (config priority)
|
// Check if target model already has a thinking suffix (config priority)
|
||||||
targetResult := thinking.ParseSuffix(targetModel)
|
targetResult := thinking.ParseSuffix(targetModel)
|
||||||
|
targetBase := targetResult.ModelName
|
||||||
|
|
||||||
|
// Helper to apply suffix to a model
|
||||||
|
applySuffix := func(model string) string {
|
||||||
|
modelResult := thinking.ParseSuffix(model)
|
||||||
|
if modelResult.HasSuffix {
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||||
|
return model + "(" + requestResult.RawSuffix + ")"
|
||||||
|
}
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
// Verify target model has available providers (use base model for lookup)
|
// Verify target model has available providers (use base model for lookup)
|
||||||
providers := util.GetProviderName(targetResult.ModelName)
|
providers := util.GetProviderName(targetBase)
|
||||||
if len(providers) == 0 {
|
|
||||||
|
// If direct provider available, return it as primary
|
||||||
|
if len(providers) > 0 {
|
||||||
|
return []string{applySuffix(targetModel)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No direct providers - check oauth-model-alias for all aliases that have providers
|
||||||
|
allAliases := m.findAllAliasesWithProviders(targetBase)
|
||||||
|
if len(allAliases) == 0 {
|
||||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||||
return ""
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
|
// Log resolution
|
||||||
if targetResult.HasSuffix {
|
if len(allAliases) == 1 {
|
||||||
// Config's "to" already contains a suffix - use it as-is (config priority)
|
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
|
||||||
return targetModel
|
} else {
|
||||||
|
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Preserve user's thinking suffix on the mapped model
|
// Apply suffix to all aliases
|
||||||
// (skip empty suffixes to avoid returning "model()")
|
result := make([]string, len(allAliases))
|
||||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
for i, alias := range allAliases {
|
||||||
return targetModel + "(" + requestResult.RawSuffix + ")"
|
result[i] = applySuffix(alias)
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
|
||||||
return targetModel
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateMappings refreshes the mapping configuration from config.
|
// UpdateMappings refreshes the mapping configuration from config.
|
||||||
|
|||||||
@@ -992,8 +992,8 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify Amp module only when Amp config has changed.
|
// Notify Amp module when Amp config or OAuth model aliases have changed.
|
||||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
|
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) || !reflect.DeepEqual(oldCfg.OAuthModelAlias, cfg.OAuthModelAlias)
|
||||||
if ampConfigChanged {
|
if ampConfigChanged {
|
||||||
if s.ampModule != nil {
|
if s.ampModule != nil {
|
||||||
log.Debugf("triggering amp module config update")
|
log.Debugf("triggering amp module config update")
|
||||||
|
|||||||
@@ -562,192 +562,188 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeWithFallback(
|
||||||
|
ctx context.Context,
|
||||||
|
initialProviders []string,
|
||||||
|
req cliproxyexecutor.Request,
|
||||||
|
opts cliproxyexecutor.Options,
|
||||||
|
exec func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error,
|
||||||
|
) error {
|
||||||
|
routeModel := req.Model
|
||||||
|
providers := initialProviders
|
||||||
|
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||||
|
tried := make(map[string]struct{})
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
// Track fallback models from context (provided by Amp module fallback_models key)
|
||||||
|
var fallbacks []string
|
||||||
|
if v := ctx.Value("fallback_models"); v != nil {
|
||||||
|
if fs, ok := v.([]string); ok {
|
||||||
|
fallbacks = fs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fallbackIdx := -1
|
||||||
|
|
||||||
|
for {
|
||||||
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||||
|
if errPick != nil {
|
||||||
|
// No more auths for current model. Try next fallback model if available.
|
||||||
|
if fallbackIdx+1 < len(fallbacks) {
|
||||||
|
fallbackIdx++
|
||||||
|
routeModel = fallbacks[fallbackIdx]
|
||||||
|
log.Debugf("no more auths for current model, trying fallback model: %s (fallback %d/%d)", routeModel, fallbackIdx+1, len(fallbacks))
|
||||||
|
|
||||||
|
// Reset tried set for the new model and find its providers
|
||||||
|
tried = make(map[string]struct{})
|
||||||
|
providers = util.GetProviderName(thinking.ParseSuffix(routeModel).ModelName)
|
||||||
|
// Reset opts for the new model
|
||||||
|
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||||
|
if len(providers) == 0 {
|
||||||
|
log.Debugf("fallback model %s has no providers, skipping", routeModel)
|
||||||
|
continue // Try next fallback if this one has no providers
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
return errPick
|
||||||
|
}
|
||||||
|
|
||||||
|
tried[auth.ID] = struct{}{}
|
||||||
|
if err := exec(ctx, executor, auth, provider, routeModel); err != nil {
|
||||||
|
if errCtx := ctx.Err(); errCtx != nil {
|
||||||
|
return errCtx
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeMixedAttempt(
|
||||||
|
ctx context.Context,
|
||||||
|
auth *Auth,
|
||||||
|
provider, routeModel string,
|
||||||
|
req cliproxyexecutor.Request,
|
||||||
|
opts cliproxyexecutor.Options,
|
||||||
|
exec func(ctx context.Context, execReq cliproxyexecutor.Request) error,
|
||||||
|
) error {
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
|
||||||
|
execCtx := ctx
|
||||||
|
if rt := m.roundTripperFor(auth); rt != nil {
|
||||||
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
execReq := req
|
||||||
|
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||||
|
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||||
|
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||||
|
|
||||||
|
err := exec(execCtx, execReq)
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: err == nil}
|
||||||
|
if err != nil {
|
||||||
|
result.Error = &Error{Message: err.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(err, &se) && se != nil {
|
||||||
|
result.Error.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
if ra := retryAfterFromError(err); ra != nil {
|
||||||
|
result.RetryAfter = ra
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
routeModel := req.Model
|
|
||||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
|
||||||
tried := make(map[string]struct{})
|
|
||||||
var lastErr error
|
|
||||||
for {
|
|
||||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
|
||||||
if errPick != nil {
|
|
||||||
if lastErr != nil {
|
|
||||||
return cliproxyexecutor.Response{}, lastErr
|
|
||||||
}
|
|
||||||
return cliproxyexecutor.Response{}, errPick
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := logEntryWithRequestID(ctx)
|
var resp cliproxyexecutor.Response
|
||||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||||
|
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||||
tried[auth.ID] = struct{}{}
|
var errExec error
|
||||||
execCtx := ctx
|
resp, errExec = executor.Execute(execCtx, auth, execReq, opts)
|
||||||
if rt := m.roundTripperFor(auth); rt != nil {
|
return errExec
|
||||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
})
|
||||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
})
|
||||||
}
|
return resp, err
|
||||||
execReq := req
|
|
||||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
|
||||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
|
||||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
|
||||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
|
||||||
if errExec != nil {
|
|
||||||
if errCtx := execCtx.Err(); errCtx != nil {
|
|
||||||
return cliproxyexecutor.Response{}, errCtx
|
|
||||||
}
|
|
||||||
result.Error = &Error{Message: errExec.Error()}
|
|
||||||
var se cliproxyexecutor.StatusError
|
|
||||||
if errors.As(errExec, &se) && se != nil {
|
|
||||||
result.Error.HTTPStatus = se.StatusCode()
|
|
||||||
}
|
|
||||||
if ra := retryAfterFromError(errExec); ra != nil {
|
|
||||||
result.RetryAfter = ra
|
|
||||||
}
|
|
||||||
m.MarkResult(execCtx, result)
|
|
||||||
lastErr = errExec
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m.MarkResult(execCtx, result)
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
routeModel := req.Model
|
|
||||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
|
||||||
tried := make(map[string]struct{})
|
|
||||||
var lastErr error
|
|
||||||
for {
|
|
||||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
|
||||||
if errPick != nil {
|
|
||||||
if lastErr != nil {
|
|
||||||
return cliproxyexecutor.Response{}, lastErr
|
|
||||||
}
|
|
||||||
return cliproxyexecutor.Response{}, errPick
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := logEntryWithRequestID(ctx)
|
var resp cliproxyexecutor.Response
|
||||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||||
|
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||||
tried[auth.ID] = struct{}{}
|
var errExec error
|
||||||
execCtx := ctx
|
resp, errExec = executor.CountTokens(execCtx, auth, execReq, opts)
|
||||||
if rt := m.roundTripperFor(auth); rt != nil {
|
return errExec
|
||||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
})
|
||||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
})
|
||||||
}
|
return resp, err
|
||||||
execReq := req
|
|
||||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
|
||||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
|
||||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
|
||||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
|
||||||
if errExec != nil {
|
|
||||||
if errCtx := execCtx.Err(); errCtx != nil {
|
|
||||||
return cliproxyexecutor.Response{}, errCtx
|
|
||||||
}
|
|
||||||
result.Error = &Error{Message: errExec.Error()}
|
|
||||||
var se cliproxyexecutor.StatusError
|
|
||||||
if errors.As(errExec, &se) && se != nil {
|
|
||||||
result.Error.HTTPStatus = se.StatusCode()
|
|
||||||
}
|
|
||||||
if ra := retryAfterFromError(errExec); ra != nil {
|
|
||||||
result.RetryAfter = ra
|
|
||||||
}
|
|
||||||
m.MarkResult(execCtx, result)
|
|
||||||
lastErr = errExec
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m.MarkResult(execCtx, result)
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
routeModel := req.Model
|
|
||||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
|
||||||
tried := make(map[string]struct{})
|
|
||||||
var lastErr error
|
|
||||||
for {
|
|
||||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
|
||||||
if errPick != nil {
|
|
||||||
if lastErr != nil {
|
|
||||||
return nil, lastErr
|
|
||||||
}
|
|
||||||
return nil, errPick
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := logEntryWithRequestID(ctx)
|
var chunks <-chan cliproxyexecutor.StreamChunk
|
||||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||||
|
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||||
|
var errExec error
|
||||||
|
chunks, errExec = executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||||
|
if errExec != nil {
|
||||||
|
return errExec
|
||||||
|
}
|
||||||
|
|
||||||
tried[auth.ID] = struct{}{}
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
execCtx := ctx
|
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||||
if rt := m.roundTripperFor(auth); rt != nil {
|
defer close(out)
|
||||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
var failed bool
|
||||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
forward := true
|
||||||
}
|
for chunk := range streamChunks {
|
||||||
execReq := req
|
if chunk.Err != nil && !failed {
|
||||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
failed = true
|
||||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
rerr := &Error{Message: chunk.Err.Error()}
|
||||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
var se cliproxyexecutor.StatusError
|
||||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
if errors.As(chunk.Err, &se) && se != nil {
|
||||||
if errStream != nil {
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
if errCtx := execCtx.Err(); errCtx != nil {
|
}
|
||||||
return nil, errCtx
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||||
}
|
}
|
||||||
rerr := &Error{Message: errStream.Error()}
|
if !forward {
|
||||||
var se cliproxyexecutor.StatusError
|
continue
|
||||||
if errors.As(errStream, &se) && se != nil {
|
}
|
||||||
rerr.HTTPStatus = se.StatusCode()
|
if streamCtx == nil {
|
||||||
}
|
out <- chunk
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
continue
|
||||||
result.RetryAfter = retryAfterFromError(errStream)
|
}
|
||||||
m.MarkResult(execCtx, result)
|
select {
|
||||||
lastErr = errStream
|
case <-streamCtx.Done():
|
||||||
continue
|
forward = false
|
||||||
}
|
case out <- chunk:
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
|
||||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
|
||||||
defer close(out)
|
|
||||||
var failed bool
|
|
||||||
forward := true
|
|
||||||
for chunk := range streamChunks {
|
|
||||||
if chunk.Err != nil && !failed {
|
|
||||||
failed = true
|
|
||||||
rerr := &Error{Message: chunk.Err.Error()}
|
|
||||||
var se cliproxyexecutor.StatusError
|
|
||||||
if errors.As(chunk.Err, &se) && se != nil {
|
|
||||||
rerr.HTTPStatus = se.StatusCode()
|
|
||||||
}
|
}
|
||||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
|
||||||
}
|
}
|
||||||
if !forward {
|
if !failed {
|
||||||
continue
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||||
}
|
}
|
||||||
if streamCtx == nil {
|
}(execCtx, auth.Clone(), provider, chunks)
|
||||||
out <- chunk
|
chunks = out
|
||||||
continue
|
return nil
|
||||||
}
|
})
|
||||||
select {
|
})
|
||||||
case <-streamCtx.Done():
|
return chunks, err
|
||||||
forward = false
|
|
||||||
case out <- chunk:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !failed {
|
|
||||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
|
||||||
}
|
|
||||||
}(execCtx, auth.Clone(), provider, chunks)
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
||||||
|
|||||||
Reference in New Issue
Block a user