mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 04:10:51 +08:00
Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach. This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings. It provides a more flexible and maintainable routing mechanism by centralizing routing logic. The changes include: - Introducing new `routing` package with core routing logic. - Creating characterization tests to capture existing behavior. - Implementing model extraction and rewriting. - Updating AMP module routes to utilize the new routing wrapper. - Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
160 lines
4.8 KiB
Go
160 lines
4.8 KiB
Go
package routing
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// ModelRewriter handles model name rewriting in requests and responses.
|
|
type ModelRewriter interface {
|
|
// RewriteRequestBody rewrites the model field in a JSON request body.
|
|
// Returns the modified body or the original if no rewrite was needed.
|
|
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
|
|
|
|
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
|
|
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
|
|
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
|
|
}
|
|
|
|
// DefaultModelRewriter is the standard implementation of ModelRewriter.
|
|
type DefaultModelRewriter struct{}
|
|
|
|
// NewModelRewriter creates a new DefaultModelRewriter.
|
|
func NewModelRewriter() *DefaultModelRewriter {
|
|
return &DefaultModelRewriter{}
|
|
}
|
|
|
|
// RewriteRequestBody replaces the model name in a JSON request body.
|
|
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
|
|
if !gjson.GetBytes(body, "model").Exists() {
|
|
return body, nil
|
|
}
|
|
result, err := sjson.SetBytes(body, "model", newModel)
|
|
if err != nil {
|
|
return body, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// WrapResponseWriter wraps a response writer to rewrite model names.
|
|
// The cleanup function must be called after the handler completes to flush any buffered data.
|
|
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
|
|
rw := &responseRewriter{
|
|
ResponseWriter: w,
|
|
body: &bytes.Buffer{},
|
|
requestedModel: requestedModel,
|
|
resolvedModel: resolvedModel,
|
|
}
|
|
return rw, func() { rw.flush() }
|
|
}
|
|
|
|
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
|
|
type responseRewriter struct {
|
|
http.ResponseWriter
|
|
body *bytes.Buffer
|
|
requestedModel string
|
|
resolvedModel string
|
|
isStreaming bool
|
|
wroteHeader bool
|
|
flushed bool
|
|
}
|
|
|
|
// Write intercepts response writes and buffers them for model name replacement.
|
|
func (rw *responseRewriter) Write(data []byte) (int, error) {
|
|
// Ensure header is written
|
|
if !rw.wroteHeader {
|
|
rw.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
// Detect streaming on first write
|
|
if rw.body.Len() == 0 && !rw.isStreaming {
|
|
contentType := rw.Header().Get("Content-Type")
|
|
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
|
strings.Contains(contentType, "stream")
|
|
}
|
|
|
|
if rw.isStreaming {
|
|
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
|
if err == nil {
|
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
return n, err
|
|
}
|
|
return rw.body.Write(data)
|
|
}
|
|
|
|
// WriteHeader captures the status code and delegates to the underlying writer.
|
|
func (rw *responseRewriter) WriteHeader(code int) {
|
|
if !rw.wroteHeader {
|
|
rw.wroteHeader = true
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
}
|
|
|
|
// flush writes the buffered response with model names rewritten.
|
|
func (rw *responseRewriter) flush() {
|
|
if rw.flushed {
|
|
return
|
|
}
|
|
rw.flushed = true
|
|
|
|
if rw.isStreaming {
|
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
return
|
|
}
|
|
if rw.body.Len() > 0 {
|
|
data := rw.rewriteModelInResponse(rw.body.Bytes())
|
|
if _, err := rw.ResponseWriter.Write(data); err != nil {
|
|
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// modelFieldPaths lists all JSON paths where model name may appear.
|
|
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
|
|
|
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
|
|
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
|
|
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
|
return data
|
|
}
|
|
|
|
for _, path := range modelFieldPaths {
|
|
if gjson.GetBytes(data, path).Exists() {
|
|
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
|
|
}
|
|
}
|
|
return data
|
|
}
|
|
|
|
// rewriteStreamChunk rewrites model names in SSE stream chunks.
|
|
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
|
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
|
return chunk
|
|
}
|
|
|
|
// SSE format: "data: {json}\n\n"
|
|
lines := bytes.Split(chunk, []byte("\n"))
|
|
for i, line := range lines {
|
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
|
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
|
// Rewrite JSON in the data line
|
|
rewritten := rw.rewriteModelInResponse(jsonData)
|
|
lines[i] = append([]byte("data: "), rewritten...)
|
|
}
|
|
}
|
|
}
|
|
|
|
return bytes.Join(lines, []byte("\n"))
|
|
}
|