mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Refactors AMP model mapping and error handling
Improves AMP request handling by consolidating model mapping logic into a helper function for better readability and maintainability. Enhances error handling for premature client connection closures during reverse proxy operations by explicitly acknowledging and swallowing the ErrAbortHandler panic, preventing noisy stack traces. Removes unused method `findProviderViaOAuthAlias` from the `DefaultModelMapper`.
This commit is contained in:
@@ -122,7 +122,11 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
|||||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
|
||||||
|
// ReverseProxy raises this panic when the client connection is closed prematurely
|
||||||
|
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
|
||||||
|
// with a ResponseWriter that doesn't implement http.CloseNotifier.
|
||||||
|
// This is an expected error condition, not a bug, so we handle it gracefully.
|
||||||
defer func() {
|
defer func() {
|
||||||
if rec := recover(); rec != nil {
|
if rec := recover(); rec != nil {
|
||||||
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||||
@@ -219,17 +223,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
usedMapping := false
|
usedMapping := false
|
||||||
var providers []string
|
var providers []string
|
||||||
|
|
||||||
// Check if model mappings should be forced ahead of local API keys
|
// Helper to apply model mapping and update state
|
||||||
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
applyMapping := func(mappedModels []string, mappedProviders []string) {
|
||||||
|
|
||||||
if forceMappings {
|
|
||||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
|
||||||
// This allows users to route Amp requests to their preferred OAuth providers
|
|
||||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
|
||||||
// Mapping found and provider available - rewrite the model in request body
|
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
// Store mapped model and fallbacks in context for handlers
|
|
||||||
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
|
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
|
||||||
if len(mappedModels) > 1 {
|
if len(mappedModels) > 1 {
|
||||||
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
|
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
|
||||||
@@ -239,6 +236,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
providers = mappedProviders
|
providers = mappedProviders
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if model mappings should be forced ahead of local API keys
|
||||||
|
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
||||||
|
|
||||||
|
if forceMappings {
|
||||||
|
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||||
|
// This allows users to route Amp requests to their preferred OAuth providers
|
||||||
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||||
|
applyMapping(mappedModels, mappedProviders)
|
||||||
|
}
|
||||||
|
|
||||||
// If no mapping applied, check for local providers
|
// If no mapping applied, check for local providers
|
||||||
if !usedMapping {
|
if !usedMapping {
|
||||||
providers = util.GetProviderName(normalizedModel)
|
providers = util.GetProviderName(normalizedModel)
|
||||||
@@ -250,17 +257,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
// No providers configured - check if we have a model mapping
|
// No providers configured - check if we have a model mapping
|
||||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||||
// Mapping found and provider available - rewrite the model in request body
|
applyMapping(mappedModels, mappedProviders)
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
// Store mapped model and fallbacks in context for handlers
|
|
||||||
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
|
|
||||||
if len(mappedModels) > 1 {
|
|
||||||
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
|
|
||||||
}
|
|
||||||
resolvedModel = mappedModels[0]
|
|
||||||
usedMapping = true
|
|
||||||
providers = mappedProviders
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -89,20 +89,6 @@ func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.O
|
|||||||
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
|
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
|
||||||
}
|
}
|
||||||
|
|
||||||
// findProviderViaOAuthAlias checks if targetModel is an oauth-model-alias name
|
|
||||||
// and returns all aliases that have available providers.
|
|
||||||
// Returns the first alias and its providers for backward compatibility,
|
|
||||||
// and also populates allAliases with all available alias models.
|
|
||||||
func (m *DefaultModelMapper) findProviderViaOAuthAlias(targetModel string) (aliasModel string, providers []string) {
|
|
||||||
aliases := m.findAllAliasesWithProviders(targetModel)
|
|
||||||
if len(aliases) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
// Return first one for backward compatibility
|
|
||||||
first := aliases[0]
|
|
||||||
return first, util.GetProviderName(first)
|
|
||||||
}
|
|
||||||
|
|
||||||
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
|
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
|
||||||
// that have available providers. Useful for fallback when one alias is quota-exceeded.
|
// that have available providers. Useful for fallback when one alias is quota-exceeded.
|
||||||
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
|
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
|
||||||
@@ -222,7 +208,7 @@ func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []stri
|
|||||||
if len(allAliases) == 1 {
|
if len(allAliases) == 1 {
|
||||||
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
|
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases))
|
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply suffix to all aliases
|
// Apply suffix to all aliases
|
||||||
|
|||||||
Reference in New Issue
Block a user