mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 20:30:51 +08:00
Implements unified model routing
Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach. This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings. It provides a more flexible and maintainable routing mechanism by centralizing routing logic. The changes include: - Introducing new `routing` package with core routing logic. - Creating characterization tests to capture existing behavior. - Implementing model extraction and rewriting. - Updating AMP module routes to utilize the new routing wrapper. - Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
This commit is contained in:
159
internal/routing/rewriter.go
Normal file
159
internal/routing/rewriter.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ModelRewriter handles model name rewriting in requests and responses.
|
||||
type ModelRewriter interface {
|
||||
// RewriteRequestBody rewrites the model field in a JSON request body.
|
||||
// Returns the modified body or the original if no rewrite was needed.
|
||||
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
|
||||
|
||||
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
|
||||
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
|
||||
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
|
||||
}
|
||||
|
||||
// DefaultModelRewriter is the standard implementation of ModelRewriter.
|
||||
type DefaultModelRewriter struct{}
|
||||
|
||||
// NewModelRewriter creates a new DefaultModelRewriter.
|
||||
func NewModelRewriter() *DefaultModelRewriter {
|
||||
return &DefaultModelRewriter{}
|
||||
}
|
||||
|
||||
// RewriteRequestBody replaces the model name in a JSON request body.
|
||||
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
|
||||
if !gjson.GetBytes(body, "model").Exists() {
|
||||
return body, nil
|
||||
}
|
||||
result, err := sjson.SetBytes(body, "model", newModel)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// WrapResponseWriter wraps a response writer to rewrite model names.
|
||||
// The cleanup function must be called after the handler completes to flush any buffered data.
|
||||
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
|
||||
rw := &responseRewriter{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
requestedModel: requestedModel,
|
||||
resolvedModel: resolvedModel,
|
||||
}
|
||||
return rw, func() { rw.flush() }
|
||||
}
|
||||
|
||||
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
|
||||
type responseRewriter struct {
|
||||
http.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
requestedModel string
|
||||
resolvedModel string
|
||||
isStreaming bool
|
||||
wroteHeader bool
|
||||
flushed bool
|
||||
}
|
||||
|
||||
// Write intercepts response writes and buffers them for model name replacement.
|
||||
func (rw *responseRewriter) Write(data []byte) (int, error) {
|
||||
// Ensure header is written
|
||||
if !rw.wroteHeader {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// Detect streaming on first write
|
||||
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||
contentType := rw.Header().Get("Content-Type")
|
||||
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||
strings.Contains(contentType, "stream")
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||
func (rw *responseRewriter) WriteHeader(code int) {
|
||||
if !rw.wroteHeader {
|
||||
rw.wroteHeader = true
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
}
|
||||
|
||||
// flush writes the buffered response with model names rewritten.
|
||||
func (rw *responseRewriter) flush() {
|
||||
if rw.flushed {
|
||||
return
|
||||
}
|
||||
rw.flushed = true
|
||||
|
||||
if rw.isStreaming {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
if rw.body.Len() > 0 {
|
||||
data := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||
if _, err := rw.ResponseWriter.Write(data); err != nil {
|
||||
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear.
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
|
||||
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||
return data
|
||||
}
|
||||
|
||||
for _, path := range modelFieldPaths {
|
||||
if gjson.GetBytes(data, path).Exists() {
|
||||
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// rewriteStreamChunk rewrites model names in SSE stream chunks.
|
||||
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||
return chunk
|
||||
}
|
||||
|
||||
// SSE format: "data: {json}\n\n"
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i, line := range lines {
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
// Rewrite JSON in the data line
|
||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||
lines[i] = append([]byte("data: "), rewritten...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
}
|
||||
Reference in New Issue
Block a user