mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Ensures compatibility with the Amp client by suppressing "thinking" blocks when "tool_use" blocks are also present in the response. The Amp client has issues rendering both types of blocks simultaneously. This change filters out "thinking" blocks in such cases, preventing rendering problems.
130 lines
4.1 KiB
Go
130 lines
4.1 KiB
Go
package amp
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
|
// It's used to rewrite model names in responses when model mapping is used
|
|
type ResponseRewriter struct {
|
|
gin.ResponseWriter
|
|
body *bytes.Buffer
|
|
originalModel string
|
|
isStreaming bool
|
|
}
|
|
|
|
// NewResponseRewriter creates a new response rewriter for model name substitution
|
|
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
|
return &ResponseRewriter{
|
|
ResponseWriter: w,
|
|
body: &bytes.Buffer{},
|
|
originalModel: originalModel,
|
|
}
|
|
}
|
|
|
|
// Write intercepts response writes and buffers them for model name replacement
|
|
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
|
// Detect streaming on first write
|
|
if rw.body.Len() == 0 && !rw.isStreaming {
|
|
contentType := rw.Header().Get("Content-Type")
|
|
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
|
strings.Contains(contentType, "stream")
|
|
}
|
|
|
|
if rw.isStreaming {
|
|
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
|
if err == nil {
|
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
return n, err
|
|
}
|
|
return rw.body.Write(data)
|
|
}
|
|
|
|
// Flush writes the buffered response with model names rewritten
|
|
func (rw *ResponseRewriter) Flush() {
|
|
if rw.isStreaming {
|
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
return
|
|
}
|
|
if rw.body.Len() > 0 {
|
|
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
|
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// modelFieldPaths lists all JSON paths where model name may appear
|
|
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
|
|
|
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
|
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
|
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
|
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
|
// The Amp client struggles when both thinking and tool_use blocks are present
|
|
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
|
// The Amp client struggles when both thinking and tool_use blocks are present
|
|
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
|
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
|
if filtered.Exists() {
|
|
originalCount := gjson.GetBytes(data, "content.#").Int()
|
|
filteredCount := filtered.Get("#").Int()
|
|
|
|
if originalCount > filteredCount {
|
|
var err error
|
|
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
|
if err != nil {
|
|
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
|
} else {
|
|
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
|
// Log the result for verification
|
|
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if rw.originalModel == "" {
|
|
return data
|
|
}
|
|
for _, path := range modelFieldPaths {
|
|
if gjson.GetBytes(data, path).Exists() {
|
|
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
|
}
|
|
}
|
|
return data
|
|
}
|
|
|
|
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
|
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
|
if rw.originalModel == "" {
|
|
return chunk
|
|
}
|
|
|
|
// SSE format: "data: {json}\n\n"
|
|
lines := bytes.Split(chunk, []byte("\n"))
|
|
for i, line := range lines {
|
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
|
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
|
// Rewrite JSON in the data line
|
|
rewritten := rw.rewriteModelInResponse(jsonData)
|
|
lines[i] = append([]byte("data: "), rewritten...)
|
|
}
|
|
}
|
|
}
|
|
|
|
return bytes.Join(lines, []byte("\n"))
|
|
}
|