mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
Improves AMP request handling by consolidating model mapping logic into a helper function for better readability and maintainability. Enhances error handling for premature client connection closures during reverse proxy operations by explicitly acknowledging and swallowing the ErrAbortHandler panic, preventing noisy stack traces. Removes unused method `findProviderViaOAuthAlias` from the `DefaultModelMapper`.
375 lines
13 KiB
Go
375 lines
13 KiB
Go
package amp
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// AmpRouteType represents the type of routing decision made for an Amp request
|
|
type AmpRouteType string
|
|
|
|
const (
|
|
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free)
|
|
RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER"
|
|
// RouteTypeModelMapping indicates the request was remapped to another available model (free)
|
|
RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING"
|
|
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits)
|
|
RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS"
|
|
// RouteTypeNoProvider indicates no provider or fallback available
|
|
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
|
)
|
|
|
|
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
|
// Deprecated: Use ctxkeys.MappedModel instead.
|
|
const MappedModelContextKey = string(ctxkeys.MappedModel)
|
|
|
|
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
|
|
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
|
|
// Deprecated: Use ctxkeys.FallbackModels instead.
|
|
const FallbackModelsContextKey = string(ctxkeys.FallbackModels)
|
|
|
|
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
|
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
|
fields := log.Fields{
|
|
"component": "amp-routing",
|
|
"route_type": string(routeType),
|
|
"requested_model": requestedModel,
|
|
"path": path,
|
|
"timestamp": time.Now().Format(time.RFC3339),
|
|
}
|
|
|
|
if resolvedModel != "" && resolvedModel != requestedModel {
|
|
fields["resolved_model"] = resolvedModel
|
|
}
|
|
if provider != "" {
|
|
fields["provider"] = provider
|
|
}
|
|
|
|
switch routeType {
|
|
case RouteTypeLocalProvider:
|
|
fields["cost"] = "free"
|
|
fields["source"] = "local_oauth"
|
|
log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
|
|
|
|
case RouteTypeModelMapping:
|
|
fields["cost"] = "free"
|
|
fields["source"] = "local_oauth"
|
|
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
|
// model mapping already logged in mapper; avoid duplicate here
|
|
|
|
case RouteTypeAmpCredits:
|
|
fields["cost"] = "amp_credits"
|
|
fields["source"] = "ampcode.com"
|
|
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
|
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
|
|
|
case RouteTypeNoProvider:
|
|
fields["cost"] = "none"
|
|
fields["source"] = "error"
|
|
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
|
log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
|
|
}
|
|
}
|
|
|
|
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
|
// when the model's provider is not available in CLIProxyAPI
|
|
type FallbackHandler struct {
|
|
getProxy func() *httputil.ReverseProxy
|
|
modelMapper ModelMapper
|
|
forceModelMappings func() bool
|
|
}
|
|
|
|
// NewFallbackHandler creates a new fallback handler wrapper
|
|
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
|
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
|
return &FallbackHandler{
|
|
getProxy: getProxy,
|
|
forceModelMappings: func() bool { return false },
|
|
}
|
|
}
|
|
|
|
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
|
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
|
if forceModelMappings == nil {
|
|
forceModelMappings = func() bool { return false }
|
|
}
|
|
return &FallbackHandler{
|
|
getProxy: getProxy,
|
|
modelMapper: mapper,
|
|
forceModelMappings: forceModelMappings,
|
|
}
|
|
}
|
|
|
|
// SetModelMapper sets the model mapper for this handler (allows late binding)
|
|
func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
|
fh.modelMapper = mapper
|
|
}
|
|
|
|
// WrapHandler wraps a gin.HandlerFunc with fallback logic
|
|
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
|
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
|
|
// ReverseProxy raises this panic when the client connection is closed prematurely
|
|
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
|
|
// with a ResponseWriter that doesn't implement http.CloseNotifier.
|
|
// This is an expected error condition, not a bug, so we handle it gracefully.
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
|
return
|
|
}
|
|
panic(rec)
|
|
}
|
|
}()
|
|
|
|
requestPath := c.Request.URL.Path
|
|
|
|
// Read the request body to extract the model name
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
log.Errorf("amp fallback: failed to read request body: %v", err)
|
|
handler(c)
|
|
return
|
|
}
|
|
|
|
// Restore the body for the handler to read
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
// Try to extract model from request body or URL path (for Gemini)
|
|
modelName := extractModelFromRequest(bodyBytes, c)
|
|
if modelName == "" {
|
|
// Can't determine model, proceed with normal handler
|
|
handler(c)
|
|
return
|
|
}
|
|
|
|
// Normalize model (handles dynamic thinking suffixes)
|
|
suffixResult := thinking.ParseSuffix(modelName)
|
|
normalizedModel := suffixResult.ModelName
|
|
thinkingSuffix := ""
|
|
if suffixResult.HasSuffix {
|
|
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
|
}
|
|
|
|
// resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
|
|
resolveMappedModels := func() ([]string, []string) {
|
|
if fh.modelMapper == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
|
|
if !ok {
|
|
// Fallback to single model for non-DefaultModelMapper
|
|
mappedModel := fh.modelMapper.MapModel(modelName)
|
|
if mappedModel == "" {
|
|
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
|
}
|
|
if mappedModel == "" {
|
|
return nil, nil
|
|
}
|
|
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
|
mappedProviders := util.GetProviderName(mappedBaseModel)
|
|
if len(mappedProviders) == 0 {
|
|
return nil, nil
|
|
}
|
|
return []string{mappedModel}, mappedProviders
|
|
}
|
|
|
|
// Use MapModelWithFallbacks for DefaultModelMapper
|
|
mappedModels := mapper.MapModelWithFallbacks(modelName)
|
|
if len(mappedModels) == 0 {
|
|
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
|
|
}
|
|
if len(mappedModels) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// Apply thinking suffix if needed
|
|
for i, model := range mappedModels {
|
|
if thinkingSuffix != "" {
|
|
suffixResult := thinking.ParseSuffix(model)
|
|
if !suffixResult.HasSuffix {
|
|
mappedModels[i] = model + thinkingSuffix
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get providers for the first model
|
|
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
|
|
providers := util.GetProviderName(firstBaseModel)
|
|
if len(providers) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return mappedModels, providers
|
|
}
|
|
|
|
// Track resolved model for logging (may change if mapping is applied)
|
|
resolvedModel := normalizedModel
|
|
usedMapping := false
|
|
var providers []string
|
|
|
|
// Helper to apply model mapping and update state
|
|
applyMapping := func(mappedModels []string, mappedProviders []string) {
|
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
|
|
if len(mappedModels) > 1 {
|
|
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
|
|
}
|
|
resolvedModel = mappedModels[0]
|
|
usedMapping = true
|
|
providers = mappedProviders
|
|
}
|
|
|
|
// Check if model mappings should be forced ahead of local API keys
|
|
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
|
|
|
if forceMappings {
|
|
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
|
// This allows users to route Amp requests to their preferred OAuth providers
|
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
|
applyMapping(mappedModels, mappedProviders)
|
|
}
|
|
|
|
// If no mapping applied, check for local providers
|
|
if !usedMapping {
|
|
providers = util.GetProviderName(normalizedModel)
|
|
}
|
|
} else {
|
|
// DEFAULT MODE: Check local providers first, then mappings as fallback
|
|
providers = util.GetProviderName(normalizedModel)
|
|
|
|
if len(providers) == 0 {
|
|
// No providers configured - check if we have a model mapping
|
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
|
applyMapping(mappedModels, mappedProviders)
|
|
}
|
|
}
|
|
}
|
|
|
|
// If no providers available, fallback to ampcode.com
|
|
if len(providers) == 0 {
|
|
proxy := fh.getProxy()
|
|
if proxy != nil {
|
|
// Log: Forwarding to ampcode.com (uses Amp credits)
|
|
logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath)
|
|
|
|
// Restore body again for the proxy
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
// Forward to ampcode.com
|
|
proxy.ServeHTTP(c.Writer, c.Request)
|
|
return
|
|
}
|
|
|
|
// No proxy available, let the normal handler return the error
|
|
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
|
|
}
|
|
|
|
// Log the routing decision
|
|
providerName := ""
|
|
if len(providers) > 0 {
|
|
providerName = providers[0]
|
|
}
|
|
|
|
if usedMapping {
|
|
// Log: Model was mapped to another model
|
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
|
c.Writer = rewriter
|
|
// Filter Anthropic-Beta header only for local handling paths
|
|
filterAntropicBetaHeader(c)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
rewriter.Flush()
|
|
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
|
} else if len(providers) > 0 {
|
|
// Log: Using local provider (free)
|
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
|
// Filter Anthropic-Beta header only for local handling paths
|
|
filterAntropicBetaHeader(c)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
} else {
|
|
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
|
|
// This is needed when using local providers (bypassing the Amp proxy)
|
|
func filterAntropicBetaHeader(c *gin.Context) {
|
|
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
|
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
|
|
c.Request.Header.Set("Anthropic-Beta", filtered)
|
|
} else {
|
|
c.Request.Header.Del("Anthropic-Beta")
|
|
}
|
|
}
|
|
}
|
|
|
|
// rewriteModelInRequest replaces the model name in a JSON request body
|
|
func rewriteModelInRequest(body []byte, newModel string) []byte {
|
|
if !gjson.GetBytes(body, "model").Exists() {
|
|
return body
|
|
}
|
|
result, err := sjson.SetBytes(body, "model", newModel)
|
|
if err != nil {
|
|
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
|
|
return body
|
|
}
|
|
return result
|
|
}
|
|
|
|
// extractModelFromRequest attempts to extract the model name from various request formats
|
|
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
|
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
|
// Check common model field names
|
|
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
|
return result.String()
|
|
}
|
|
|
|
// For Gemini requests, model is in the URL path
|
|
// Standard format: /models/{model}:generateContent -> :action parameter
|
|
if action := c.Param("action"); action != "" {
|
|
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
|
|
parts := strings.Split(action, ":")
|
|
if len(parts) > 0 && parts[0] != "" {
|
|
return parts[0]
|
|
}
|
|
}
|
|
|
|
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
|
|
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
|
if path := c.Param("path"); path != "" {
|
|
// Look for /models/{model}:method pattern
|
|
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
|
modelPart := path[idx+8:] // Skip "/models/"
|
|
// Split by colon to get model name
|
|
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
|
return modelPart[:colonIdx]
|
|
}
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|