mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-17 20:00:52 +08:00
Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach. This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings. It provides a more flexible and maintainable routing mechanism by centralizing routing logic. The changes include: - Introducing new `routing` package with core routing logic. - Creating characterization tests to capture existing behavior. - Implementing model extraction and rewriting. - Updating AMP module routes to utilize the new routing wrapper. - Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
271 lines
8.1 KiB
Go
271 lines
8.1 KiB
Go
package routing
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// ProxyFunc is the function type for proxying requests.
|
|
type ProxyFunc func(c *gin.Context)
|
|
|
|
// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic.
|
|
// It replaces the FallbackHandler logic with a Router-based approach.
|
|
type ModelRoutingWrapper struct {
|
|
router *Router
|
|
extractor ModelExtractor
|
|
rewriter ModelRewriter
|
|
proxyFunc ProxyFunc
|
|
logger *logrus.Logger
|
|
}
|
|
|
|
// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies.
|
|
// If extractor is nil, a DefaultModelExtractor is used.
|
|
// If rewriter is nil, a DefaultModelRewriter is used.
|
|
// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead.
|
|
func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper {
|
|
if extractor == nil {
|
|
extractor = NewModelExtractor()
|
|
}
|
|
if rewriter == nil {
|
|
rewriter = NewModelRewriter()
|
|
}
|
|
return &ModelRoutingWrapper{
|
|
router: router,
|
|
extractor: extractor,
|
|
rewriter: rewriter,
|
|
proxyFunc: proxyFunc,
|
|
logger: logrus.New(),
|
|
}
|
|
}
|
|
|
|
// SetLogger sets the logger for the wrapper.
|
|
func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) {
|
|
w.logger = logger
|
|
}
|
|
|
|
// Wrap wraps a gin.HandlerFunc with model routing logic.
|
|
// The returned handler will:
|
|
// 1. Extract the model from the request
|
|
// 2. Get a routing decision from the Router
|
|
// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS)
|
|
func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Read request body
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
w.logger.Errorf("routing wrapper: failed to read request body: %v", err)
|
|
handler(c)
|
|
return
|
|
}
|
|
|
|
// Extract model from request
|
|
ginParams := map[string]string{
|
|
"action": c.Param("action"),
|
|
"path": c.Param("path"),
|
|
}
|
|
modelName, err := w.extractor.Extract(bodyBytes, ginParams)
|
|
if err != nil {
|
|
w.logger.Warnf("routing wrapper: failed to extract model: %v", err)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
return
|
|
}
|
|
|
|
if modelName == "" {
|
|
// No model found, proceed with original handler
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
return
|
|
}
|
|
|
|
// Get routing decision
|
|
req := RoutingRequest{
|
|
RequestedModel: modelName,
|
|
PreferLocalProvider: true,
|
|
ForceModelMapping: false, // TODO: Get from config
|
|
}
|
|
decision := w.router.ResolveV2(req)
|
|
|
|
// Store decision in context for downstream handlers
|
|
c.Set(string(ctxkeys.RoutingDecision), decision)
|
|
|
|
// Handle based on route type
|
|
switch decision.RouteType {
|
|
case RouteTypeLocalProvider:
|
|
w.handleLocalProvider(c, handler, bodyBytes, decision)
|
|
case RouteTypeModelMapping:
|
|
w.handleModelMapping(c, handler, bodyBytes, decision)
|
|
case RouteTypeAmpCredits:
|
|
w.handleAmpCredits(c, handler, bodyBytes)
|
|
default:
|
|
// No provider available
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
handler(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleLocalProvider handles the LOCAL_PROVIDER route type.
|
|
func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
|
// Filter Anthropic-Beta header for local provider
|
|
filterAnthropicBetaHeader(c)
|
|
|
|
// Restore body with original content
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
// Call handler
|
|
handler(c)
|
|
}
|
|
|
|
// handleModelMapping handles the MODEL_MAPPING route type.
|
|
func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
|
// Rewrite request body with mapped model
|
|
rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel)
|
|
if err != nil {
|
|
w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err)
|
|
rewrittenBody = bodyBytes
|
|
}
|
|
_ = rewrittenBody
|
|
|
|
// Store mapped model in context
|
|
c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel)
|
|
|
|
// Store fallback models in context if present
|
|
if len(decision.FallbackModels) > 0 {
|
|
c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels)
|
|
}
|
|
|
|
// Filter Anthropic-Beta header for local provider
|
|
filterAnthropicBetaHeader(c)
|
|
|
|
// Restore body with rewritten content
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody))
|
|
|
|
// Wrap response writer to rewrite model back
|
|
wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel)
|
|
c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer}
|
|
|
|
// Call handler
|
|
handler(c)
|
|
|
|
// Cleanup (flush response rewriting)
|
|
cleanup()
|
|
}
|
|
|
|
// handleAmpCredits handles the AMP_CREDITS route type.
|
|
// It calls the proxy function directly if available, otherwise passes to handler.
|
|
// Does NOT filter headers or rewrite body - proxy handles everything.
|
|
func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) {
|
|
// Restore body with original content (no rewriting for proxy)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
// Call proxy function if available, otherwise fall back to handler
|
|
if w.proxyFunc != nil {
|
|
w.proxyFunc(c)
|
|
} else {
|
|
handler(c)
|
|
}
|
|
}
|
|
|
|
// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers.
|
|
func filterAnthropicBetaHeader(c *gin.Context) {
|
|
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
|
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
|
if filtered != "" {
|
|
c.Request.Header.Set("Anthropic-Beta", filtered)
|
|
} else {
|
|
c.Request.Header.Del("Anthropic-Beta")
|
|
}
|
|
}
|
|
}
|
|
|
|
// filterBetaFeatures removes specified beta features from the header.
|
|
func filterBetaFeatures(betaHeader, featureToRemove string) string {
|
|
// Simple implementation - can be enhanced
|
|
if betaHeader == featureToRemove {
|
|
return ""
|
|
}
|
|
return betaHeader
|
|
}
|
|
|
|
// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter.
|
|
type ginResponseWriterAdapter struct {
|
|
http.ResponseWriter
|
|
original gin.ResponseWriter
|
|
}
|
|
|
|
func (a *ginResponseWriterAdapter) WriteHeader(code int) {
|
|
a.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) {
|
|
return a.ResponseWriter.Write(data)
|
|
}
|
|
|
|
func (a *ginResponseWriterAdapter) Header() http.Header {
|
|
return a.ResponseWriter.Header()
|
|
}
|
|
|
|
// CloseNotify implements http.CloseNotifier.
|
|
func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool {
|
|
if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok {
|
|
return notifier.CloseNotify()
|
|
}
|
|
return a.original.CloseNotify()
|
|
}
|
|
|
|
// Flush implements http.Flusher.
|
|
func (a *ginResponseWriterAdapter) Flush() {
|
|
if flusher, ok := a.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
// Hijack implements http.Hijacker.
|
|
func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok {
|
|
return hijacker.Hijack()
|
|
}
|
|
return a.original.Hijack()
|
|
}
|
|
|
|
// Status returns the HTTP status code.
|
|
func (a *ginResponseWriterAdapter) Status() int {
|
|
return a.original.Status()
|
|
}
|
|
|
|
// Size returns the number of bytes already written into the response http body.
|
|
func (a *ginResponseWriterAdapter) Size() int {
|
|
return a.original.Size()
|
|
}
|
|
|
|
// Written returns whether or not the response for this context has been written.
|
|
func (a *ginResponseWriterAdapter) Written() bool {
|
|
return a.original.Written()
|
|
}
|
|
|
|
// WriteHeaderNow forces WriteHeader to be called.
|
|
func (a *ginResponseWriterAdapter) WriteHeaderNow() {
|
|
a.original.WriteHeaderNow()
|
|
}
|
|
|
|
// WriteString writes the given string into the response body.
|
|
func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) {
|
|
return a.Write([]byte(s))
|
|
}
|
|
|
|
// Pusher returns the http.Pusher for server push.
|
|
func (a *ginResponseWriterAdapter) Pusher() http.Pusher {
|
|
if pusher, ok := a.ResponseWriter.(http.Pusher); ok {
|
|
return pusher
|
|
}
|
|
return nil
|
|
}
|