mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(amp): add response rewriter for model name substitution in responses
This commit is contained in:
@@ -2,7 +2,6 @@ package amp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,6 +10,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AmpRouteType represents the type of routing decision made for an Amp request
|
// AmpRouteType represents the type of routing decision made for an Amp request
|
||||||
@@ -138,7 +139,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
if fh.modelMapper != nil {
|
if fh.modelMapper != nil {
|
||||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||||
// Mapping found - rewrite the model in request body
|
// Mapping found - rewrite the model in request body
|
||||||
bodyBytes = rewriteModelInBody(bodyBytes, mappedModel)
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
resolvedModel = mappedModel
|
resolvedModel = mappedModel
|
||||||
usedMapping = true
|
usedMapping = true
|
||||||
@@ -180,58 +181,59 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
if usedMapping {
|
if usedMapping {
|
||||||
// Log: Model was mapped to another model
|
// Log: Model was mapped to another model
|
||||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
|
||||||
|
c.Writer = rewriter
|
||||||
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
|
filterAntropicBetaHeader(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
rewriter.Flush()
|
||||||
|
log.Debugf("amp response rewriter: rewrote model %s -> %s in response", resolvedModel, normalizedModel)
|
||||||
} else if len(providers) > 0 {
|
} else if len(providers) > 0 {
|
||||||
// Log: Using local provider (free)
|
// Log: Using local provider (free)
|
||||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
|
filterAntropicBetaHeader(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
} else {
|
||||||
|
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
||||||
|
handler(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Providers available or no proxy for fallback, restore body and use normal handler
|
|
||||||
// Filter Anthropic-Beta header to remove features requiring special subscription
|
|
||||||
// This is needed when using local providers (bypassing the Amp proxy)
|
|
||||||
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
|
||||||
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
|
||||||
if filtered != "" {
|
|
||||||
c.Request.Header.Set("Anthropic-Beta", filtered)
|
|
||||||
} else {
|
|
||||||
c.Request.Header.Del("Anthropic-Beta")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
handler(c)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewriteModelInBody replaces the model name in a JSON request body
|
// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
|
||||||
func rewriteModelInBody(body []byte, newModel string) []byte {
|
// This is needed when using local providers (bypassing the Amp proxy)
|
||||||
var payload map[string]interface{}
|
func filterAntropicBetaHeader(c *gin.Context) {
|
||||||
if err := json.Unmarshal(body, &payload); err != nil {
|
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||||
log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err)
|
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||||
|
} else {
|
||||||
|
c.Request.Header.Del("Anthropic-Beta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteModelInRequest replaces the model name in a JSON request body
|
||||||
|
func rewriteModelInRequest(body []byte, newModel string) []byte {
|
||||||
|
if !gjson.GetBytes(body, "model").Exists() {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
result, err := sjson.SetBytes(body, "model", newModel)
|
||||||
if _, exists := payload["model"]; exists {
|
if err != nil {
|
||||||
payload["model"] = newModel
|
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
|
||||||
newBody, err := json.Marshal(payload)
|
return body
|
||||||
if err != nil {
|
|
||||||
log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err)
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return newBody
|
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
return body
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractModelFromRequest attempts to extract the model name from various request formats
|
// extractModelFromRequest attempts to extract the model name from various request formats
|
||||||
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
||||||
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||||
var payload map[string]interface{}
|
// Check common model field names
|
||||||
if err := json.Unmarshal(body, &payload); err == nil {
|
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
||||||
// Check common model field names
|
return result.String()
|
||||||
if model, ok := payload["model"].(string); ok {
|
|
||||||
return model
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Gemini requests, model is in the URL path
|
// For Gemini requests, model is in the URL path
|
||||||
|
|||||||
108
internal/api/modules/amp/response_rewriter.go
Normal file
108
internal/api/modules/amp/response_rewriter.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
||||||
|
// It's used to rewrite model names in responses when model mapping is used
|
||||||
|
type ResponseRewriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
originalModel string
|
||||||
|
isStreaming bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponseRewriter creates a new response rewriter for model name substitution
|
||||||
|
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
||||||
|
return &ResponseRewriter{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: &bytes.Buffer{},
|
||||||
|
originalModel: originalModel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write intercepts response writes and buffers them for model name replacement
|
||||||
|
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||||
|
// Detect streaming on first write
|
||||||
|
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||||
|
contentType := rw.Header().Get("Content-Type")
|
||||||
|
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||||
|
strings.Contains(contentType, "stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rw.isStreaming {
|
||||||
|
return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||||
|
}
|
||||||
|
return rw.body.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush writes the buffered response with model names rewritten
|
||||||
|
func (rw *ResponseRewriter) Flush() {
|
||||||
|
if rw.isStreaming {
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rw.body.Len() > 0 {
|
||||||
|
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
||||||
|
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelFieldPaths lists all JSON paths where model name may appear
|
||||||
|
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||||
|
|
||||||
|
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||||
|
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
|
if rw.originalModel == "" {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
for _, path := range modelFieldPaths {
|
||||||
|
if gjson.GetBytes(data, path).Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
||||||
|
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||||
|
if rw.originalModel == "" {
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE format: "data: {json}\n\n"
|
||||||
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
|
for i, line := range lines {
|
||||||
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
// Rewrite JSON in the data line
|
||||||
|
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||||
|
lines[i] = append([]byte("data: "), rewritten...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Join(lines, []byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack implements http.Hijacker for WebSocket support
|
||||||
|
func (rw *ResponseRewriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
return nil, nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user