mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 12:20:52 +08:00
feature(ampcode): Improves AMP model mapping with alias support
Enhances the AMP model mapping functionality to support fallback mechanisms using . This change allows the system to attempt alternative models (aliases) if the primary mapped model fails due to issues like quota exhaustion. It updates the model mapper to load and utilize the configuration, enabling provider lookup via aliases. It also introduces context keys to pass fallback model names between handlers. Additionally, this change introduces a fix to prevent ReverseProxy from panicking by swallowing ErrAbortHandler panics. Amp-Thread-ID: https://ampcode.com/threads/T-019c0cd1-9e59-722b-83f0-e0582aba6914 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -562,192 +562,188 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeWithFallback(
|
||||
ctx context.Context,
|
||||
initialProviders []string,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
exec func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error,
|
||||
) error {
|
||||
routeModel := req.Model
|
||||
providers := initialProviders
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
|
||||
// Track fallback models from context (provided by Amp module fallback_models key)
|
||||
var fallbacks []string
|
||||
if v := ctx.Value("fallback_models"); v != nil {
|
||||
if fs, ok := v.([]string); ok {
|
||||
fallbacks = fs
|
||||
}
|
||||
}
|
||||
fallbackIdx := -1
|
||||
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
// No more auths for current model. Try next fallback model if available.
|
||||
if fallbackIdx+1 < len(fallbacks) {
|
||||
fallbackIdx++
|
||||
routeModel = fallbacks[fallbackIdx]
|
||||
log.Debugf("no more auths for current model, trying fallback model: %s (fallback %d/%d)", routeModel, fallbackIdx+1, len(fallbacks))
|
||||
|
||||
// Reset tried set for the new model and find its providers
|
||||
tried = make(map[string]struct{})
|
||||
providers = util.GetProviderName(thinking.ParseSuffix(routeModel).ModelName)
|
||||
// Reset opts for the new model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("fallback model %s has no providers, skipping", routeModel)
|
||||
continue // Try next fallback if this one has no providers
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return errPick
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
if err := exec(ctx, executor, auth, provider, routeModel); err != nil {
|
||||
if errCtx := ctx.Err(); errCtx != nil {
|
||||
return errCtx
|
||||
}
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedAttempt(
|
||||
ctx context.Context,
|
||||
auth *Auth,
|
||||
provider, routeModel string,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
exec func(ctx context.Context, execReq cliproxyexecutor.Request) error,
|
||||
) error {
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
|
||||
err := exec(execCtx, execReq)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: err == nil}
|
||||
if err != nil {
|
||||
result.Error = &Error{Message: err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(err, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(err); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
var resp cliproxyexecutor.Response
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
resp, errExec = executor.Execute(execCtx, auth, execReq, opts)
|
||||
return errExec
|
||||
})
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
var resp cliproxyexecutor.Response
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
resp, errExec = executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
return errExec
|
||||
})
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
var chunks <-chan cliproxyexecutor.StreamChunk
|
||||
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
|
||||
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
|
||||
var errExec error
|
||||
chunks, errExec = executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errExec != nil {
|
||||
return errExec
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
forward := true
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
forward := true
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
continue
|
||||
}
|
||||
if streamCtx == nil {
|
||||
out <- chunk
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
forward = false
|
||||
case out <- chunk:
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
if !forward {
|
||||
continue
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
if streamCtx == nil {
|
||||
out <- chunk
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
forward = false
|
||||
case out <- chunk:
|
||||
}
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
chunks = out
|
||||
return nil
|
||||
})
|
||||
})
|
||||
return chunks, err
|
||||
}
|
||||
|
||||
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
||||
|
||||
Reference in New Issue
Block a user