feature(ampcode): Improves AMP model mapping with alias support

Enhances the AMP model mapping functionality to support fallback mechanisms using .

This change allows the system to attempt alternative models (aliases) if the primary mapped model fails due to issues like quota exhaustion. It updates the model mapper to load and utilize the  configuration, enabling provider lookup via aliases. It also introduces context keys to pass fallback model names between handlers.

Additionally, this change introduces a fix to prevent ReverseProxy from panicking by swallowing ErrAbortHandler panics.

Amp-Thread-ID: https://ampcode.com/threads/T-019c0cd1-9e59-722b-83f0-e0582aba6914
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
이대희
2026-01-30 12:50:53 +09:00
committed by hkfires
parent 2854e04bbb
commit 09044e8ccc
5 changed files with 384 additions and 213 deletions

View File

@@ -562,192 +562,188 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
func (m *Manager) executeWithFallback(
ctx context.Context,
initialProviders []string,
req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
exec func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error,
) error {
routeModel := req.Model
providers := initialProviders
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
// Track fallback models from context (provided by Amp module fallback_models key)
var fallbacks []string
if v := ctx.Value("fallback_models"); v != nil {
if fs, ok := v.([]string); ok {
fallbacks = fs
}
}
fallbackIdx := -1
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
// No more auths for current model. Try next fallback model if available.
if fallbackIdx+1 < len(fallbacks) {
fallbackIdx++
routeModel = fallbacks[fallbackIdx]
log.Debugf("no more auths for current model, trying fallback model: %s (fallback %d/%d)", routeModel, fallbackIdx+1, len(fallbacks))
// Reset tried set for the new model and find its providers
tried = make(map[string]struct{})
providers = util.GetProviderName(thinking.ParseSuffix(routeModel).ModelName)
// Reset opts for the new model
opts = ensureRequestedModelMetadata(opts, routeModel)
if len(providers) == 0 {
log.Debugf("fallback model %s has no providers, skipping", routeModel)
continue // Try next fallback if this one has no providers
}
continue
}
if lastErr != nil {
return lastErr
}
return errPick
}
tried[auth.ID] = struct{}{}
if err := exec(ctx, executor, auth, provider, routeModel); err != nil {
if errCtx := ctx.Err(); errCtx != nil {
return errCtx
}
lastErr = err
continue
}
return nil
}
}
func (m *Manager) executeMixedAttempt(
ctx context.Context,
auth *Auth,
provider, routeModel string,
req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
exec func(ctx context.Context, execReq cliproxyexecutor.Request) error,
) error {
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
err := exec(execCtx, execReq)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: err == nil}
if err != nil {
result.Error = &Error{Message: err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(err, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(err); ra != nil {
result.RetryAfter = ra
}
}
m.MarkResult(execCtx, result)
return err
}
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
var resp cliproxyexecutor.Response
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
var errExec error
resp, errExec = executor.Execute(execCtx, auth, execReq, opts)
return errExec
})
})
return resp, err
}
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
var resp cliproxyexecutor.Response
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
var errExec error
resp, errExec = executor.CountTokens(execCtx, auth, execReq, opts)
return errExec
})
})
return resp, err
}
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
var chunks <-chan cliproxyexecutor.StreamChunk
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
var errExec error
chunks, errExec = executor.ExecuteStream(execCtx, auth, execReq, opts)
if errExec != nil {
return errExec
}
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
forward := true
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
forward := true
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
continue
}
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
continue
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
}
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
}
}(execCtx, auth.Clone(), provider, chunks)
chunks = out
return nil
})
})
return chunks, err
}
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {