From acf483c9e6cd5af8b91f2b670d67575bac99628e Mon Sep 17 00:00:00 2001 From: canxin121 Date: Tue, 24 Feb 2026 01:42:54 +0800 Subject: [PATCH] fix(responses): reject invalid SSE data JSON Guard the openai-response streaming path against truncated/invalid SSE data payloads by validating data: JSON before forwarding; surface a 502 terminal error instead of letting clients crash with JSON parse errors. --- sdk/api/handlers/handlers.go | 35 ++++++++ .../handlers_stream_bootstrap_test.go | 83 +++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 68859853..0e490e32 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -716,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return } if len(chunk.Payload) > 0 { + if handlerType == "openai-response" { + if err := validateSSEDataJSON(chunk.Payload); err != nil { + _ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}) + return + } + } sentPayload = true if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData { return @@ -727,6 +733,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return dataChan, upstreamHeaders, errChan } +func validateSSEDataJSON(chunk []byte) error { + for _, line := range bytes.Split(chunk, []byte("\n")) { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + data := bytes.TrimSpace(line[5:]) + if len(data) == 0 { + continue + } + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + if json.Valid(data) { + continue + } + const max = 512 + preview := data + if len(preview) > max { + preview = preview[:max] + } + return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview) + } + return nil +} + func statusFromError(err error) int { if err == nil { return 0 diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index ba9dcac5..b08e3a99 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -134,6 +134,37 @@ type authAwareStreamExecutor struct { authIDs []string } +type invalidJSONStreamExecutor struct{} + +func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" } + +func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { + ch := make(chan coreexecutor.StreamChunk, 1) + ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")} + close(ch) + return &coreexecutor.StreamResult{Chunks: ch}, nil +} + +func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) { + return nil, &coreauth.Error{ + Code: "not_implemented", + Message: "HttpRequest not implemented", + HTTPStatus: http.StatusNotImplemented, + } +} + func (e *authAwareStreamExecutor) Identifier() string { return "codex" } func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2") } } + +func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) { + executor := &invalidJSONStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + if len(got) != 0 { + t.Fatalf("expected empty payload, got %q", string(got)) + } + + gotErr := false + for msg := range errChan { + if msg == nil { + continue + } + if msg.StatusCode != http.StatusBadGateway { + t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode) + } + if msg.Error == nil { + t.Fatalf("expected error") + } + gotErr = true + } + if !gotErr { + t.Fatalf("expected terminal error") + } +}