diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 59c83e22..46c3d3d9 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -2,7 +2,6 @@ package amp import ( "bytes" - "encoding/json" "io" "net/http/httputil" "strings" @@ -11,6 +10,8 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // AmpRouteType represents the type of routing decision made for an Amp request @@ -138,7 +139,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if fh.modelMapper != nil { if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { // Mapping found - rewrite the model in request body - bodyBytes = rewriteModelInBody(bodyBytes, mappedModel) + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) resolvedModel = mappedModel usedMapping = true @@ -180,58 +181,59 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if usedMapping { // Log: Model was mapped to another model logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) + rewriter := NewResponseRewriter(c.Writer, normalizedModel) + c.Writer = rewriter + // Filter Anthropic-Beta header only for local handling paths + filterAntropicBetaHeader(c) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + rewriter.Flush() + log.Debugf("amp response rewriter: rewrote model %s -> %s in response", resolvedModel, normalizedModel) } else if len(providers) > 0 { // Log: Using local provider (free) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) + // Filter Anthropic-Beta header only for local handling paths + filterAntropicBetaHeader(c) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + } else { + // No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response + handler(c) } - - // Providers available or no proxy for fallback, restore body and use normal handler - // Filter Anthropic-Beta header to remove features requiring special subscription - // This is needed when using local providers (bypassing the Amp proxy) - if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" { - filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07") - if filtered != "" { - c.Request.Header.Set("Anthropic-Beta", filtered) - } else { - c.Request.Header.Del("Anthropic-Beta") - } - } - - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - handler(c) } } -// rewriteModelInBody replaces the model name in a JSON request body -func rewriteModelInBody(body []byte, newModel string) []byte { - var payload map[string]interface{} - if err := json.Unmarshal(body, &payload); err != nil { - log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err) +// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription +// This is needed when using local providers (bypassing the Amp proxy) +func filterAntropicBetaHeader(c *gin.Context) { + if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" { + if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" { + c.Request.Header.Set("Anthropic-Beta", filtered) + } else { + c.Request.Header.Del("Anthropic-Beta") + } + } +} + +// rewriteModelInRequest replaces the model name in a JSON request body +func rewriteModelInRequest(body []byte, newModel string) []byte { + if !gjson.GetBytes(body, "model").Exists() { return body } - - if _, exists := payload["model"]; exists { - payload["model"] = newModel - newBody, err := json.Marshal(payload) - if err != nil { - log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err) - return body - } - return newBody + result, err := sjson.SetBytes(body, "model", newModel) + if err != nil { + log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err) + return body } - - return body + return result } // extractModelFromRequest attempts to extract the model name from various request formats func extractModelFromRequest(body []byte, c *gin.Context) string { // First try to parse from JSON body (OpenAI, Claude, etc.) - var payload map[string]interface{} - if err := json.Unmarshal(body, &payload); err == nil { - // Check common model field names - if model, ok := payload["model"].(string); ok { - return model - } + // Check common model field names + if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String { + return result.String() } // For Gemini requests, model is in the URL path diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go new file mode 100644 index 00000000..5be36078 --- /dev/null +++ b/internal/api/modules/amp/response_rewriter.go @@ -0,0 +1,108 @@ +package amp + +import ( + "bufio" + "bytes" + "net" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body +// It's used to rewrite model names in responses when model mapping is used +type ResponseRewriter struct { + gin.ResponseWriter + body *bytes.Buffer + originalModel string + isStreaming bool +} + +// NewResponseRewriter creates a new response rewriter for model name substitution +func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter { + return &ResponseRewriter{ + ResponseWriter: w, + body: &bytes.Buffer{}, + originalModel: originalModel, + } +} + +// Write intercepts response writes and buffers them for model name replacement +func (rw *ResponseRewriter) Write(data []byte) (int, error) { + // Detect streaming on first write + if rw.body.Len() == 0 && !rw.isStreaming { + contentType := rw.Header().Get("Content-Type") + rw.isStreaming = strings.Contains(contentType, "text/event-stream") || + strings.Contains(contentType, "stream") + } + + if rw.isStreaming { + return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) + } + return rw.body.Write(data) +} + +// Flush writes the buffered response with model names rewritten +func (rw *ResponseRewriter) Flush() { + if rw.isStreaming { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + return + } + if rw.body.Len() > 0 { + if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil { + log.Warnf("amp response rewriter: failed to write rewritten response: %v", err) + } + } +} + +// modelFieldPaths lists all JSON paths where model name may appear +var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"} + +// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON +func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { + if rw.originalModel == "" { + return data + } + for _, path := range modelFieldPaths { + if gjson.GetBytes(data, path).Exists() { + data, _ = sjson.SetBytes(data, path, rw.originalModel) + } + } + return data +} + +// rewriteStreamChunk rewrites model names in SSE stream chunks +func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { + if rw.originalModel == "" { + return chunk + } + + // SSE format: "data: {json}\n\n" + lines := bytes.Split(chunk, []byte("\n")) + for i, line := range lines { + if bytes.HasPrefix(line, []byte("data: ")) { + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) > 0 && jsonData[0] == '{' { + // Rewrite JSON in the data line + rewritten := rw.rewriteModelInResponse(jsonData) + lines[i] = append([]byte("data: "), rewritten...) + } + } + } + + return bytes.Join(lines, []byte("\n")) +} + +// Hijack implements http.Hijacker for WebSocket support +func (rw *ResponseRewriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, http.ErrNotSupported +}