mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach. This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings. It provides a more flexible and maintainable routing mechanism by centralizing routing logic. The changes include: - Introducing new `routing` package with core routing logic. - Creating characterization tests to capture existing behavior. - Implementing model extraction and rewriting. - Updating AMP module routes to utilize the new routing wrapper. - Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
383 lines
14 KiB
Go
383 lines
14 KiB
Go
package amp
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// AmpRouteType represents the type of routing decision made for an Amp request
|
|
type AmpRouteType string
|
|
|
|
const (
|
|
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free)
|
|
RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER"
|
|
// RouteTypeModelMapping indicates the request was remapped to another available model (free)
|
|
RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING"
|
|
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits)
|
|
RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS"
|
|
// RouteTypeNoProvider indicates no provider or fallback available
|
|
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
|
)
|
|
|
|
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
|
// Deprecated: Use ctxkeys.MappedModel instead.
|
|
const MappedModelContextKey = string(ctxkeys.MappedModel)
|
|
|
|
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
|
|
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
|
|
// Deprecated: Use ctxkeys.FallbackModels instead.
|
|
const FallbackModelsContextKey = string(ctxkeys.FallbackModels)
|
|
|
|
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
|
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
|
fields := log.Fields{
|
|
"component": "amp-routing",
|
|
"route_type": string(routeType),
|
|
"requested_model": requestedModel,
|
|
"path": path,
|
|
"timestamp": time.Now().Format(time.RFC3339),
|
|
}
|
|
|
|
if resolvedModel != "" && resolvedModel != requestedModel {
|
|
fields["resolved_model"] = resolvedModel
|
|
}
|
|
if provider != "" {
|
|
fields["provider"] = provider
|
|
}
|
|
|
|
switch routeType {
|
|
case RouteTypeLocalProvider:
|
|
fields["cost"] = "free"
|
|
fields["source"] = "local_oauth"
|
|
log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
|
|
|
|
case RouteTypeModelMapping:
|
|
fields["cost"] = "free"
|
|
fields["source"] = "local_oauth"
|
|
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
|
// model mapping already logged in mapper; avoid duplicate here
|
|
|
|
case RouteTypeAmpCredits:
|
|
fields["cost"] = "amp_credits"
|
|
fields["source"] = "ampcode.com"
|
|
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
|
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
|
|
|
case RouteTypeNoProvider:
|
|
fields["cost"] = "none"
|
|
fields["source"] = "error"
|
|
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
|
log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
|
|
}
|
|
}
|
|
|
|
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
|
// when the model's provider is not available in CLIProxyAPI
|
|
//
|
|
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
|
|
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
|
|
// This type is kept for backward compatibility and test purposes.
|
|
type FallbackHandler struct {
|
|
getProxy func() *httputil.ReverseProxy
|
|
modelMapper ModelMapper
|
|
forceModelMappings func() bool
|
|
}
|
|
|
|
// NewFallbackHandler creates a new fallback handler wrapper
|
|
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
|
//
|
|
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
|
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
|
return &FallbackHandler{
|
|
getProxy: getProxy,
|
|
forceModelMappings: func() bool { return false },
|
|
}
|
|
}
|
|
|
|
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
|
//
|
|
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
|
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
|
if forceModelMappings == nil {
|
|
forceModelMappings = func() bool { return false }
|
|
}
|
|
return &FallbackHandler{
|
|
getProxy: getProxy,
|
|
modelMapper: mapper,
|
|
forceModelMappings: forceModelMappings,
|
|
}
|
|
}
|
|
|
|
// SetModelMapper sets the model mapper for this handler (allows late binding)
|
|
func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
|
fh.modelMapper = mapper
|
|
}
|
|
|
|
// WrapHandler wraps a gin.HandlerFunc with fallback logic
|
|
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
|
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
|
|
// ReverseProxy raises this panic when the client connection is closed prematurely
|
|
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
|
|
// with a ResponseWriter that doesn't implement http.CloseNotifier.
|
|
// This is an expected error condition, not a bug, so we handle it gracefully.
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
|
return
|
|
}
|
|
panic(rec)
|
|
}
|
|
}()
|
|
|
|
requestPath := c.Request.URL.Path
|
|
|
|
// Read the request body to extract the model name
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
log.Errorf("amp fallback: failed to read request body: %v", err)
|
|
handler(c)
|
|
return
|
|
}
|
|
|
|
// Restore the body for the handler to read
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
// Try to extract model from request body or URL path (for Gemini)
|
|
modelName := extractModelFromRequest(bodyBytes, c)
|
|
if modelName == "" {
|
|
// Can't determine model, proceed with normal handler
|
|
handler(c)
|
|
return
|
|
}
|
|
|
|
// Normalize model (handles dynamic thinking suffixes)
|
|
suffixResult := thinking.ParseSuffix(modelName)
|
|
normalizedModel := suffixResult.ModelName
|
|
thinkingSuffix := ""
|
|
if suffixResult.HasSuffix {
|
|
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
|
}
|
|
|
|
// resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
|
|
resolveMappedModels := func() ([]string, []string) {
|
|
if fh.modelMapper == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
|
|
if !ok {
|
|
// Fallback to single model for non-DefaultModelMapper
|
|
mappedModel := fh.modelMapper.MapModel(modelName)
|
|
if mappedModel == "" {
|
|
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
|
}
|
|
if mappedModel == "" {
|
|
return nil, nil
|
|
}
|
|
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
|
mappedProviders := util.GetProviderName(mappedBaseModel)
|
|
if len(mappedProviders) == 0 {
|
|
return nil, nil
|
|
}
|
|
return []string{mappedModel}, mappedProviders
|
|
}
|
|
|
|
// Use MapModelWithFallbacks for DefaultModelMapper
|
|
mappedModels := mapper.MapModelWithFallbacks(modelName)
|
|
if len(mappedModels) == 0 {
|
|
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
|
|
}
|
|
if len(mappedModels) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// Apply thinking suffix if needed
|
|
for i, model := range mappedModels {
|
|
if thinkingSuffix != "" {
|
|
suffixResult := thinking.ParseSuffix(model)
|
|
if !suffixResult.HasSuffix {
|
|
mappedModels[i] = model + thinkingSuffix
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get providers for the first model
|
|
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
|
|
providers := util.GetProviderName(firstBaseModel)
|
|
if len(providers) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return mappedModels, providers
|
|
}
|
|
|
|
// Track resolved model for logging (may change if mapping is applied)
|
|
resolvedModel := normalizedModel
|
|
usedMapping := false
|
|
var providers []string
|
|
|
|
// Helper to apply model mapping and update state
|
|
applyMapping := func(mappedModels []string, mappedProviders []string) {
|
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
|
|
if len(mappedModels) > 1 {
|
|
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
|
|
}
|
|
resolvedModel = mappedModels[0]
|
|
usedMapping = true
|
|
providers = mappedProviders
|
|
}
|
|
|
|
// Check if model mappings should be forced ahead of local API keys
|
|
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
|
|
|
if forceMappings {
|
|
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
|
// This allows users to route Amp requests to their preferred OAuth providers
|
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
|
applyMapping(mappedModels, mappedProviders)
|
|
}
|
|
|
|
// If no mapping applied, check for local providers
|
|
if !usedMapping {
|
|
providers = util.GetProviderName(normalizedModel)
|
|
}
|
|
} else {
|
|
// DEFAULT MODE: Check local providers first, then mappings as fallback
|
|
providers = util.GetProviderName(normalizedModel)
|
|
|
|
if len(providers) == 0 {
|
|
// No providers configured - check if we have a model mapping
|
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
|
applyMapping(mappedModels, mappedProviders)
|
|
}
|
|
}
|
|
}
|
|
|
|
// If no providers available, fallback to ampcode.com
|
|
if len(providers) == 0 {
|
|
proxy := fh.getProxy()
|
|
if proxy != nil {
|
|
// Log: Forwarding to ampcode.com (uses Amp credits)
|
|
logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath)
|
|
|
|
// Restore body again for the proxy
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
// Forward to ampcode.com
|
|
proxy.ServeHTTP(c.Writer, c.Request)
|
|
return
|
|
}
|
|
|
|
// No proxy available, let the normal handler return the error
|
|
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
|
|
}
|
|
|
|
// Log the routing decision
|
|
providerName := ""
|
|
if len(providers) > 0 {
|
|
providerName = providers[0]
|
|
}
|
|
|
|
if usedMapping {
|
|
// Log: Model was mapped to another model
|
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
|
c.Writer = rewriter
|
|
// Filter Anthropic-Beta header only for local handling paths
|
|
filterAntropicBetaHeader(c)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
rewriter.Flush()
|
|
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
|
} else if len(providers) > 0 {
|
|
// Log: Using local provider (free)
|
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
|
// Filter Anthropic-Beta header only for local handling paths
|
|
filterAntropicBetaHeader(c)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
} else {
|
|
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
|
|
// This is needed when using local providers (bypassing the Amp proxy)
|
|
func filterAntropicBetaHeader(c *gin.Context) {
|
|
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
|
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
|
|
c.Request.Header.Set("Anthropic-Beta", filtered)
|
|
} else {
|
|
c.Request.Header.Del("Anthropic-Beta")
|
|
}
|
|
}
|
|
}
|
|
|
|
// rewriteModelInRequest replaces the model name in a JSON request body
|
|
func rewriteModelInRequest(body []byte, newModel string) []byte {
|
|
if !gjson.GetBytes(body, "model").Exists() {
|
|
return body
|
|
}
|
|
result, err := sjson.SetBytes(body, "model", newModel)
|
|
if err != nil {
|
|
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
|
|
return body
|
|
}
|
|
return result
|
|
}
|
|
|
|
// extractModelFromRequest attempts to extract the model name from various request formats
|
|
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
|
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
|
// Check common model field names
|
|
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
|
return result.String()
|
|
}
|
|
|
|
// For Gemini requests, model is in the URL path
|
|
// Standard format: /models/{model}:generateContent -> :action parameter
|
|
if action := c.Param("action"); action != "" {
|
|
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
|
|
parts := strings.Split(action, ":")
|
|
if len(parts) > 0 && parts[0] != "" {
|
|
return parts[0]
|
|
}
|
|
}
|
|
|
|
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
|
|
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
|
if path := c.Param("path"); path != "" {
|
|
// Look for /models/{model}:method pattern
|
|
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
|
modelPart := path[idx+8:] // Skip "/models/"
|
|
// Split by colon to get model name
|
|
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
|
return modelPart[:colonIdx]
|
|
}
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|