Files
CLIProxyAPI/internal/api/modules/amp/proxy_test.go
이대희 93147dddeb Improves error handling for canceled requests
Adds explicit handling for context.Canceled errors in the reverse proxy error handler to return 499 status code without logging, which is more appropriate for client-side cancellations during polling.

Also adds a test case to verify this behavior.
2026-02-12 10:39:45 +09:00

682 lines
19 KiB
Go

package amp
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Helper: compress data with gzip
func gzipBytes(b []byte) []byte {
var buf bytes.Buffer
zw := gzip.NewWriter(&buf)
zw.Write(b)
zw.Close()
return buf.Bytes()
}
// Helper: create a mock http.Response
func mkResp(status int, hdr http.Header, body []byte) *http.Response {
if hdr == nil {
hdr = http.Header{}
}
return &http.Response{
StatusCode: status,
Header: hdr,
Body: io.NopCloser(bytes.NewReader(body)),
ContentLength: int64(len(body)),
}
}
func TestCreateReverseProxy_ValidURL(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key"))
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if proxy == nil {
t.Fatal("expected proxy to be created")
}
}
func TestCreateReverseProxy_InvalidURL(t *testing.T) {
_, err := createReverseProxy("://invalid", NewStaticSecretSource("key"))
if err == nil {
t.Fatal("expected error for invalid URL")
}
}
func TestModifyResponse_GzipScenarios(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
good := gzipBytes(goodJSON)
truncated := good[:10]
corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...)
cases := []struct {
name string
header http.Header
body []byte
status int
wantBody []byte
wantCE string
}{
{
name: "decompresses_valid_gzip_no_header",
header: http.Header{},
body: good,
status: 200,
wantBody: goodJSON,
wantCE: "",
},
{
name: "skips_when_ce_present",
header: http.Header{"Content-Encoding": []string{"gzip"}},
body: good,
status: 200,
wantBody: good,
wantCE: "gzip",
},
{
name: "passes_truncated_unchanged",
header: http.Header{},
body: truncated,
status: 200,
wantBody: truncated,
wantCE: "",
},
{
name: "passes_corrupted_unchanged",
header: http.Header{},
body: corrupted,
status: 200,
wantBody: corrupted,
wantCE: "",
},
{
name: "non_gzip_unchanged",
header: http.Header{},
body: []byte("plain"),
status: 200,
wantBody: []byte("plain"),
wantCE: "",
},
{
name: "empty_body",
header: http.Header{},
body: []byte{},
status: 200,
wantBody: []byte{},
wantCE: "",
},
{
name: "single_byte_body",
header: http.Header{},
body: []byte{0x1f},
status: 200,
wantBody: []byte{0x1f},
wantCE: "",
},
{
name: "skips_non_2xx_status",
header: http.Header{},
body: good,
status: 404,
wantBody: good,
wantCE: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
resp := mkResp(tc.status, tc.header, tc.body)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if !bytes.Equal(got, tc.wantBody) {
t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got)
}
if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE {
t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce)
}
})
}
}
func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"message":"test response"}`)
gzipped := gzipBytes(goodJSON)
// Simulate upstream response with gzip body AND Content-Length header
// (this is the scenario the bot flagged - stale Content-Length after decompression)
resp := mkResp(200, http.Header{
"Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size
}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// Verify body is decompressed
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, goodJSON) {
t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON)
}
// Verify Content-Length header is updated to decompressed size
wantCL := fmt.Sprintf("%d", len(goodJSON))
gotCL := resp.Header.Get("Content-Length")
if gotCL != wantCL {
t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL)
}
// Verify struct field also matches
if resp.ContentLength != int64(len(goodJSON)) {
t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength)
}
}
func TestModifyResponse_SkipsStreamingResponses(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
gzipped := gzipBytes(goodJSON)
t.Run("sse_skips_decompression", func(t *testing.T) {
resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// SSE should NOT be decompressed
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, gzipped) {
t.Fatal("SSE response should not be decompressed")
}
})
}
func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
gzipped := gzipBytes(goodJSON)
t.Run("chunked_json_decompresses", func(t *testing.T) {
// Chunked JSON responses (like thread APIs) should be decompressed
resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// Should decompress because it's not SSE
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, goodJSON) {
t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON)
}
})
}
func TestReverseProxy_InjectsHeaders(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "secret" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer secret" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_EmptySecret(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
// Should NOT inject headers when secret is empty
if hdr.Get("X-Api-Key") != "" {
t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key"))
}
if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " {
t.Fatalf("Authorization should not be set, got: %q", authVal)
}
}
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/any")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
if res.StatusCode != http.StatusBadGateway {
t.Fatalf("want 502, got %d", res.StatusCode)
}
if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) {
t.Fatalf("unexpected body: %s", body)
}
if ct := res.Header.Get("Content-Type"); ct != "application/json" {
t.Fatalf("content-type: want application/json, got %s", ct)
}
}
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
// Test that context.Canceled errors return 499 without generic error response
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
// Create a canceled context to trigger the cancellation path
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
rr := httptest.NewRecorder()
// Directly invoke the ErrorHandler with context.Canceled
proxy.ErrorHandler(rr, req, context.Canceled)
// Body should be empty for canceled requests (no JSON error response)
body := rr.Body.Bytes()
if len(body) > 0 {
t.Fatalf("expected empty body for canceled context, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
// Upstream returns gzipped JSON without Content-Encoding header
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write(gzipBytes([]byte(`{"upstream":"ok"}`)))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
expected := []byte(`{"upstream":"ok"}`)
if !bytes.Equal(body, expected) {
t.Fatalf("want decompressed JSON, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) {
// Upstream returns plain JSON
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write([]byte(`{"plain":"json"}`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
expected := []byte(`{"plain":"json"}`)
if !bytes.Equal(body, expected) {
t.Fatalf("want plain JSON unchanged, got: %s", body)
}
}
func TestIsStreamingResponse(t *testing.T) {
cases := []struct {
name string
header http.Header
want bool
}{
{
name: "sse",
header: http.Header{"Content-Type": []string{"text/event-stream"}},
want: true,
},
{
name: "chunked_not_streaming",
header: http.Header{"Transfer-Encoding": []string{"chunked"}},
want: false, // Chunked is transport-level, not streaming
},
{
name: "normal_json",
header: http.Header{"Content-Type": []string{"application/json"}},
want: false,
},
{
name: "empty",
header: http.Header{},
want: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
resp := &http.Response{Header: tc.header}
got := isStreamingResponse(resp)
if got != tc.want {
t.Fatalf("want %v, got %v", tc.want, got)
}
})
}
}
func TestFilterBetaFeatures(t *testing.T) {
tests := []struct {
name string
header string
featureToRemove string
expected string
}{
{
name: "Remove context-1m from middle",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
{
name: "Remove context-1m from start",
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14",
},
{
name: "Remove context-1m from end",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14",
},
{
name: "Feature not present",
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
{
name: "Only feature to remove",
header: "context-1m-2025-08-07",
featureToRemove: "context-1m-2025-08-07",
expected: "",
},
{
name: "Empty header",
header: "",
featureToRemove: "context-1m-2025-08-07",
expected: "",
},
{
name: "Header with spaces",
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterBetaFeatures(tt.header, tt.featureToRemove)
if result != tt.expected {
t.Errorf("filterBetaFeatures() = %q, want %q", result, tt.expected)
}
})
}
}