mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 04:10:51 +08:00
Migrates the AMP module to a new unified routing system, replacing the fallback handler with a router-based approach. This change introduces a `ModelRoutingWrapper` that handles model extraction, routing decisions, and proxying based on provider availability and model mappings. It provides a more flexible and maintainable routing mechanism by centralizing routing logic. The changes include: - Introducing new `routing` package with core routing logic. - Creating characterization tests to capture existing behavior. - Implementing model extraction and rewriting. - Updating AMP module routes to utilize the new routing wrapper. - Deprecating `FallbackHandler` in favor of the new `ModelRoutingWrapper`.
343 lines
10 KiB
Go
343 lines
10 KiB
Go
package routing
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestModelRewriter_RewriteRequestBody(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
|
|
tests := []struct {
|
|
name string
|
|
body []byte
|
|
newModel string
|
|
wantModel string
|
|
wantChange bool
|
|
}{
|
|
{
|
|
name: "rewrites model field in JSON body",
|
|
body: []byte(`{"model":"gpt-4.1","messages":[]}`),
|
|
newModel: "claude-local",
|
|
wantModel: "claude-local",
|
|
wantChange: true,
|
|
},
|
|
{
|
|
name: "rewrites with empty body returns empty",
|
|
body: []byte{},
|
|
newModel: "gpt-4",
|
|
wantModel: "",
|
|
wantChange: false,
|
|
},
|
|
{
|
|
name: "handles missing model field gracefully",
|
|
body: []byte(`{"messages":[{"role":"user"}]}`),
|
|
newModel: "gpt-4",
|
|
wantModel: "",
|
|
wantChange: false,
|
|
},
|
|
{
|
|
name: "preserves other fields when rewriting",
|
|
body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`),
|
|
newModel: "new-model",
|
|
wantModel: "new-model",
|
|
wantChange: true,
|
|
},
|
|
{
|
|
name: "handles nested JSON structure",
|
|
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`),
|
|
newModel: "claude-3-opus",
|
|
wantModel: "claude-3-opus",
|
|
wantChange: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel)
|
|
require.NoError(t, err)
|
|
|
|
if tt.wantChange {
|
|
assert.NotEqual(t, string(tt.body), string(result), "body should have been modified")
|
|
}
|
|
|
|
if tt.wantModel != "" {
|
|
// Parse result and check model field
|
|
model, _ := NewModelExtractor().Extract(result, nil)
|
|
assert.Equal(t, tt.wantModel, model)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestModelRewriter_WrapResponseWriter(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
|
|
t.Run("response writer wraps without error", func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
require.NotNil(t, wrapped)
|
|
require.NotNil(t, cleanup)
|
|
defer cleanup()
|
|
})
|
|
|
|
t.Run("rewrites model in non-streaming response", func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
// Write a response with the resolved model
|
|
response := []byte(`{"model":"claude-local","content":"hello"}`)
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
_, err := wrapped.Write(response)
|
|
require.NoError(t, err)
|
|
|
|
// Cleanup triggers the rewrite
|
|
cleanup()
|
|
|
|
// Check the response was rewritten to the requested model
|
|
body := recorder.Body.Bytes()
|
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
|
})
|
|
|
|
t.Run("no-op when requested equals resolved", func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4")
|
|
|
|
response := []byte(`{"model":"gpt-4","content":"hello"}`)
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
_, err := wrapped.Write(response)
|
|
require.NoError(t, err)
|
|
|
|
cleanup()
|
|
|
|
body := recorder.Body.Bytes()
|
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
|
})
|
|
|
|
t.Run("rewrites modelVersion field", func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
response := []byte(`{"modelVersion":"claude-local","content":"hello"}`)
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
_, err := wrapped.Write(response)
|
|
require.NoError(t, err)
|
|
|
|
cleanup()
|
|
|
|
body := recorder.Body.Bytes()
|
|
assert.Contains(t, string(body), `"modelVersion":"gpt-4"`)
|
|
})
|
|
|
|
t.Run("handles streaming responses", func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
// Set streaming content type
|
|
wrapped.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
// Write SSE chunks with resolved model
|
|
chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n")
|
|
_, err := wrapped.Write(chunk1)
|
|
require.NoError(t, err)
|
|
|
|
chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n")
|
|
_, err = wrapped.Write(chunk2)
|
|
require.NoError(t, err)
|
|
|
|
cleanup()
|
|
|
|
// For streaming, data is written immediately with rewrites
|
|
body := recorder.Body.Bytes()
|
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
|
})
|
|
|
|
t.Run("empty body handled gracefully", func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
// Don't write anything
|
|
|
|
cleanup()
|
|
|
|
body := recorder.Body.Bytes()
|
|
assert.Empty(t, body)
|
|
})
|
|
|
|
t.Run("preserves other JSON fields", func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`)
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
_, err := wrapped.Write(response)
|
|
require.NoError(t, err)
|
|
|
|
cleanup()
|
|
|
|
body := recorder.Body.Bytes()
|
|
assert.Contains(t, string(body), `"temperature":0.7`)
|
|
assert.Contains(t, string(body), `"prompt_tokens":10`)
|
|
})
|
|
}
|
|
|
|
func TestResponseRewriter_ImplementsInterfaces(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
defer cleanup()
|
|
|
|
// Should implement http.ResponseWriter
|
|
assert.Implements(t, (*http.ResponseWriter)(nil), wrapped)
|
|
|
|
// Should preserve header access
|
|
wrapped.Header().Set("X-Custom", "value")
|
|
assert.Equal(t, "value", recorder.Header().Get("X-Custom"))
|
|
|
|
// Should write status
|
|
wrapped.WriteHeader(http.StatusCreated)
|
|
assert.Equal(t, http.StatusCreated, recorder.Code)
|
|
}
|
|
|
|
func TestResponseRewriter_Flush(t *testing.T) {
|
|
t.Run("flush writes buffered content", func(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
response := []byte(`{"model":"claude-local","content":"test"}`)
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
wrapped.Write(response)
|
|
|
|
// Before cleanup, response should be empty (buffered)
|
|
assert.Empty(t, recorder.Body.Bytes())
|
|
|
|
// After cleanup, response should be written
|
|
cleanup()
|
|
assert.NotEmpty(t, recorder.Body.Bytes())
|
|
})
|
|
|
|
t.Run("multiple flush calls are safe", func(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
response := []byte(`{"model":"claude-local"}`)
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
wrapped.Write(response)
|
|
|
|
// First cleanup
|
|
cleanup()
|
|
firstBody := recorder.Body.Bytes()
|
|
|
|
// Second cleanup should not write again
|
|
cleanup()
|
|
secondBody := recorder.Body.Bytes()
|
|
|
|
assert.Equal(t, firstBody, secondBody)
|
|
})
|
|
}
|
|
|
|
func TestResponseRewriter_StreamingWithDataLines(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
wrapped.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
// SSE format with multiple data lines
|
|
chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n")
|
|
wrapped.Write(chunk)
|
|
|
|
cleanup()
|
|
|
|
body := recorder.Body.Bytes()
|
|
// Both data lines should have model rewritten
|
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
|
}
|
|
|
|
func TestModelRewriter_RoundTrip(t *testing.T) {
|
|
// Simulate a full request -> response cycle with model rewriting
|
|
rewriter := NewModelRewriter()
|
|
|
|
// Step 1: Rewrite request body
|
|
originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`)
|
|
rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local")
|
|
require.NoError(t, err)
|
|
|
|
// Verify request was rewritten
|
|
extractor := NewModelExtractor()
|
|
requestModel, _ := extractor.Extract(rewrittenRequest, nil)
|
|
assert.Equal(t, "claude-local", requestModel)
|
|
|
|
// Step 2: Simulate response with resolved model
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`)
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
wrapped.Write(response)
|
|
cleanup()
|
|
|
|
// Verify response was rewritten back
|
|
body, _ := io.ReadAll(recorder.Result().Body)
|
|
responseModel, _ := extractor.Extract(body, nil)
|
|
assert.Equal(t, "gpt-4", responseModel)
|
|
}
|
|
|
|
func TestModelRewriter_NonJSONBody(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
|
|
// Binary/non-JSON body should be returned unchanged
|
|
body := []byte{0x00, 0x01, 0x02, 0x03}
|
|
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, body, result)
|
|
}
|
|
|
|
func TestModelRewriter_InvalidJSON(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
|
|
// Invalid JSON without model field should be returned unchanged
|
|
body := []byte(`not valid json`)
|
|
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, body, result)
|
|
}
|
|
|
|
func TestResponseRewriter_StatusCodePreserved(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
wrapped.WriteHeader(http.StatusAccepted)
|
|
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
|
cleanup()
|
|
|
|
assert.Equal(t, http.StatusAccepted, recorder.Code)
|
|
}
|
|
|
|
func TestResponseRewriter_HeaderFlushed(t *testing.T) {
|
|
rewriter := NewModelRewriter()
|
|
recorder := httptest.NewRecorder()
|
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
|
|
|
wrapped.Header().Set("Content-Type", "application/json")
|
|
wrapped.Header().Set("X-Request-ID", "abc123")
|
|
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
|
cleanup()
|
|
|
|
result := recorder.Result()
|
|
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
|
|
assert.Equal(t, "abc123", result.Header.Get("X-Request-ID"))
|
|
}
|