Files
CLIProxyAPI/internal/routing/rewriter.go
이대희 9299897e04 Implements unified model routing
Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach.

This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings.
It provides a more flexible and maintainable routing mechanism by centralizing routing logic.

The changes include:
- Introducing new `routing` package with core routing logic.
- Creating characterization tests to capture existing behavior.
- Implementing model extraction and rewriting.
- Updating AMP module routes to utilize the new routing wrapper.
- Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
2026-02-01 16:58:32 +09:00

160 lines
4.8 KiB
Go

package routing
import (
"bytes"
"net/http"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
log "github.com/sirupsen/logrus"
)
// ModelRewriter handles model name rewriting in requests and responses.
type ModelRewriter interface {
// RewriteRequestBody rewrites the model field in a JSON request body.
// Returns the modified body or the original if no rewrite was needed.
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
}
// DefaultModelRewriter is the standard implementation of ModelRewriter.
type DefaultModelRewriter struct{}
// NewModelRewriter creates a new DefaultModelRewriter.
func NewModelRewriter() *DefaultModelRewriter {
return &DefaultModelRewriter{}
}
// RewriteRequestBody replaces the model name in a JSON request body.
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
if !gjson.GetBytes(body, "model").Exists() {
return body, nil
}
result, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body, err
}
return result, nil
}
// WrapResponseWriter wraps a response writer to rewrite model names.
// The cleanup function must be called after the handler completes to flush any buffered data.
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
rw := &responseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
requestedModel: requestedModel,
resolvedModel: resolvedModel,
}
return rw, func() { rw.flush() }
}
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
type responseRewriter struct {
http.ResponseWriter
body *bytes.Buffer
requestedModel string
resolvedModel string
isStreaming bool
wroteHeader bool
flushed bool
}
// Write intercepts response writes and buffers them for model name replacement.
func (rw *responseRewriter) Write(data []byte) (int, error) {
// Ensure header is written
if !rw.wroteHeader {
rw.WriteHeader(http.StatusOK)
}
// Detect streaming on first write
if rw.body.Len() == 0 && !rw.isStreaming {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
if err == nil {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
return n, err
}
return rw.body.Write(data)
}
// WriteHeader captures the status code and delegates to the underlying writer.
func (rw *responseRewriter) WriteHeader(code int) {
if !rw.wroteHeader {
rw.wroteHeader = true
rw.ResponseWriter.WriteHeader(code)
}
}
// flush writes the buffered response with model names rewritten.
func (rw *responseRewriter) flush() {
if rw.flushed {
return
}
rw.flushed = true
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
return
}
if rw.body.Len() > 0 {
data := rw.rewriteModelInResponse(rw.body.Bytes())
if _, err := rw.ResponseWriter.Write(data); err != nil {
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear.
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
return data
}
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
}
}
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks.
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
return chunk
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
}
}
}
return bytes.Join(lines, []byte("\n"))
}