From 527a269799d0931886c672a6d005ecb3a875db79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Sun, 1 Feb 2026 15:56:31 +0900 Subject: [PATCH] Refactors AMP model mapping and error handling Improves AMP request handling by consolidating model mapping logic into a helper function for better readability and maintainability. Enhances error handling for premature client connection closures during reverse proxy operations by explicitly acknowledging and swallowing the ErrAbortHandler panic, preventing noisy stack traces. Removes unused method `findProviderViaOAuthAlias` from the `DefaultModelMapper`. --- internal/api/modules/amp/fallback_handlers.go | 43 +++++++++---------- internal/api/modules/amp/model_mapping.go | 16 +------ 2 files changed, 21 insertions(+), 38 deletions(-) diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 940e9e3c..3b32d25e 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -122,7 +122,11 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) { // If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { - // Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces + // Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces. + // ReverseProxy raises this panic when the client connection is closed prematurely + // (e.g., user cancels request, network disconnect) or when ServeHTTP is called + // with a ResponseWriter that doesn't implement http.CloseNotifier. + // This is an expected error condition, not a bug, so we handle it gracefully. defer func() { if rec := recover(); rec != nil { if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) { @@ -219,6 +223,19 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc usedMapping := false var providers []string + // Helper to apply model mapping and update state + applyMapping := func(mappedModels []string, mappedProviders []string) { + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0]) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + c.Set(string(ctxkeys.MappedModel), mappedModels[0]) + if len(mappedModels) > 1 { + c.Set(string(ctxkeys.FallbackModels), mappedModels[1:]) + } + resolvedModel = mappedModels[0] + usedMapping = true + providers = mappedProviders + } + // Check if model mappings should be forced ahead of local API keys forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings() @@ -226,17 +243,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // This allows users to route Amp requests to their preferred OAuth providers if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0]) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model and fallbacks in context for handlers - c.Set(string(ctxkeys.MappedModel), mappedModels[0]) - if len(mappedModels) > 1 { - c.Set(string(ctxkeys.FallbackModels), mappedModels[1:]) - } - resolvedModel = mappedModels[0] - usedMapping = true - providers = mappedProviders + applyMapping(mappedModels, mappedProviders) } // If no mapping applied, check for local providers @@ -250,17 +257,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if len(providers) == 0 { // No providers configured - check if we have a model mapping if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0]) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model and fallbacks in context for handlers - c.Set(string(ctxkeys.MappedModel), mappedModels[0]) - if len(mappedModels) > 1 { - c.Set(string(ctxkeys.FallbackModels), mappedModels[1:]) - } - resolvedModel = mappedModels[0] - usedMapping = true - providers = mappedProviders + applyMapping(mappedModels, mappedProviders) } } } diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index 92599ebf..24b8cdcc 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -89,20 +89,6 @@ func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.O log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward)) } -// findProviderViaOAuthAlias checks if targetModel is an oauth-model-alias name -// and returns all aliases that have available providers. -// Returns the first alias and its providers for backward compatibility, -// and also populates allAliases with all available alias models. -func (m *DefaultModelMapper) findProviderViaOAuthAlias(targetModel string) (aliasModel string, providers []string) { - aliases := m.findAllAliasesWithProviders(targetModel) - if len(aliases) == 0 { - return "", nil - } - // Return first one for backward compatibility - first := aliases[0] - return first, util.GetProviderName(first) -} - // findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel // that have available providers. Useful for fallback when one alias is quota-exceeded. func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string { @@ -222,7 +208,7 @@ func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []stri if len(allAliases) == 1 { log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0]) } else { - log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)) + log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1) } // Apply suffix to all aliases