feature(ampcode): Improves AMP model mapping with alias support

Enhances the AMP model mapping functionality to support fallback mechanisms using .

This change allows the system to attempt alternative models (aliases) if the primary mapped model fails due to issues like quota exhaustion. It updates the model mapper to load and utilize the  configuration, enabling provider lookup via aliases. It also introduces context keys to pass fallback model names between handlers.

Additionally, this change introduces a fix to prevent ReverseProxy from panicking by swallowing ErrAbortHandler panics.

Amp-Thread-ID: https://ampcode.com/threads/T-019c0cd1-9e59-722b-83f0-e0582aba6914
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
이대희
2026-01-30 12:50:53 +09:00
committed by hkfires
parent 2854e04bbb
commit 09044e8ccc
5 changed files with 384 additions and 213 deletions

View File

@@ -125,6 +125,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
m.registerOnce.Do(func() { m.registerOnce.Do(func() {
// Initialize model mapper from config (for routing unavailable models to alternatives) // Initialize model mapper from config (for routing unavailable models to alternatives)
m.modelMapper = NewModelMapper(settings.ModelMappings) m.modelMapper = NewModelMapper(settings.ModelMappings)
// Load oauth-model-alias for provider lookup via aliases
m.modelMapper.UpdateOAuthModelAlias(ctx.Config.OAuthModelAlias)
// Store initial config for partial reload comparison // Store initial config for partial reload comparison
settingsCopy := settings settingsCopy := settings
@@ -212,6 +214,11 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
} }
} }
// Always update oauth-model-alias for model mapper (used for provider lookup)
if m.modelMapper != nil {
m.modelMapper.UpdateOAuthModelAlias(cfg.OAuthModelAlias)
}
if m.enabled { if m.enabled {
// Check upstream URL change - now supports hot-reload // Check upstream URL change - now supports hot-reload
if newUpstreamURL == "" && oldUpstreamURL != "" { if newUpstreamURL == "" && oldUpstreamURL != "" {

View File

@@ -2,7 +2,9 @@ package amp
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"net/http"
"net/http/httputil" "net/http/httputil"
"strings" "strings"
"time" "time"
@@ -32,6 +34,10 @@ const (
// MappedModelContextKey is the Gin context key for passing mapped model names. // MappedModelContextKey is the Gin context key for passing mapped model names.
const MappedModelContextKey = "mapped_model" const MappedModelContextKey = "mapped_model"
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
const FallbackModelsContextKey = "fallback_models"
// logAmpRouting logs the routing decision for an Amp request with structured fields // logAmpRouting logs the routing decision for an Amp request with structured fields
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
fields := log.Fields{ fields := log.Fields{
@@ -113,6 +119,16 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com // If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
defer func() {
if rec := recover(); rec != nil {
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
return
}
panic(rec)
}
}()
requestPath := c.Request.URL.Path requestPath := c.Request.URL.Path
// Read the request body to extract the model name // Read the request body to extract the model name
@@ -142,36 +158,57 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
thinkingSuffix = "(" + suffixResult.RawSuffix + ")" thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
} }
resolveMappedModel := func() (string, []string) { // resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
resolveMappedModels := func() ([]string, []string) {
if fh.modelMapper == nil { if fh.modelMapper == nil {
return "", nil return nil, nil
} }
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
if !ok {
// Fallback to single model for non-DefaultModelMapper
mappedModel := fh.modelMapper.MapModel(modelName) mappedModel := fh.modelMapper.MapModel(modelName)
if mappedModel == "" { if mappedModel == "" {
mappedModel = fh.modelMapper.MapModel(normalizedModel) mappedModel = fh.modelMapper.MapModel(normalizedModel)
} }
mappedModel = strings.TrimSpace(mappedModel)
if mappedModel == "" { if mappedModel == "" {
return "", nil return nil, nil
} }
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
// already specifies its own thinking suffix.
if thinkingSuffix != "" {
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
if !mappedSuffixResult.HasSuffix {
mappedModel += thinkingSuffix
}
}
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
mappedProviders := util.GetProviderName(mappedBaseModel) mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 { if len(mappedProviders) == 0 {
return "", nil return nil, nil
}
return []string{mappedModel}, mappedProviders
} }
return mappedModel, mappedProviders // Use MapModelWithFallbacks for DefaultModelMapper
mappedModels := mapper.MapModelWithFallbacks(modelName)
if len(mappedModels) == 0 {
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
}
if len(mappedModels) == 0 {
return nil, nil
}
// Apply thinking suffix if needed
for i, model := range mappedModels {
if thinkingSuffix != "" {
suffixResult := thinking.ParseSuffix(model)
if !suffixResult.HasSuffix {
mappedModels[i] = model + thinkingSuffix
}
}
}
// Get providers for the first model
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
providers := util.GetProviderName(firstBaseModel)
if len(providers) == 0 {
return nil, nil
}
return mappedModels, providers
} }
// Track resolved model for logging (may change if mapping is applied) // Track resolved model for logging (may change if mapping is applied)
@@ -185,13 +222,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
if forceMappings { if forceMappings {
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers // This allows users to route Amp requests to their preferred OAuth providers
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
// Mapping found and provider available - rewrite the model in request body // Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge) // Store mapped model and fallbacks in context for handlers
c.Set(MappedModelContextKey, mappedModel) c.Set(MappedModelContextKey, mappedModels[0])
resolvedModel = mappedModel if len(mappedModels) > 1 {
c.Set(FallbackModelsContextKey, mappedModels[1:])
}
resolvedModel = mappedModels[0]
usedMapping = true usedMapping = true
providers = mappedProviders providers = mappedProviders
} }
@@ -206,13 +246,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
if len(providers) == 0 { if len(providers) == 0 {
// No providers configured - check if we have a model mapping // No providers configured - check if we have a model mapping
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
// Mapping found and provider available - rewrite the model in request body // Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge) // Store mapped model and fallbacks in context for handlers
c.Set(MappedModelContextKey, mappedModel) c.Set(MappedModelContextKey, mappedModels[0])
resolvedModel = mappedModel if len(mappedModels) > 1 {
c.Set(FallbackModelsContextKey, mappedModels[1:])
}
resolvedModel = mappedModels[0]
usedMapping = true usedMapping = true
providers = mappedProviders providers = mappedProviders
} }

View File

@@ -30,6 +30,10 @@ type DefaultModelMapper struct {
mu sync.RWMutex mu sync.RWMutex
mappings map[string]string // exact: from -> to (normalized lowercase keys) mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order regexps []regexMapping // regex rules evaluated in order
// oauthAliasForward maps channel -> name (lower) -> []alias for oauth-model-alias lookup.
// This allows model-mappings targets to find providers via their aliases.
oauthAliasForward map[string]map[string][]string
} }
// NewModelMapper creates a new model mapper with the given initial mappings. // NewModelMapper creates a new model mapper with the given initial mappings.
@@ -37,11 +41,101 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{ m := &DefaultModelMapper{
mappings: make(map[string]string), mappings: make(map[string]string),
regexps: nil, regexps: nil,
oauthAliasForward: nil,
} }
m.UpdateMappings(mappings) m.UpdateMappings(mappings)
return m return m
} }
// UpdateOAuthModelAlias updates the oauth-model-alias lookup table.
// This is called during initialization and on config hot-reload.
func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.OAuthModelAlias) {
m.mu.Lock()
defer m.mu.Unlock()
if len(aliases) == 0 {
m.oauthAliasForward = nil
return
}
forward := make(map[string]map[string][]string, len(aliases))
for rawChannel, entries := range aliases {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
channelMap := make(map[string][]string)
for _, entry := range entries {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
nameKey := strings.ToLower(name)
channelMap[nameKey] = append(channelMap[nameKey], alias)
}
if len(channelMap) > 0 {
forward[channel] = channelMap
}
}
if len(forward) == 0 {
m.oauthAliasForward = nil
return
}
m.oauthAliasForward = forward
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
}
// findProviderViaOAuthAlias checks if targetModel is an oauth-model-alias name
// and returns all aliases that have available providers.
// Returns the first alias and its providers for backward compatibility,
// and also populates allAliases with all available alias models.
func (m *DefaultModelMapper) findProviderViaOAuthAlias(targetModel string) (aliasModel string, providers []string) {
aliases := m.findAllAliasesWithProviders(targetModel)
if len(aliases) == 0 {
return "", nil
}
// Return first one for backward compatibility
first := aliases[0]
return first, util.GetProviderName(first)
}
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
// that have available providers. Useful for fallback when one alias is quota-exceeded.
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
if m.oauthAliasForward == nil {
return nil
}
targetKey := strings.ToLower(strings.TrimSpace(targetModel))
if targetKey == "" {
return nil
}
var result []string
seen := make(map[string]struct{})
// Check all channels for this model name
for _, channelMap := range m.oauthAliasForward {
aliases := channelMap[targetKey]
for _, alias := range aliases {
aliasLower := strings.ToLower(alias)
if _, exists := seen[aliasLower]; exists {
continue
}
providers := util.GetProviderName(alias)
if len(providers) > 0 {
result = append(result, alias)
seen[aliasLower] = struct{}{}
}
}
}
return result
}
// MapModel checks if a mapping exists for the requested model and if the // MapModel checks if a mapping exists for the requested model and if the
// target model has available local providers. Returns the mapped model name // target model has available local providers. Returns the mapped model name
// or empty string if no valid mapping exists. // or empty string if no valid mapping exists.
@@ -51,9 +145,20 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
// However, if the mapping target already contains a suffix, the config suffix // However, if the mapping target already contains a suffix, the config suffix
// takes priority over the user's suffix. // takes priority over the user's suffix.
func (m *DefaultModelMapper) MapModel(requestedModel string) string { func (m *DefaultModelMapper) MapModel(requestedModel string) string {
if requestedModel == "" { models := m.MapModelWithFallbacks(requestedModel)
if len(models) == 0 {
return "" return ""
} }
return models[0]
}
// MapModelWithFallbacks returns all possible target models for the requested model,
// including fallback aliases from oauth-model-alias. The first model is the primary target,
// and subsequent models are fallbacks to try if the primary is unavailable (e.g., quota exceeded).
func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []string {
if requestedModel == "" {
return nil
}
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
@@ -78,34 +183,54 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
} }
} }
if !exists { if !exists {
return "" return nil
} }
} }
// Check if target model already has a thinking suffix (config priority) // Check if target model already has a thinking suffix (config priority)
targetResult := thinking.ParseSuffix(targetModel) targetResult := thinking.ParseSuffix(targetModel)
targetBase := targetResult.ModelName
// Helper to apply suffix to a model
applySuffix := func(model string) string {
modelResult := thinking.ParseSuffix(model)
if modelResult.HasSuffix {
return model
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return model + "(" + requestResult.RawSuffix + ")"
}
return model
}
// Verify target model has available providers (use base model for lookup) // Verify target model has available providers (use base model for lookup)
providers := util.GetProviderName(targetResult.ModelName) providers := util.GetProviderName(targetBase)
if len(providers) == 0 {
// If direct provider available, return it as primary
if len(providers) > 0 {
return []string{applySuffix(targetModel)}
}
// No direct providers - check oauth-model-alias for all aliases that have providers
allAliases := m.findAllAliasesWithProviders(targetBase)
if len(allAliases) == 0 {
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
return "" return nil
} }
// Suffix handling: config suffix takes priority, otherwise preserve user suffix // Log resolution
if targetResult.HasSuffix { if len(allAliases) == 1 {
// Config's "to" already contains a suffix - use it as-is (config priority) log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
return targetModel } else {
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases))
} }
// Preserve user's thinking suffix on the mapped model // Apply suffix to all aliases
// (skip empty suffixes to avoid returning "model()") result := make([]string, len(allAliases))
if requestResult.HasSuffix && requestResult.RawSuffix != "" { for i, alias := range allAliases {
return targetModel + "(" + requestResult.RawSuffix + ")" result[i] = applySuffix(alias)
} }
return result
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
return targetModel
} }
// UpdateMappings refreshes the mapping configuration from config. // UpdateMappings refreshes the mapping configuration from config.

View File

@@ -992,8 +992,8 @@ func (s *Server) UpdateClients(cfg *config.Config) {
s.mgmt.SetAuthManager(s.handlers.AuthManager) s.mgmt.SetAuthManager(s.handlers.AuthManager)
} }
// Notify Amp module only when Amp config has changed. // Notify Amp module when Amp config or OAuth model aliases have changed.
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) || !reflect.DeepEqual(oldCfg.OAuthModelAlias, cfg.OAuthModelAlias)
if ampConfigChanged { if ampConfigChanged {
if s.ampModule != nil { if s.ampModule != nil {
log.Debugf("triggering amp module config update") log.Debugf("triggering amp module config update")

View File

@@ -562,158 +562,151 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "auth_not_found", Message: "no auth available"} return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
} }
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (m *Manager) executeWithFallback(
if len(providers) == 0 { ctx context.Context,
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} initialProviders []string,
} req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
exec func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error,
) error {
routeModel := req.Model routeModel := req.Model
providers := initialProviders
opts = ensureRequestedModelMetadata(opts, routeModel) opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{}) tried := make(map[string]struct{})
var lastErr error var lastErr error
// Track fallback models from context (provided by Amp module fallback_models key)
var fallbacks []string
if v := ctx.Value("fallback_models"); v != nil {
if fs, ok := v.([]string); ok {
fallbacks = fs
}
}
fallbackIdx := -1
for { for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil { if errPick != nil {
if lastErr != nil { // No more auths for current model. Try next fallback model if available.
return cliproxyexecutor.Response{}, lastErr if fallbackIdx+1 < len(fallbacks) {
fallbackIdx++
routeModel = fallbacks[fallbackIdx]
log.Debugf("no more auths for current model, trying fallback model: %s (fallback %d/%d)", routeModel, fallbackIdx+1, len(fallbacks))
// Reset tried set for the new model and find its providers
tried = make(map[string]struct{})
providers = util.GetProviderName(thinking.ParseSuffix(routeModel).ModelName)
// Reset opts for the new model
opts = ensureRequestedModelMetadata(opts, routeModel)
if len(providers) == 0 {
log.Debugf("fallback model %s has no providers, skipping", routeModel)
continue // Try next fallback if this one has no providers
} }
return cliproxyexecutor.Response{}, errPick continue
} }
if lastErr != nil {
return lastErr
}
return errPick
}
tried[auth.ID] = struct{}{}
if err := exec(ctx, executor, auth, provider, routeModel); err != nil {
if errCtx := ctx.Err(); errCtx != nil {
return errCtx
}
lastErr = err
continue
}
return nil
}
}
func (m *Manager) executeMixedAttempt(
ctx context.Context,
auth *Auth,
provider, routeModel string,
req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
exec func(ctx context.Context, execReq cliproxyexecutor.Request) error,
) error {
entry := logEntryWithRequestID(ctx) entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model) debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil { if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
} }
execReq := req execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth) execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} err := exec(execCtx, execReq)
if errExec != nil { result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: err == nil}
if errCtx := execCtx.Err(); errCtx != nil { if err != nil {
return cliproxyexecutor.Response{}, errCtx result.Error = &Error{Message: err.Error()}
}
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil { if errors.As(err, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode() result.Error.HTTPStatus = se.StatusCode()
} }
if ra := retryAfterFromError(errExec); ra != nil { if ra := retryAfterFromError(err); ra != nil {
result.RetryAfter = ra result.RetryAfter = ra
} }
m.MarkResult(execCtx, result)
lastErr = errExec
continue
} }
m.MarkResult(execCtx, result) m.MarkResult(execCtx, result)
return resp, nil return err
}
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
var resp cliproxyexecutor.Response
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
var errExec error
resp, errExec = executor.Execute(execCtx, auth, execReq, opts)
return errExec
})
})
return resp, err
} }
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 { if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx) var resp cliproxyexecutor.Response
debugLogAuthSelection(entry, auth, provider, req.Model) err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
tried[auth.ID] = struct{}{} var errExec error
execCtx := ctx resp, errExec = executor.CountTokens(execCtx, auth, execReq, opts)
if rt := m.roundTripperFor(auth); rt != nil { return errExec
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) })
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) })
} return resp, err
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
} }
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
if len(providers) == 0 { if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel) var chunks <-chan cliproxyexecutor.StreamChunk
tried := make(map[string]struct{}) err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
var lastErr error return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
for { var errExec error
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) chunks, errExec = executor.ExecuteStream(execCtx, auth, execReq, opts)
if errPick != nil { if errExec != nil {
if lastErr != nil { return errExec
return nil, lastErr
}
return nil, errPick
} }
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk) out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) { go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out) defer close(out)
@@ -746,8 +739,11 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
} }
}(execCtx, auth.Clone(), provider, chunks) }(execCtx, auth.Clone(), provider, chunks)
return out, nil chunks = out
} return nil
})
})
return chunks, err
} }
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options { func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {