mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
feature(ampcode): Improves AMP model mapping with alias support
Enhances the AMP model mapping functionality to support fallback mechanisms using . This change allows the system to attempt alternative models (aliases) if the primary mapped model fails due to issues like quota exhaustion. It updates the model mapper to load and utilize the configuration, enabling provider lookup via aliases. It also introduces context keys to pass fallback model names between handlers. Additionally, this change introduces a fix to prevent ReverseProxy from panicking by swallowing ErrAbortHandler panics. Amp-Thread-ID: https://ampcode.com/threads/T-019c0cd1-9e59-722b-83f0-e0582aba6914 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -125,6 +125,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.registerOnce.Do(func() {
|
||||
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
// Load oauth-model-alias for provider lookup via aliases
|
||||
m.modelMapper.UpdateOAuthModelAlias(ctx.Config.OAuthModelAlias)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
@@ -212,6 +214,11 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Always update oauth-model-alias for model mapper (used for provider lookup)
|
||||
if m.modelMapper != nil {
|
||||
m.modelMapper.UpdateOAuthModelAlias(cfg.OAuthModelAlias)
|
||||
}
|
||||
|
||||
if m.enabled {
|
||||
// Check upstream URL change - now supports hot-reload
|
||||
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
||||
|
||||
@@ -2,7 +2,9 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -32,6 +34,10 @@ const (
|
||||
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||
const MappedModelContextKey = "mapped_model"
|
||||
|
||||
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
|
||||
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
|
||||
const FallbackModelsContextKey = "fallback_models"
|
||||
|
||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||
fields := log.Fields{
|
||||
@@ -113,6 +119,16 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||
return
|
||||
}
|
||||
panic(rec)
|
||||
}
|
||||
}()
|
||||
|
||||
requestPath := c.Request.URL.Path
|
||||
|
||||
// Read the request body to extract the model name
|
||||
@@ -142,36 +158,57 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
resolveMappedModel := func() (string, []string) {
|
||||
// resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
|
||||
resolveMappedModels := func() ([]string, []string) {
|
||||
if fh.modelMapper == nil {
|
||||
return "", nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
mappedModel = strings.TrimSpace(mappedModel)
|
||||
if mappedModel == "" {
|
||||
return "", nil
|
||||
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
|
||||
if !ok {
|
||||
// Fallback to single model for non-DefaultModelMapper
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
if mappedModel == "" {
|
||||
return nil, nil
|
||||
}
|
||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{mappedModel}, mappedProviders
|
||||
}
|
||||
|
||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||
// already specifies its own thinking suffix.
|
||||
if thinkingSuffix != "" {
|
||||
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
|
||||
if !mappedSuffixResult.HasSuffix {
|
||||
mappedModel += thinkingSuffix
|
||||
// Use MapModelWithFallbacks for DefaultModelMapper
|
||||
mappedModels := mapper.MapModelWithFallbacks(modelName)
|
||||
if len(mappedModels) == 0 {
|
||||
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
|
||||
}
|
||||
if len(mappedModels) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Apply thinking suffix if needed
|
||||
for i, model := range mappedModels {
|
||||
if thinkingSuffix != "" {
|
||||
suffixResult := thinking.ParseSuffix(model)
|
||||
if !suffixResult.HasSuffix {
|
||||
mappedModels[i] = model + thinkingSuffix
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return "", nil
|
||||
// Get providers for the first model
|
||||
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
|
||||
providers := util.GetProviderName(firstBaseModel)
|
||||
if len(providers) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return mappedModel, mappedProviders
|
||||
return mappedModels, providers
|
||||
}
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
@@ -185,13 +222,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
// Store mapped model and fallbacks in context for handlers
|
||||
c.Set(MappedModelContextKey, mappedModels[0])
|
||||
if len(mappedModels) > 1 {
|
||||
c.Set(FallbackModelsContextKey, mappedModels[1:])
|
||||
}
|
||||
resolvedModel = mappedModels[0]
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
@@ -206,13 +246,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
// Store mapped model and fallbacks in context for handlers
|
||||
c.Set(MappedModelContextKey, mappedModels[0])
|
||||
if len(mappedModels) > 1 {
|
||||
c.Set(FallbackModelsContextKey, mappedModels[1:])
|
||||
}
|
||||
resolvedModel = mappedModels[0]
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
|
||||
@@ -30,18 +30,112 @@ type DefaultModelMapper struct {
|
||||
mu sync.RWMutex
|
||||
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||
regexps []regexMapping // regex rules evaluated in order
|
||||
|
||||
// oauthAliasForward maps channel -> name (lower) -> []alias for oauth-model-alias lookup.
|
||||
// This allows model-mappings targets to find providers via their aliases.
|
||||
oauthAliasForward map[string]map[string][]string
|
||||
}
|
||||
|
||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
m := &DefaultModelMapper{
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
oauthAliasForward: nil,
|
||||
}
|
||||
m.UpdateMappings(mappings)
|
||||
return m
|
||||
}
|
||||
|
||||
// UpdateOAuthModelAlias updates the oauth-model-alias lookup table.
|
||||
// This is called during initialization and on config hot-reload.
|
||||
func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.OAuthModelAlias) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if len(aliases) == 0 {
|
||||
m.oauthAliasForward = nil
|
||||
return
|
||||
}
|
||||
|
||||
forward := make(map[string]map[string][]string, len(aliases))
|
||||
for rawChannel, entries := range aliases {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
channelMap := make(map[string][]string)
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
nameKey := strings.ToLower(name)
|
||||
channelMap[nameKey] = append(channelMap[nameKey], alias)
|
||||
}
|
||||
if len(channelMap) > 0 {
|
||||
forward[channel] = channelMap
|
||||
}
|
||||
}
|
||||
if len(forward) == 0 {
|
||||
m.oauthAliasForward = nil
|
||||
return
|
||||
}
|
||||
m.oauthAliasForward = forward
|
||||
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
|
||||
}
|
||||
|
||||
// findProviderViaOAuthAlias checks if targetModel is an oauth-model-alias name
|
||||
// and returns all aliases that have available providers.
|
||||
// Returns the first alias and its providers for backward compatibility,
|
||||
// and also populates allAliases with all available alias models.
|
||||
func (m *DefaultModelMapper) findProviderViaOAuthAlias(targetModel string) (aliasModel string, providers []string) {
|
||||
aliases := m.findAllAliasesWithProviders(targetModel)
|
||||
if len(aliases) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
// Return first one for backward compatibility
|
||||
first := aliases[0]
|
||||
return first, util.GetProviderName(first)
|
||||
}
|
||||
|
||||
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
|
||||
// that have available providers. Useful for fallback when one alias is quota-exceeded.
|
||||
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
|
||||
if m.oauthAliasForward == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetKey := strings.ToLower(strings.TrimSpace(targetModel))
|
||||
if targetKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
// Check all channels for this model name
|
||||
for _, channelMap := range m.oauthAliasForward {
|
||||
aliases := channelMap[targetKey]
|
||||
for _, alias := range aliases {
|
||||
aliasLower := strings.ToLower(alias)
|
||||
if _, exists := seen[aliasLower]; exists {
|
||||
continue
|
||||
}
|
||||
providers := util.GetProviderName(alias)
|
||||
if len(providers) > 0 {
|
||||
result = append(result, alias)
|
||||
seen[aliasLower] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MapModel checks if a mapping exists for the requested model and if the
|
||||
// target model has available local providers. Returns the mapped model name
|
||||
// or empty string if no valid mapping exists.
|
||||
@@ -51,9 +145,20 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
// However, if the mapping target already contains a suffix, the config suffix
|
||||
// takes priority over the user's suffix.
|
||||
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
if requestedModel == "" {
|
||||
models := m.MapModelWithFallbacks(requestedModel)
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return models[0]
|
||||
}
|
||||
|
||||
// MapModelWithFallbacks returns all possible target models for the requested model,
|
||||
// including fallback aliases from oauth-model-alias. The first model is the primary target,
|
||||
// and subsequent models are fallbacks to try if the primary is unavailable (e.g., quota exceeded).
|
||||
func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []string {
|
||||
if requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
@@ -78,34 +183,54 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if target model already has a thinking suffix (config priority)
|
||||
targetResult := thinking.ParseSuffix(targetModel)
|
||||
targetBase := targetResult.ModelName
|
||||
|
||||
// Helper to apply suffix to a model
|
||||
applySuffix := func(model string) string {
|
||||
modelResult := thinking.ParseSuffix(model)
|
||||
if modelResult.HasSuffix {
|
||||
return model
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return model + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// Verify target model has available providers (use base model for lookup)
|
||||
providers := util.GetProviderName(targetResult.ModelName)
|
||||
if len(providers) == 0 {
|
||||
providers := util.GetProviderName(targetBase)
|
||||
|
||||
// If direct provider available, return it as primary
|
||||
if len(providers) > 0 {
|
||||
return []string{applySuffix(targetModel)}
|
||||
}
|
||||
|
||||
// No direct providers - check oauth-model-alias for all aliases that have providers
|
||||
allAliases := m.findAllAliasesWithProviders(targetBase)
|
||||
if len(allAliases) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
|
||||
if targetResult.HasSuffix {
|
||||
// Config's "to" already contains a suffix - use it as-is (config priority)
|
||||
return targetModel
|
||||
// Log resolution
|
||||
if len(allAliases) == 1 {
|
||||
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
|
||||
} else {
|
||||
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases))
|
||||
}
|
||||
|
||||
// Preserve user's thinking suffix on the mapped model
|
||||
// (skip empty suffixes to avoid returning "model()")
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return targetModel + "(" + requestResult.RawSuffix + ")"
|
||||
// Apply suffix to all aliases
|
||||
result := make([]string, len(allAliases))
|
||||
for i, alias := range allAliases {
|
||||
result[i] = applySuffix(alias)
|
||||
}
|
||||
|
||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||
return targetModel
|
||||
return result
|
||||
}
|
||||
|
||||
// UpdateMappings refreshes the mapping configuration from config.
|
||||
|
||||
@@ -992,8 +992,8 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
}
|
||||
|
||||
// Notify Amp module only when Amp config has changed.
|
||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
|
||||
// Notify Amp module when Amp config or OAuth model aliases have changed.
|
||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) || !reflect.DeepEqual(oldCfg.OAuthModelAlias, cfg.OAuthModelAlias)
|
||||
if ampConfigChanged {
|
||||
if s.ampModule != nil {
|
||||
log.Debugf("triggering amp module config update")
|
||||
|
||||
@@ -562,192 +562,188 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeWithFallback(
|
||||
ctx context.Context,
|
||||
initialProviders []string,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
exec func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error,
|
||||
) error {
|
||||
routeModel := req.Model
|
||||
providers := initialProviders
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
|
||||
// Track fallback models from context (provided by Amp module fallback_models key)
|
||||
var fallbacks []string
|
||||
if v := ctx.Value("fallback_models"); v != nil {
|
||||
if fs, ok := v.([]string); ok {
|
||||
fallbacks = fs
|
||||
}
|
||||
}
|
||||
fallbackIdx := -1
|
||||
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
// No more auths for current model. Try next fallback model if available.
|
||||
if fallbackIdx+1 < len(fallbacks) {
|
||||
fallbackIdx++
|
||||
routeModel = fallbacks[fallbackIdx]
|
||||
log.Debugf("no more auths for current model, trying fallback model: %s (fallback %d/%d)", routeModel, fallbackIdx+1, len(fallbacks))
|
||||
|
||||
// Reset tried set for the new model and find its providers
|
||||
tried = make(map[string]struct{})
|
||||
providers = util.GetProviderName(thinking.ParseSuffix(routeModel).ModelName)
|
||||
// Reset opts for the new model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("fallback model %s has no providers, skipping", routeModel)
|
||||
continue // Try next fallback if this one has no providers
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return errPick
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
if err := exec(ctx, executor, auth, provider, routeModel); err != nil {
|
||||
if errCtx := ctx.Err(); errCtx != nil {
|
||||
return errCtx
|
||||
}
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedAttempt(
|
||||
ctx context.Context,
|
||||
auth *Auth,
|
||||
provider, routeModel string,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
exec func(ctx context.Context, execReq cliproxyexecutor.Request) error,
|
||||
) error {
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
|
||||
err := exec(execCtx, execReq)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: err == nil}
|
||||
if err != nil {
|
||||
result.Error = &Error{Message: err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(err, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(err); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
var resp cliproxyexecutor.Response
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
resp, errExec = executor.Execute(execCtx, auth, execReq, opts)
|
||||
return errExec
|
||||
})
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
var resp cliproxyexecutor.Response
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
resp, errExec = executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
return errExec
|
||||
})
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
var chunks <-chan cliproxyexecutor.StreamChunk
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
chunks, errExec = executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errExec != nil {
|
||||
return errExec
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
forward := true
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
forward := true
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
continue
|
||||
}
|
||||
if streamCtx == nil {
|
||||
out <- chunk
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
forward = false
|
||||
case out <- chunk:
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
continue
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
if streamCtx == nil {
|
||||
out <- chunk
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
forward = false
|
||||
case out <- chunk:
|
||||
}
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
chunks = out
|
||||
return nil
|
||||
})
|
||||
})
|
||||
return chunks, err
|
||||
}
|
||||
|
||||
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
||||
|
||||
Reference in New Issue
Block a user