Files
agent-framework/python/packages/ag-ui/tests/ag_ui/test_run.py
T
Evan Mattson 5e8fe0be1f Python: Stop emitting duplicate reasoning content from OpenAI response.reasoning_text.done and response.reasoning_summary_text.done events (#5162)
* Fix reasoning text done events duplicating streamed delta content (#5157)

The OpenAI Responses API sends both reasoning_text.delta (incremental
chunks) and reasoning_text.done (full accumulated text) events. The
chat client was emitting Content for both, causing ag-ui to append the
full done text onto already-accumulated delta text, producing
duplicated reasoning output.

Stop emitting Content for reasoning_text.done and
reasoning_summary_text.done events, matching how output_text.done is
already handled (not emitted). The deltas contain all the content;
the done event is redundant.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix(openai): emit reasoning done content as fallback when no deltas observed (#5157)

Address PR review feedback:
- Track item_ids that received reasoning deltas via seen_reasoning_delta_item_ids set
- Emit content from done events only when no deltas were received for the
  item_id, preventing silent content loss on stream resumption
- Add comment documenting code_interpreter done event asymmetry
- Replace redundant ag-ui test with deduplication-focused test
- Add integration test for delta+done sequence in OpenAI chat client tests
- Add fallback path tests for done events without preceding deltas

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #5157: Python: [Bug]: "type": "response.reasoning_text.delta" and "response.reasoning_text.done" both get exposed as "text_reasoning"

* Fix AG-UI reasoning streaming to use proper Start/End pattern (#5157)

_emit_text_reasoning now follows the same streaming pattern as _emit_text:
- Emits ReasoningStartEvent/ReasoningMessageStartEvent only on the first
  delta for a given message_id
- Emits only ReasoningMessageContentEvent for subsequent deltas
- Defers ReasoningMessageEndEvent/ReasoningEndEvent until
  _close_reasoning_block is called (on content type switch or end-of-run)

This produces the correct protocol pattern:
  ReasoningStartEvent
    ReasoningMessageStartEvent
    ReasoningMessageContentEvent(delta1)
    ReasoningMessageContentEvent(delta2)
    ReasoningMessageEndEvent
  ReasoningEndEvent

Instead of wrapping every delta in a full Start→End sequence.

Backward compatibility is preserved: calling _emit_text_reasoning without
a flow argument still produces the full sequence per call.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix import ordering lint error in AG-UI test file (#5157)

Move inline import of TextMessageContentEvent to the top-level import
block and ensure alphabetical ordering to satisfy ruff I001 rule.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix mypy error: rename loop variable to avoid type conflict with WorkflowEvent

The 'event' variable was already typed as WorkflowEvent[Any] from the
async for loop at line 590. Reusing it in the _close_reasoning_block
loop (which returns list[BaseEvent]) caused an incompatible assignment
error. Renamed to 'reasoning_evt' to avoid the conflict.

Fixes #5162

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #5157: review comment fixes

* narrow test result reporting to explicit pytest JUnit XML

* Fix test args

* Fix pytest-results-action in merge workflow and remove committed test artifacts

Apply the same JUnit XML fix from python-tests.yml to python-merge-tests.yml:
add --junitxml=pytest.xml to all test commands and narrow the results action
path from ./python/**.xml to ./python/pytest.xml. Also remove accidentally
committed pytest.xml and python-coverage.xml and add them to .gitignore.

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-09 22:44:59 +00:00

1643 lines
62 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for _agent_run.py helper functions and FlowState."""
import pytest
from ag_ui.core import (
CustomEvent,
ReasoningEncryptedValueEvent,
ReasoningEndEvent,
ReasoningMessageContentEvent,
ReasoningMessageEndEvent,
ReasoningMessageStartEvent,
ReasoningStartEvent,
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallArgsEvent,
)
from agent_framework import AgentResponseUpdate, Content, Message, ResponseStream
from agent_framework.exceptions import AgentInvalidResponseException
from agent_framework_ag_ui._agent_run import (
_build_safe_metadata,
_create_state_context_message,
_inject_state_context,
_normalize_response_stream,
_resume_to_tool_messages,
_should_suppress_intermediate_snapshot,
)
from agent_framework_ag_ui._run_common import (
FlowState,
_build_run_finished_event,
_close_reasoning_block,
_emit_approval_request,
_emit_content,
_emit_mcp_tool_call,
_emit_mcp_tool_result,
_emit_text,
_emit_text_reasoning,
_emit_tool_call,
_emit_tool_result,
_extract_resume_payload,
_has_only_tool_calls,
)
class TestBuildSafeMetadata:
"""Tests for _build_safe_metadata function."""
def test_none_metadata(self):
"""Returns empty dict for None."""
result = _build_safe_metadata(None)
assert result == {}
def test_empty_metadata(self):
"""Returns empty dict for empty dict."""
result = _build_safe_metadata({})
assert result == {}
def test_short_string_values(self):
"""Preserves short string values."""
metadata = {"key1": "short", "key2": "value"}
result = _build_safe_metadata(metadata)
assert result == metadata
def test_truncates_long_strings(self):
"""Truncates strings over 512 chars."""
long_value = "x" * 1000
metadata = {"key": long_value}
result = _build_safe_metadata(metadata)
assert len(result["key"]) == 512
def test_serializes_non_strings(self):
"""Serializes non-string values to JSON."""
metadata = {"count": 42, "items": [1, 2, 3]}
result = _build_safe_metadata(metadata)
assert result["count"] == "42"
assert result["items"] == "[1, 2, 3]"
def test_truncates_serialized_values(self):
"""Truncates serialized values over 512 chars."""
long_list = list(range(200))
metadata = {"data": long_list}
result = _build_safe_metadata(metadata)
assert len(result["data"]) == 512
class TestHasOnlyToolCalls:
"""Tests for _has_only_tool_calls function."""
def test_only_tool_calls(self):
"""Returns True when only function_call content."""
contents = [
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
]
assert _has_only_tool_calls(contents) is True
def test_tool_call_with_text(self):
"""Returns False when both tool call and text."""
contents = [
Content.from_text("Some text"),
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
]
assert _has_only_tool_calls(contents) is False
def test_only_text(self):
"""Returns False when only text."""
contents = [Content.from_text("Just text")]
assert _has_only_tool_calls(contents) is False
def test_empty_contents(self):
"""Returns False for empty contents."""
assert _has_only_tool_calls([]) is False
def test_tool_call_with_empty_text(self):
"""Returns True when text content has empty text."""
contents = [
Content.from_text(""),
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
]
assert _has_only_tool_calls(contents) is True
class TestShouldSuppressIntermediateSnapshot:
"""Tests for _should_suppress_intermediate_snapshot function."""
def test_no_tool_name(self):
"""Returns False when no tool name."""
result = _should_suppress_intermediate_snapshot(
None, {"key": {"tool": "write_doc", "tool_argument": "content"}}, False
)
assert result is False
def test_no_config(self):
"""Returns False when no config."""
result = _should_suppress_intermediate_snapshot("write_doc", None, False)
assert result is False
def test_confirmation_required(self):
"""Returns False when confirmation is required."""
config = {"key": {"tool": "write_doc", "tool_argument": "content"}}
result = _should_suppress_intermediate_snapshot("write_doc", config, True)
assert result is False
def test_tool_not_in_config(self):
"""Returns False when tool not in config."""
config = {"key": {"tool": "other_tool", "tool_argument": "content"}}
result = _should_suppress_intermediate_snapshot("write_doc", config, False)
assert result is False
def test_suppresses_predictive_tool(self):
"""Returns True for predictive tool without confirmation."""
config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
result = _should_suppress_intermediate_snapshot("write_doc", config, False)
assert result is True
class TestFlowState:
"""Tests for FlowState dataclass."""
def test_default_values(self):
"""Tests default initialization."""
flow = FlowState()
assert flow.message_id is None
assert flow.tool_call_id is None
assert flow.tool_call_name is None
assert flow.waiting_for_approval is False
assert flow.current_state == {}
assert flow.accumulated_text == ""
assert flow.pending_tool_calls == []
assert flow.tool_calls_by_id == {}
assert flow.tool_results == []
assert flow.tool_calls_ended == set()
assert flow.interrupts == []
def test_get_tool_name(self):
"""Tests get_tool_name method."""
flow = FlowState()
flow.tool_calls_by_id = {"call_123": {"function": {"name": "get_weather", "arguments": "{}"}}}
assert flow.get_tool_name("call_123") == "get_weather"
assert flow.get_tool_name("nonexistent") is None
assert flow.get_tool_name(None) is None
def test_get_tool_name_empty_name(self):
"""Tests get_tool_name with empty name."""
flow = FlowState()
flow.tool_calls_by_id = {"call_123": {"function": {"name": "", "arguments": "{}"}}}
assert flow.get_tool_name("call_123") is None
def test_get_pending_without_end(self):
"""Tests get_pending_without_end method."""
flow = FlowState()
flow.pending_tool_calls = [
{"id": "call_1", "function": {"name": "tool1"}},
{"id": "call_2", "function": {"name": "tool2"}},
{"id": "call_3", "function": {"name": "tool3"}},
]
flow.tool_calls_ended = {"call_1", "call_3"}
result = flow.get_pending_without_end()
assert len(result) == 1
assert result[0]["id"] == "call_2"
class TestNormalizeResponseStream:
"""Tests for _normalize_response_stream helper."""
async def test_accepts_response_stream(self):
"""Accept standard ResponseStream values."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
stream = await _normalize_response_stream(ResponseStream(_stream()))
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_accepts_async_iterable(self):
"""Accept workflow-style async generator streams."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
stream = await _normalize_response_stream(_stream())
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_accepts_awaitable_resolving_to_async_iterable(self):
"""Accept awaitables that resolve to async iterable streams."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
async def _resolve():
return _stream()
stream = await _normalize_response_stream(_resolve())
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_rejects_non_stream_values(self):
"""Reject unsupported stream return values."""
with pytest.raises(AgentInvalidResponseException):
await _normalize_response_stream("not-a-stream")
class TestCreateStateContextMessage:
"""Tests for _create_state_context_message function."""
def test_no_state(self):
"""Returns None when no state."""
result = _create_state_context_message({}, {"properties": {}})
assert result is None
def test_no_schema(self):
"""Returns None when no schema."""
result = _create_state_context_message({"key": "value"}, {})
assert result is None
def test_creates_message(self):
"""Creates state context message."""
state = {"document": "Hello world"}
schema = {"properties": {"document": {"type": "string"}}}
result = _create_state_context_message(state, schema)
assert result is not None
assert result.role == "system"
assert len(result.contents) == 1
assert "Hello world" in result.contents[0].text
assert "Current state" in result.contents[0].text
class TestInjectStateContext:
"""Tests for _inject_state_context function."""
def test_no_state_message(self):
"""Returns original messages when no state context needed."""
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
result = _inject_state_context(messages, {}, {})
assert result == messages
def test_empty_messages(self):
"""Returns empty list for empty messages."""
result = _inject_state_context([], {"key": "value"}, {"properties": {}})
assert result == []
def test_last_message_not_user(self):
"""Returns original messages when last message is not from user."""
messages = [
Message(role="user", contents=[Content.from_text("Hello")]),
Message(role="assistant", contents=[Content.from_text("Hi")]),
]
state = {"key": "value"}
schema = {"properties": {"key": {"type": "string"}}}
result = _inject_state_context(messages, state, schema)
assert result == messages
def test_injects_before_last_user_message(self):
"""Injects state context before last user message."""
messages = [
Message(role="system", contents=[Content.from_text("You are helpful")]),
Message(role="user", contents=[Content.from_text("Hello")]),
]
state = {"document": "content"}
schema = {"properties": {"document": {"type": "string"}}}
result = _inject_state_context(messages, state, schema)
assert len(result) == 3
# System message first
assert result[0].role == "system"
assert "helpful" in result[0].contents[0].text
# State context second
assert result[1].role == "system"
assert "Current state" in result[1].contents[0].text
# User message last
assert result[2].role == "user"
assert "Hello" in result[2].contents[0].text
# Additional tests for _agent_run.py functions
def test_emit_text_basic():
"""Test _emit_text emits correct events."""
flow = FlowState()
content = Content.from_text("Hello world")
events = _emit_text(content, flow)
assert len(events) == 2 # TextMessageStartEvent + TextMessageContentEvent
assert flow.message_id is not None
assert flow.accumulated_text == "Hello world"
def test_emit_text_skip_empty():
"""Test _emit_text skips empty text."""
flow = FlowState()
content = Content.from_text("")
events = _emit_text(content, flow)
assert len(events) == 0
def test_emit_text_continues_existing_message():
"""Test _emit_text continues existing message."""
flow = FlowState()
flow.message_id = "existing-id"
content = Content.from_text("more text")
events = _emit_text(content, flow)
assert len(events) == 1 # Only TextMessageContentEvent, no new start
assert flow.message_id == "existing-id"
def test_emit_text_skips_duplicate_full_message_delta():
"""Test _emit_text skips replayed full-message chunks on an open message."""
flow = FlowState()
flow.message_id = "existing-id"
flow.accumulated_text = "Case complete."
content = Content.from_text("Case complete.")
events = _emit_text(content, flow)
assert events == []
assert flow.accumulated_text == "Case complete."
def test_emit_text_skips_when_waiting_for_approval():
"""Test _emit_text skips when waiting for approval."""
flow = FlowState()
flow.waiting_for_approval = True
content = Content.from_text("should skip")
events = _emit_text(content, flow)
assert len(events) == 0
def test_emit_text_skips_when_skip_text_flag():
"""Test _emit_text skips with skip_text flag."""
flow = FlowState()
content = Content.from_text("should skip")
events = _emit_text(content, flow, skip_text=True)
assert len(events) == 0
def test_emit_tool_call_basic():
"""Test _emit_tool_call emits correct events."""
flow = FlowState()
content = Content.from_function_call(
call_id="call_123",
name="get_weather",
arguments='{"city": "NYC"}',
)
events = _emit_tool_call(content, flow)
assert len(events) >= 1 # At least ToolCallStartEvent
assert flow.tool_call_id == "call_123"
assert flow.tool_call_name == "get_weather"
def test_emit_tool_call_generates_id():
"""Test _emit_tool_call generates ID when not provided."""
flow = FlowState()
# Create content without call_id
content = Content(type="function_call", name="test_tool", arguments="{}")
events = _emit_tool_call(content, flow)
assert len(events) >= 1
assert flow.tool_call_id is not None # ID should be generated
def test_emit_tool_call_skips_duplicate_full_arguments_replay():
"""Test _emit_tool_call skips replayed full-arguments on an existing tool call.
This is a regression test for issue #4194 where some streaming providers
send the full arguments string again after streaming deltas, causing the
arguments to be doubled in MESSAGES_SNAPSHOT events.
Mirrors test_emit_text_skips_duplicate_full_message_delta for consistency.
"""
flow = FlowState()
full_args = '{"city": "Seattle"}'
# Step 1: Initial tool call with name + arguments (normal start)
content_start = Content.from_function_call(
call_id="call_dup",
name="get_weather",
arguments=full_args,
)
events_start = _emit_tool_call(content_start, flow)
# Should emit ToolCallStartEvent + ToolCallArgsEvent
assert any(isinstance(e, ToolCallArgsEvent) for e in events_start)
assert flow.tool_calls_by_id["call_dup"]["function"]["arguments"] == full_args
# Step 2: Provider replays the full arguments (duplicate)
content_replay = Content(type="function_call", call_id="call_dup", arguments=full_args)
events_replay = _emit_tool_call(content_replay, flow)
# Should NOT emit any ToolCallArgsEvent (early return on replay)
args_events = [e for e in events_replay if isinstance(e, ToolCallArgsEvent)]
assert args_events == [], "Duplicate full-arguments replay should not emit ToolCallArgsEvent"
# Accumulated arguments should remain unchanged
assert flow.tool_calls_by_id["call_dup"]["function"]["arguments"] == full_args
def test_emit_tool_result_closes_open_message():
"""Test _emit_tool_result emits TextMessageEndEvent for open text message.
This is a regression test for where TEXT_MESSAGE_END was not
emitted when using MCP tools because the message_id was reset without
closing the message first.
"""
flow = FlowState()
# Simulate an open text message (e.g., from Feature #4 tool-only detection)
flow.message_id = "open-msg-123"
flow.tool_call_id = "call_456"
content = Content.from_function_result(call_id="call_456", result="tool result")
events = _emit_tool_result(content, flow, predictive_handler=None)
# Should have: ToolCallEndEvent, ToolCallResultEvent, TextMessageEndEvent
assert len(events) == 3
# Verify TextMessageEndEvent is emitted for the open message
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
assert len(text_end_events) == 1
assert text_end_events[0].message_id == "open-msg-123"
# Verify message_id is reset after
assert flow.message_id is None
def test_emit_tool_result_no_open_message():
"""Test _emit_tool_result works when there's no open text message."""
flow = FlowState()
# No open message
flow.message_id = None
flow.tool_call_id = "call_456"
content = Content.from_function_result(call_id="call_456", result="tool result")
events = _emit_tool_result(content, flow, predictive_handler=None)
# Should have: ToolCallEndEvent, ToolCallResultEvent (no TextMessageEndEvent)
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
assert len(text_end_events) == 0
def test_emit_tool_result_serializes_non_string_result():
"""Non-string tool results should be serialized before emitting TOOL_CALL_RESULT."""
flow = FlowState()
content = Content.from_function_result(call_id="call_789", result={"ok": True, "items": [1, 2]})
events = _emit_tool_result(content, flow, predictive_handler=None)
result_event = next(event for event in events if getattr(event, "type", None) == "TOOL_CALL_RESULT")
assert isinstance(result_event.content, str)
assert '"ok": true' in result_event.content
assert flow.tool_results[0]["content"] == result_event.content
def test_emit_content_usage_emits_custom_usage_event():
"""Usage content should be emitted as a custom usage event."""
flow = FlowState()
content = Content.from_usage({"input_token_count": 3, "output_token_count": 2, "total_token_count": 5})
events = _emit_content(content, flow)
assert len(events) == 1
assert events[0].type == "CUSTOM"
assert events[0].name == "usage"
assert events[0].value["total_token_count"] == 5
def test_emit_approval_request_populates_interrupt_metadata():
"""Approval requests should populate FlowState interrupts for RUN_FINISHED metadata."""
flow = FlowState(message_id="msg-1")
function_call = Content.from_function_call(call_id="call_123", name="write_doc", arguments={"content": "x"})
approval_content = Content.from_function_approval_request(id="approval_1", function_call=function_call)
_emit_approval_request(approval_content, flow)
assert flow.waiting_for_approval is True
assert len(flow.interrupts) == 1
assert flow.interrupts[0]["id"] == "call_123"
assert flow.interrupts[0]["value"]["type"] == "function_approval_request"
def test_emit_approval_request_accumulates_multiple_interrupts():
"""Multiple approval requests in the same turn should accumulate in flow.interrupts."""
flow = FlowState(message_id="msg-1")
for i in range(1, 4):
function_call = Content.from_function_call(
call_id=f"call_{i}",
name=f"tool_{i}",
arguments={"arg": f"value_{i}"},
)
approval_content = Content.from_function_approval_request(
id=f"approval_{i}",
function_call=function_call,
)
_emit_approval_request(approval_content, flow)
assert len(flow.interrupts) == 3
interrupt_ids = {intr["id"] for intr in flow.interrupts}
assert interrupt_ids == {"call_1", "call_2", "call_3"}
def test_resume_to_tool_messages_from_interrupts_payload():
"""Resume payload interrupt responses map to tool messages."""
resume = {
"interrupts": [
{"id": "req_1", "value": {"accepted": True, "steps": []}},
{"id": "req_2", "value": "plain value"},
]
}
messages = _resume_to_tool_messages(resume)
assert len(messages) == 2
assert messages[0]["role"] == "tool"
assert messages[0]["toolCallId"] == "req_1"
assert '"accepted": true' in messages[0]["content"]
assert messages[1]["content"] == "plain value"
def test_extract_resume_payload_prefers_top_level_resume():
"""Top-level resume should take precedence over forwarded props."""
payload = {
"resume": {"interrupts": [{"id": "req_1", "value": "approved"}]},
"forwarded_props": {"command": {"resume": "ignored"}},
}
result = _extract_resume_payload(payload)
assert result == {"interrupts": [{"id": "req_1", "value": "approved"}]}
def test_extract_resume_payload_reads_forwarded_command_resume():
"""Forwarded command.resume should be treated as a resume payload."""
payload = {
"forwarded_props": {
"command": {"resume": '{"airline":"KLM","departure":"Amsterdam (AMS)","arrival":"San Francisco (SFO)"}'}
}
}
result = _extract_resume_payload(payload)
assert isinstance(result, str)
assert "KLM" in result
def test_build_run_finished_event_with_interrupt():
"""RUN_FINISHED helper should preserve interrupt payloads."""
event = _build_run_finished_event("run-1", "thread-1", interrupts=[{"id": "req_1", "value": {"x": 1}}])
dumped = event.model_dump()
assert dumped["run_id"] == "run-1"
assert dumped["thread_id"] == "thread-1"
assert dumped["interrupt"] == [{"id": "req_1", "value": {"x": 1}}]
def test_extract_approved_state_updates_no_handler():
"""Test _extract_approved_state_updates returns empty with no handler."""
from agent_framework_ag_ui._agent_run import _extract_approved_state_updates
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
result = _extract_approved_state_updates(messages, None)
assert result == {}
def test_extract_approved_state_updates_no_approval():
"""Test _extract_approved_state_updates returns empty when no approval content."""
from agent_framework_ag_ui._agent_run import _extract_approved_state_updates
from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler
handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "content"}})
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
result = _extract_approved_state_updates(messages, handler)
assert result == {}
class TestBuildMessagesSnapshot:
"""Tests for _build_messages_snapshot function."""
def test_tool_calls_and_text_are_separate_messages(self):
"""Test that tool calls and text content are emitted as separate messages.
This is a regression test for issue #3619 where tool calls and content
were incorrectly merged into a single assistant message.
"""
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
flow = FlowState()
flow.message_id = "msg-123"
flow.pending_tool_calls = [
{"id": "call_1", "function": {"name": "get_weather", "arguments": '{"city": "NYC"}'}},
]
flow.accumulated_text = "Here is the weather information."
flow.tool_results = [{"id": "result-1", "role": "tool", "content": '{"temp": 72}', "toolCallId": "call_1"}]
result = _build_messages_snapshot(flow, [])
# Should have 3 messages: tool call msg, tool result, text content msg
assert len(result.messages) == 3
# First message: assistant with tool calls only (no content)
assistant_tool_msg = result.messages[0]
assert assistant_tool_msg.role == "assistant"
assert assistant_tool_msg.tool_calls is not None
assert len(assistant_tool_msg.tool_calls) == 1
assert assistant_tool_msg.content is None
# Second message: tool result
tool_result_msg = result.messages[1]
assert tool_result_msg.role == "tool"
# Third message: assistant with content only (no tool calls)
assistant_text_msg = result.messages[2]
assert assistant_text_msg.role == "assistant"
assert assistant_text_msg.content == "Here is the weather information."
assert assistant_text_msg.tool_calls is None
# The text message should have a different ID than the tool call message
assert assistant_text_msg.id != assistant_tool_msg.id
def test_only_tool_calls_no_text(self):
"""Test snapshot with only tool calls and no accumulated text."""
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
flow = FlowState()
flow.message_id = "msg-123"
flow.pending_tool_calls = [
{"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
]
flow.accumulated_text = ""
flow.tool_results = []
result = _build_messages_snapshot(flow, [])
# Should have 1 message: tool call msg only
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert result.messages[0].tool_calls is not None
assert result.messages[0].content is None
def test_only_text_no_tool_calls(self):
"""Test snapshot with only text and no tool calls."""
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
flow = FlowState()
flow.message_id = "msg-123"
flow.pending_tool_calls = []
flow.accumulated_text = "Hello world"
flow.tool_results = []
result = _build_messages_snapshot(flow, [])
# Should have 1 message: text content msg only
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert result.messages[0].content == "Hello world"
assert result.messages[0].tool_calls is None
# Should use the existing message_id
assert result.messages[0].id == "msg-123"
def test_preserves_snapshot_messages(self):
"""Test that existing snapshot messages are preserved."""
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
flow = FlowState()
flow.pending_tool_calls = []
flow.accumulated_text = ""
existing_messages = [
{"id": "user-1", "role": "user", "content": "Hello"},
{"id": "assist-1", "role": "assistant", "content": "Hi there"},
]
result = _build_messages_snapshot(flow, existing_messages)
assert len(result.messages) == 2
assert result.messages[0].id == "user-1"
assert result.messages[1].id == "assist-1"
def test_malformed_json_in_confirm_args_skips_confirmation():
"""Test that malformed JSON in tool arguments skips confirm_changes flow.
This is a regression test to ensure that when tool arguments contain malformed
JSON, the code skips the confirmation flow entirely rather than crashing or
showing incomplete data to the user.
"""
import json
# Simulate the parsing logic - malformed JSON should trigger skip
malformed_arguments = "{ invalid json }"
tool_call = {"function": {"name": "write_doc", "arguments": malformed_arguments}}
# This is what the code should do - detect parsing failure and skip
should_skip_confirmation = False
try:
json.loads(tool_call.get("function", {}).get("arguments", "{}"))
except json.JSONDecodeError:
should_skip_confirmation = True
# Should skip confirmation when JSON is malformed
assert should_skip_confirmation is True
# Valid JSON should proceed with confirmation
valid_arguments = '{"content": "hello"}'
tool_call_valid = {"function": {"name": "write_doc", "arguments": valid_arguments}}
should_skip_confirmation = False
try:
function_arguments = json.loads(tool_call_valid.get("function", {}).get("arguments", "{}"))
except json.JSONDecodeError:
should_skip_confirmation = True
assert should_skip_confirmation is False
assert function_arguments == {"content": "hello"}
class TestTextMessageEventBalancing:
"""Tests for proper TEXT_MESSAGE_START/END event balancing.
These tests verify that the streaming flow produces balanced pairs of
TextMessageStartEvent and TextMessageEndEvent, especially when tool
execution is involved.
"""
def test_tool_only_flow_produces_balanced_events(self):
"""Test that a tool-only response produces balanced TEXT_MESSAGE events.
This simulates the scenario where the LLM immediately calls a tool
without any initial text, then returns text after the tool result.
"""
flow = FlowState()
all_events: list = []
# Step 1: LLM outputs function_call only (no text)
func_call_content = Content.from_function_call(
call_id="call_weather",
name="get_weather",
arguments='{"city": "Seattle"}',
)
# Feature #4 check: this should trigger TextMessageStartEvent
contents = [func_call_content]
if not flow.message_id and _has_only_tool_calls(contents):
flow.message_id = "tool-msg-1"
all_events.append(TextMessageStartEvent(message_id=flow.message_id, role="assistant"))
# Emit tool call events
all_events.extend(_emit_content(func_call_content, flow))
# Step 2: Tool executes and returns result
func_result_content = Content.from_function_result(
call_id="call_weather",
result='{"temp": 55, "conditions": "rainy"}',
)
# This should close the text message
all_events.extend(_emit_tool_result(func_result_content, flow))
# Verify message_id was reset
assert flow.message_id is None, "message_id should be reset after tool result"
# Step 3: LLM outputs text response
text_content = Content.from_text("The weather in Seattle is 55°F and rainy.")
# Since message_id is None, _emit_text should create a new one
for event in _emit_content(text_content, flow):
all_events.append(event)
# Step 4: End of stream - emit final TextMessageEndEvent
if flow.message_id:
all_events.append(TextMessageEndEvent(message_id=flow.message_id))
# Verify event counts
start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)]
end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
# Should have 2 TextMessageStartEvent and 2 TextMessageEndEvent
assert len(start_events) == 2, f"Expected 2 start events, got {len(start_events)}"
assert len(end_events) == 2, f"Expected 2 end events, got {len(end_events)}"
# Verify order: first message should start and end before second starts
# Find indices
start_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageStartEvent)]
end_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageEndEvent)]
# First end should come before second start
assert end_indices[0] < start_indices[1], (
f"First TextMessageEndEvent (index {end_indices[0]}) should come "
f"before second TextMessageStartEvent (index {start_indices[1]})"
)
def test_text_then_tool_flow(self):
"""Test flow where LLM outputs text first, then calls a tool.
This simulates: "Let me check the weather..." -> tool call -> tool result -> "The weather is..."
"""
flow = FlowState()
all_events: list = []
# Step 1: LLM outputs text first
text1 = Content.from_text("Let me check the weather for you.")
all_events.extend(_emit_content(text1, flow))
# Verify message_id is set
assert flow.message_id is not None, "message_id should be set after text"
first_msg_id = flow.message_id
# Step 2: LLM outputs function_call
func_call = Content.from_function_call(
call_id="call_1",
name="get_weather",
arguments="{}",
)
all_events.extend(_emit_content(func_call, flow))
# Step 3: Tool result comes back
func_result = Content.from_function_result(call_id="call_1", result="sunny")
all_events.extend(_emit_tool_result(func_result, flow))
# Verify message_id was reset and first message was closed
assert flow.message_id is None
end_events_so_far = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
assert len(end_events_so_far) == 1
assert end_events_so_far[0].message_id == first_msg_id
# Step 4: LLM outputs follow-up text
text2 = Content.from_text("The weather is sunny!")
all_events.extend(_emit_content(text2, flow))
# Step 5: End of stream
if flow.message_id:
all_events.append(TextMessageEndEvent(message_id=flow.message_id))
# Verify balance
start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)]
end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
assert len(start_events) == 2
assert len(end_events) == 2
async def test_run_agent_stream_accumulates_multiple_confirm_interrupts():
"""Multiple predictive tool calls in a single streaming run should accumulate interrupts.
This exercises the confirm_changes path in run_agent_stream (_agent_run.py),
ensuring that flow.interrupts.append() works correctly for multiple tool calls
and all interrupts appear in the RUN_FINISHED event.
"""
import json
from conftest import StubAgent
from agent_framework_ag_ui import AgentFrameworkAgent
predict_config = {
"tasks": {"tool": "generate_tasks", "tool_argument": "steps"},
"notes": {"tool": "generate_notes", "tool_argument": "items"},
}
state_schema = {
"tasks": {"type": "array", "items": {"type": "object"}},
"notes": {"type": "array", "items": {"type": "object"}},
}
updates = [
AgentResponseUpdate(
contents=[
Content.from_function_call(
name="generate_tasks",
call_id="call-tasks",
arguments=json.dumps({"steps": [{"description": "Task 1"}]}),
),
Content.from_function_call(
name="generate_notes",
call_id="call-notes",
arguments=json.dumps({"items": [{"description": "Note 1"}]}),
),
],
role="assistant",
),
]
stub = StubAgent(updates=updates)
agent = AgentFrameworkAgent(
agent=stub,
state_schema=state_schema,
predict_state_config=predict_config,
require_confirmation=True,
)
payload = {
"thread_id": "thread-multi",
"run_id": "run-multi",
"messages": [{"role": "user", "content": "Generate tasks and notes"}],
"state": {"tasks": [], "notes": []},
}
events = [event async for event in agent.run(payload)]
# Find RUN_FINISHED event and verify multiple interrupts
finished_events = [
e
for e in events
if getattr(e, "type", None) == "RUN_FINISHED"
or getattr(getattr(e, "type", None), "value", None) == "RUN_FINISHED"
]
assert finished_events, f"Expected RUN_FINISHED event. Types: {[getattr(e, 'type', None) for e in events]}"
finished = finished_events[-1]
interrupt = getattr(finished, "interrupt", None)
assert interrupt is not None, "Expected interrupt metadata in RUN_FINISHED"
assert len(interrupt) == 2, f"Expected 2 interrupts (one per tool), got {len(interrupt)}"
# Verify both tool calls are represented in interrupt metadata
interrupt_tool_names = {i["value"]["function_call"]["name"] for i in interrupt}
assert interrupt_tool_names == {"generate_tasks", "generate_notes"}
def test_emit_oauth_consent_request():
"""Test that oauth_consent_request content emits a CustomEvent."""
content = Content.from_oauth_consent_request(
consent_link="https://login.microsoftonline.com/consent",
)
flow = FlowState()
events = _emit_content(content, flow)
assert len(events) == 1
assert isinstance(events[0], CustomEvent)
assert events[0].name == "oauth_consent_request"
assert events[0].value == {"consent_link": "https://login.microsoftonline.com/consent"}
def test_emit_oauth_consent_request_no_link():
"""Test that oauth_consent_request without a consent_link emits no events."""
content = Content("oauth_consent_request")
flow = FlowState()
events = _emit_content(content, flow)
assert len(events) == 0
# ============================================================================
# Tests for MCP tool call, MCP tool result, and text reasoning event emission
# ============================================================================
class TestEmitMcpToolCall:
"""Tests for _emit_mcp_tool_call function."""
def test_produces_start_and_args_events(self):
"""MCP tool call emits ToolCallStart + ToolCallArgs events."""
flow = FlowState()
content = Content.from_mcp_server_tool_call(
call_id="mcp_call_1",
tool_name="search",
server_name="brave",
arguments={"query": "weather"},
)
events = _emit_mcp_tool_call(content, flow)
assert len(events) == 2
assert events[0].type == "TOOL_CALL_START"
assert events[0].tool_call_id == "mcp_call_1"
assert events[0].tool_call_name == "search"
assert events[1].type == "TOOL_CALL_ARGS"
assert events[1].tool_call_id == "mcp_call_1"
assert "weather" in events[1].delta
def test_tracks_in_flow_state(self):
"""MCP tool call is tracked in flow.pending_tool_calls and tool_calls_by_id."""
flow = FlowState()
content = Content.from_mcp_server_tool_call(
call_id="mcp_call_2",
tool_name="get_file",
arguments='{"path": "/tmp/test.txt"}',
)
_emit_mcp_tool_call(content, flow)
assert len(flow.pending_tool_calls) == 1
assert flow.pending_tool_calls[0]["id"] == "mcp_call_2"
assert "mcp_call_2" in flow.tool_calls_by_id
assert flow.tool_calls_by_id["mcp_call_2"]["function"]["name"] == "get_file"
assert flow.tool_calls_by_id["mcp_call_2"]["function"]["arguments"] == '{"path": "/tmp/test.txt"}'
def test_no_server_name_uses_tool_name_only(self):
"""Without server_name, display name is just tool_name."""
flow = FlowState()
content = Content.from_mcp_server_tool_call(
call_id="mcp_call_3",
tool_name="list_files",
)
events = _emit_mcp_tool_call(content, flow)
assert events[0].tool_call_name == "list_files"
def test_no_arguments_skips_args_event(self):
"""No arguments produces only ToolCallStart, no ToolCallArgs."""
flow = FlowState()
content = Content.from_mcp_server_tool_call(
call_id="mcp_call_4",
tool_name="ping",
)
events = _emit_mcp_tool_call(content, flow)
assert len(events) == 1
assert events[0].type == "TOOL_CALL_START"
def test_generates_id_when_missing(self):
"""A tool_call_id is generated when call_id is None."""
flow = FlowState()
content = Content(type="mcp_server_tool_call", tool_name="test_tool")
events = _emit_mcp_tool_call(content, flow)
assert len(events) >= 1
assert events[0].tool_call_id is not None
assert events[0].tool_call_id != ""
assert events[0].tool_call_name == "test_tool"
def test_missing_tool_name_falls_back_to_mcp_tool(self):
"""When tool_name is None, the fallback 'mcp_tool' is used."""
flow = FlowState()
content = Content(type="mcp_server_tool_call")
events = _emit_mcp_tool_call(content, flow)
assert len(events) >= 1
assert events[0].tool_call_name == "mcp_tool"
class TestEmitMcpToolResult:
"""Tests for _emit_mcp_tool_result function."""
def test_produces_end_and_result_events(self):
"""MCP tool result emits ToolCallEnd + ToolCallResult events."""
flow = FlowState()
content = Content.from_mcp_server_tool_result(
call_id="mcp_call_1",
output={"results": [{"title": "Weather", "url": "https://example.com"}]},
)
events = _emit_mcp_tool_result(content, flow)
assert len(events) == 2
assert events[0].type == "TOOL_CALL_END"
assert events[0].tool_call_id == "mcp_call_1"
assert events[1].type == "TOOL_CALL_RESULT"
assert events[1].tool_call_id == "mcp_call_1"
assert "Weather" in events[1].content
def test_tracks_in_flow_state(self):
"""MCP tool result is tracked in flow.tool_results and tool_calls_ended."""
flow = FlowState()
content = Content.from_mcp_server_tool_result(
call_id="mcp_call_5",
output="Success",
)
_emit_mcp_tool_result(content, flow)
assert "mcp_call_5" in flow.tool_calls_ended
assert len(flow.tool_results) == 1
assert flow.tool_results[0]["toolCallId"] == "mcp_call_5"
assert flow.tool_results[0]["content"] == "Success"
def test_no_call_id_returns_empty(self):
"""Missing call_id returns empty events list with a warning."""
flow = FlowState()
content = Content(type="mcp_server_tool_result", output="data")
events = _emit_mcp_tool_result(content, flow)
assert events == []
def test_serializes_non_string_output(self):
"""Non-string output is serialized to JSON."""
flow = FlowState()
content = Content.from_mcp_server_tool_result(
call_id="mcp_call_6",
output={"key": "value", "count": 42},
)
events = _emit_mcp_tool_result(content, flow)
result_event = events[1]
assert isinstance(result_event.content, str)
assert '"key": "value"' in result_event.content
def test_output_none_falls_back_to_empty_string(self):
"""When output is None (default), the result content is an empty string."""
flow = FlowState()
content = Content(type="mcp_server_tool_result", call_id="mcp_call_none")
events = _emit_mcp_tool_result(content, flow)
assert len(events) == 2
assert events[1].type == "TOOL_CALL_RESULT"
assert events[1].content == ""
def test_resets_flow_state_like_emit_tool_result(self):
"""MCP tool result performs same FlowState cleanup as _emit_tool_result."""
flow = FlowState()
flow.tool_call_id = "mcp_call_7"
flow.tool_call_name = "brave/search"
flow.message_id = "open-msg-456"
flow.accumulated_text = "Let me search for that..."
content = Content.from_mcp_server_tool_result(
call_id="mcp_call_7",
output="search results",
)
events = _emit_mcp_tool_result(content, flow)
assert flow.tool_call_id is None
assert flow.tool_call_name is None
assert flow.message_id is None
assert flow.accumulated_text == ""
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
assert len(text_end_events) == 1
assert text_end_events[0].message_id == "open-msg-456"
def test_no_open_message_skips_text_end(self):
"""MCP tool result without open text message skips TextMessageEndEvent."""
flow = FlowState()
flow.message_id = None
content = Content.from_mcp_server_tool_result(
call_id="mcp_call_8",
output="result",
)
events = _emit_mcp_tool_result(content, flow)
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
assert len(text_end_events) == 0
def test_predictive_handler_emits_state_snapshot(self):
"""MCP tool result applies pending updates and emits StateSnapshotEvent when predictive_handler is set."""
from unittest.mock import MagicMock
from ag_ui.core import StateSnapshotEvent
flow = FlowState()
flow.current_state = {"doc": "hello"}
content = Content.from_mcp_server_tool_result(
call_id="mcp_call_9",
output="done",
)
handler = MagicMock()
events = _emit_mcp_tool_result(content, flow, predictive_handler=handler)
handler.apply_pending_updates.assert_called_once()
snapshot_events = [e for e in events if isinstance(e, StateSnapshotEvent)]
assert len(snapshot_events) == 1
assert snapshot_events[0].snapshot == {"doc": "hello"}
class TestEmitTextReasoning:
"""Tests for _emit_text_reasoning function."""
def test_produces_reasoning_events(self):
"""Text reasoning emits the full reasoning event sequence."""
content = Content.from_text_reasoning(
id="reason_1",
text="The user is asking about weather, so I should call the weather tool.",
)
events = _emit_text_reasoning(content)
assert len(events) == 5
assert isinstance(events[0], ReasoningStartEvent)
assert events[0].message_id == "reason_1"
assert isinstance(events[1], ReasoningMessageStartEvent)
assert events[1].message_id == "reason_1"
assert events[1].role == "assistant"
assert isinstance(events[2], ReasoningMessageContentEvent)
assert events[2].message_id == "reason_1"
assert events[2].delta == "The user is asking about weather, so I should call the weather tool."
assert isinstance(events[3], ReasoningMessageEndEvent)
assert events[3].message_id == "reason_1"
assert isinstance(events[4], ReasoningEndEvent)
assert events[4].message_id == "reason_1"
def test_protected_data_emits_encrypted_value_event(self):
"""protected_data is emitted as a ReasoningEncryptedValueEvent."""
content = Content.from_text_reasoning(
id="reason_2",
text="visible reasoning",
protected_data="encrypted metadata",
)
events = _emit_text_reasoning(content)
encrypted_events = [e for e in events if isinstance(e, ReasoningEncryptedValueEvent)]
assert len(encrypted_events) == 1
assert encrypted_events[0].subtype == "message"
assert encrypted_events[0].entity_id == "reason_2"
assert encrypted_events[0].encrypted_value == "encrypted metadata"
def test_protected_data_only_emits_event(self):
"""Content with only protected_data (no text) still emits reasoning events."""
content = Content.from_text_reasoning(
protected_data="encrypted reasoning content",
)
events = _emit_text_reasoning(content)
# Should have start, msg_start, msg_end, encrypted_value, end (no content event)
assert len(events) == 5
assert isinstance(events[0], ReasoningStartEvent)
assert isinstance(events[1], ReasoningMessageStartEvent)
assert isinstance(events[2], ReasoningMessageEndEvent)
assert isinstance(events[3], ReasoningEncryptedValueEvent)
assert events[3].encrypted_value == "encrypted reasoning content"
assert isinstance(events[4], ReasoningEndEvent)
def test_empty_text_and_no_protected_data_returns_empty(self):
"""Empty text and no protected_data returns no events."""
content = Content.from_text_reasoning()
events = _emit_text_reasoning(content)
assert events == []
def test_generates_message_id_when_missing(self):
"""When id is None, a message_id is generated."""
content = Content.from_text_reasoning(text="thinking...")
events = _emit_text_reasoning(content)
assert len(events) == 5
assert events[0].message_id is not None
assert events[0].message_id != ""
# All events share the same message_id
assert events[1].message_id == events[0].message_id
class TestEmitContentMcpRouting:
"""Tests that _emit_content correctly routes MCP and reasoning types."""
def test_routes_mcp_server_tool_call(self):
"""_emit_content dispatches mcp_server_tool_call to _emit_mcp_tool_call."""
flow = FlowState()
content = Content.from_mcp_server_tool_call(
call_id="route_test_1",
tool_name="test_tool",
server_name="test_server",
)
events = _emit_content(content, flow)
assert len(events) >= 1
assert events[0].type == "TOOL_CALL_START"
assert events[0].tool_call_name == "test_tool"
def test_routes_mcp_server_tool_result(self):
"""_emit_content dispatches mcp_server_tool_result to _emit_mcp_tool_result."""
flow = FlowState()
content = Content.from_mcp_server_tool_result(
call_id="route_test_2",
output="result data",
)
events = _emit_content(content, flow)
assert len(events) == 2
assert events[0].type == "TOOL_CALL_END"
assert events[1].type == "TOOL_CALL_RESULT"
def test_routes_text_reasoning(self):
"""_emit_content dispatches text_reasoning to _emit_text_reasoning."""
flow = FlowState()
content = Content.from_text_reasoning(text="I need to think about this...")
events = _emit_content(content, flow)
# Streaming pattern: Start + MessageStart + Content (no End events yet)
assert len(events) == 3
assert isinstance(events[0], ReasoningStartEvent)
assert isinstance(events[1], ReasoningMessageStartEvent)
assert isinstance(events[2], ReasoningMessageContentEvent)
class TestReasoningInSnapshot:
"""Tests for reasoning message inclusion in MESSAGES_SNAPSHOT."""
def test_reasoning_persisted_to_flow_state(self):
"""_emit_text_reasoning with flow persists reasoning into flow.reasoning_messages."""
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_persist",
text="Let me think step by step.",
)
_emit_text_reasoning(content, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["id"] == "reason_persist"
assert flow.reasoning_messages[0]["role"] == "reasoning"
assert flow.reasoning_messages[0]["content"] == "Let me think step by step."
assert "encryptedValue" not in flow.reasoning_messages[0]
def test_reasoning_with_encrypted_value_persisted(self):
"""Reasoning with protected_data preserves encryptedValue in flow state."""
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_enc",
text="visible reasoning",
protected_data="encrypted-data-123",
)
_emit_text_reasoning(content, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-data-123"
def test_snapshot_includes_reasoning(self):
"""_build_messages_snapshot includes reasoning messages from flow state."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
flow = FlowState()
flow.accumulated_text = "Here is my answer."
flow.reasoning_messages = [
{"id": "r1", "role": "reasoning", "content": "Thinking..."},
]
snapshot = _build_messages_snapshot(flow, [])
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
assert "reasoning" in roles
def test_snapshot_preserves_reasoning_encrypted_value(self):
"""Snapshot reasoning with encryptedValue is preserved end-to-end."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
flow = FlowState()
content = Content.from_text_reasoning(
id="reason_e2e",
text="visible",
protected_data="secret-data",
)
_emit_text_reasoning(content, flow)
text_content = Content.from_text("Final answer.")
_emit_text(text_content, flow)
snapshot = _build_messages_snapshot(flow, [])
reasoning_msgs = [
m
for m in snapshot.messages
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "reasoning"
]
assert len(reasoning_msgs) == 1
msg = reasoning_msgs[0]
if isinstance(msg, dict):
assert msg["content"] == "visible"
assert msg["encryptedValue"] == "secret-data"
def test_emit_content_routes_reasoning_with_flow(self):
"""_emit_content passes flow to _emit_text_reasoning for persistence."""
flow = FlowState()
content = Content.from_text_reasoning(text="routed reasoning")
_emit_content(content, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "routed reasoning"
def test_reasoning_without_flow_does_not_error(self):
"""Calling _emit_text_reasoning without flow still works (backward compat)."""
content = Content.from_text_reasoning(text="no flow")
events = _emit_text_reasoning(content)
assert len(events) == 5
assert isinstance(events[0], ReasoningStartEvent)
def test_snapshot_reasoning_ordering(self):
"""Reasoning messages appear after assistant text in snapshot."""
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
flow = FlowState()
reasoning_content = Content.from_text_reasoning(id="r1", text="Thinking...")
_emit_text_reasoning(reasoning_content, flow)
text_content = Content.from_text("Answer")
_emit_text(text_content, flow)
snapshot = _build_messages_snapshot(flow, [{"id": "u1", "role": "user", "content": "Hi"}])
# user -> assistant text -> reasoning
assert len(snapshot.messages) == 3
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
assert roles == ["user", "assistant", "reasoning"]
def test_reasoning_accumulates_incremental_deltas(self):
"""Multiple reasoning deltas with the same id accumulate into one entry."""
flow = FlowState()
content1 = Content.from_text_reasoning(id="reason_inc", text="First ")
content2 = Content.from_text_reasoning(id="reason_inc", text="second ")
content3 = Content.from_text_reasoning(id="reason_inc", text="third.")
_emit_text_reasoning(content1, flow)
_emit_text_reasoning(content2, flow)
_emit_text_reasoning(content3, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["id"] == "reason_inc"
assert flow.reasoning_messages[0]["content"] == "First second third."
def test_reasoning_accumulates_distinct_message_ids(self):
"""Reasoning entries with different ids are stored separately."""
flow = FlowState()
content_a = Content.from_text_reasoning(id="a", text="alpha")
content_b = Content.from_text_reasoning(id="b", text="beta")
_emit_text_reasoning(content_a, flow)
_emit_text_reasoning(content_b, flow)
assert len(flow.reasoning_messages) == 2
assert flow.reasoning_messages[0]["content"] == "alpha"
assert flow.reasoning_messages[1]["content"] == "beta"
def test_reasoning_encrypted_value_updated_on_later_delta(self):
"""encryptedValue is set even when it arrives with a later delta."""
flow = FlowState()
content1 = Content.from_text_reasoning(id="enc_late", text="part1 ")
content2 = Content.from_text_reasoning(id="enc_late", text="part2", protected_data="encrypted-payload")
_emit_text_reasoning(content1, flow)
_emit_text_reasoning(content2, flow)
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "part1 part2"
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-payload"
def test_reasoning_done_after_deltas_does_not_duplicate(self):
"""A done-style content arriving after deltas does not duplicate accumulated text.
The upstream client should skip done events when deltas preceded them,
but if one leaks through, the accumulator must not double-append.
This test verifies that only the delta-produced text is stored.
"""
flow = FlowState()
msg_id = "reason_dedup"
delta1 = Content.from_text_reasoning(id=msg_id, text="Hello ")
delta2 = Content.from_text_reasoning(id=msg_id, text="world")
_emit_text_reasoning(delta1, flow)
_emit_text_reasoning(delta2, flow)
# Accumulated text should equal the concatenation of deltas only
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "Hello world"
assert flow.reasoning_messages[0]["id"] == msg_id
def test_reasoning_deltas_emit_one_content_event_each(self):
"""Each reasoning delta emits exactly one ReasoningMessageContentEvent
within a single Start/End sequence (streaming pattern)."""
flow = FlowState()
msg_id = "reason_evt"
delta1 = Content.from_text_reasoning(id=msg_id, text="Think ")
delta2 = Content.from_text_reasoning(id=msg_id, text="hard")
events1 = _emit_text_reasoning(delta1, flow)
events2 = _emit_text_reasoning(delta2, flow)
close_events = _close_reasoning_block(flow)
all_events = events1 + events2 + close_events
content_events = [e for e in all_events if isinstance(e, ReasoningMessageContentEvent)]
assert len(content_events) == 2
assert content_events[0].delta == "Think "
assert content_events[1].delta == "hard"
# Streaming pattern: one Start/End sequence wrapping both content events
start_events = [e for e in all_events if isinstance(e, ReasoningStartEvent)]
end_events = [e for e in all_events if isinstance(e, ReasoningEndEvent)]
msg_start_events = [e for e in all_events if isinstance(e, ReasoningMessageStartEvent)]
msg_end_events = [e for e in all_events if isinstance(e, ReasoningMessageEndEvent)]
assert len(start_events) == 1
assert len(end_events) == 1
assert len(msg_start_events) == 1
assert len(msg_end_events) == 1
def test_reasoning_streaming_event_order(self):
"""Streaming reasoning emits Start once, then Content per delta, then End on close."""
flow = FlowState()
msg_id = "reason_order"
d1 = Content.from_text_reasoning(id=msg_id, text="A ")
d2 = Content.from_text_reasoning(id=msg_id, text="B ")
d3 = Content.from_text_reasoning(id=msg_id, text="C")
events = []
events.extend(_emit_text_reasoning(d1, flow))
events.extend(_emit_text_reasoning(d2, flow))
events.extend(_emit_text_reasoning(d3, flow))
events.extend(_close_reasoning_block(flow))
assert isinstance(events[0], ReasoningStartEvent)
assert isinstance(events[1], ReasoningMessageStartEvent)
assert isinstance(events[2], ReasoningMessageContentEvent)
assert events[2].delta == "A "
assert isinstance(events[3], ReasoningMessageContentEvent)
assert events[3].delta == "B "
assert isinstance(events[4], ReasoningMessageContentEvent)
assert events[4].delta == "C"
assert isinstance(events[5], ReasoningMessageEndEvent)
assert isinstance(events[6], ReasoningEndEvent)
assert len(events) == 7
def test_close_reasoning_block_noop_when_not_open(self):
"""_close_reasoning_block returns empty list when no reasoning block is open."""
flow = FlowState()
assert _close_reasoning_block(flow) == []
def test_close_reasoning_block_resets_state(self):
"""_close_reasoning_block clears reasoning_message_id."""
flow = FlowState()
_emit_text_reasoning(Content.from_text_reasoning(id="r1", text="x"), flow)
assert flow.reasoning_message_id == "r1"
_close_reasoning_block(flow)
assert flow.reasoning_message_id is None
def test_emit_content_closes_reasoning_on_text(self):
"""Switching from reasoning to text content auto-closes reasoning block."""
flow = FlowState()
reasoning = Content.from_text_reasoning(id="r1", text="thinking")
text = Content.from_text("answer")
r_events = _emit_content(reasoning, flow)
t_events = _emit_content(text, flow)
# reasoning events: Start + MsgStart + Content
assert isinstance(r_events[0], ReasoningStartEvent)
# text events should start with reasoning End events
assert isinstance(t_events[0], ReasoningMessageEndEvent)
assert isinstance(t_events[1], ReasoningEndEvent)
# then text start
assert isinstance(t_events[2], TextMessageStartEvent)
assert isinstance(t_events[3], TextMessageContentEvent)
def test_reasoning_distinct_ids_close_previous_block(self):
"""Emitting reasoning with a new message_id auto-closes the previous block."""
flow = FlowState()
c1 = Content.from_text_reasoning(id="block1", text="first")
c2 = Content.from_text_reasoning(id="block2", text="second")
events1 = _emit_text_reasoning(c1, flow)
events2 = _emit_text_reasoning(c2, flow)
close = _close_reasoning_block(flow)
# events1: Start(block1) + MsgStart(block1) + Content(block1)
assert events1[0].message_id == "block1"
# events2: MsgEnd(block1) + End(block1) + Start(block2) + MsgStart(block2) + Content(block2)
assert isinstance(events2[0], ReasoningMessageEndEvent)
assert events2[0].message_id == "block1"
assert isinstance(events2[1], ReasoningEndEvent)
assert events2[1].message_id == "block1"
assert isinstance(events2[2], ReasoningStartEvent)
assert events2[2].message_id == "block2"
# close: MsgEnd(block2) + End(block2)
assert isinstance(close[0], ReasoningMessageEndEvent)
assert close[0].message_id == "block2"