Refactors AMP model mapping and error handling

Improves AMP request handling by consolidating model mapping logic into a helper function for better readability and maintainability.

Enhances error handling for premature client connection closures during reverse proxy operations by explicitly acknowledging and swallowing the ErrAbortHandler panic, preventing noisy stack traces.

Removes unused method `findProviderViaOAuthAlias` from the `DefaultModelMapper`.
This commit is contained in:
이대희
2026-02-01 15:56:31 +09:00
parent 2fe0b6cd2d
commit 527a269799
2 changed files with 21 additions and 38 deletions

View File

@@ -122,7 +122,11 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
// ReverseProxy raises this panic when the client connection is closed prematurely
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
// with a ResponseWriter that doesn't implement http.CloseNotifier.
// This is an expected error condition, not a bug, so we handle it gracefully.
defer func() {
if rec := recover(); rec != nil {
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
@@ -219,6 +223,19 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
usedMapping := false
var providers []string
// Helper to apply model mapping and update state
applyMapping := func(mappedModels []string, mappedProviders []string) {
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
if len(mappedModels) > 1 {
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
}
resolvedModel = mappedModels[0]
usedMapping = true
providers = mappedProviders
}
// Check if model mappings should be forced ahead of local API keys
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
@@ -226,17 +243,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model and fallbacks in context for handlers
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
if len(mappedModels) > 1 {
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
}
resolvedModel = mappedModels[0]
usedMapping = true
providers = mappedProviders
applyMapping(mappedModels, mappedProviders)
}
// If no mapping applied, check for local providers
@@ -250,17 +257,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
if len(providers) == 0 {
// No providers configured - check if we have a model mapping
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model and fallbacks in context for handlers
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
if len(mappedModels) > 1 {
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
}
resolvedModel = mappedModels[0]
usedMapping = true
providers = mappedProviders
applyMapping(mappedModels, mappedProviders)
}
}
}

View File

@@ -89,20 +89,6 @@ func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.O
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
}
// findProviderViaOAuthAlias checks if targetModel is an oauth-model-alias name
// and returns all aliases that have available providers.
// Returns the first alias and its providers for backward compatibility,
// and also populates allAliases with all available alias models.
func (m *DefaultModelMapper) findProviderViaOAuthAlias(targetModel string) (aliasModel string, providers []string) {
aliases := m.findAllAliasesWithProviders(targetModel)
if len(aliases) == 0 {
return "", nil
}
// Return first one for backward compatibility
first := aliases[0]
return first, util.GetProviderName(first)
}
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
// that have available providers. Useful for fallback when one alias is quota-exceeded.
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
@@ -222,7 +208,7 @@ func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []stri
if len(allAliases) == 1 {
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
} else {
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases))
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1)
}
// Apply suffix to all aliases