Files
CLIProxyAPI/internal/api/modules/amp/response_rewriter.go
이대희 31bd90c748 feature: Improves Amp client compatibility
Ensures compatibility with the Amp client by suppressing
"thinking" blocks when "tool_use" blocks are also present in
the response.

The Amp client has issues rendering both types of blocks
simultaneously. This change filters out "thinking" blocks in
such cases, preventing rendering problems.
2025-12-19 08:18:27 +09:00

130 lines
4.1 KiB
Go

package amp
import (
"bytes"
"net/http"
"strings"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
// It's used to rewrite model names in responses when model mapping is used
type ResponseRewriter struct {
gin.ResponseWriter
body *bytes.Buffer
originalModel string
isStreaming bool
}
// NewResponseRewriter creates a new response rewriter for model name substitution
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
return &ResponseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
}
}
// Write intercepts response writes and buffers them for model name replacement
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
// Detect streaming on first write
if rw.body.Len() == 0 && !rw.isStreaming {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
if err == nil {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
return n, err
}
return rw.body.Write(data)
}
// Flush writes the buffered response with model names rewritten
func (rw *ResponseRewriter) Flush() {
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
return
}
if rw.body.Len() > 0 {
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount {
var err error
data, err = sjson.SetBytes(data, "content", filtered.Value())
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
}
}
}
}
if rw.originalModel == "" {
return data
}
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.originalModel)
}
}
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.originalModel == "" {
return chunk
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
}
}
}
return bytes.Join(lines, []byte("\n"))
}