mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
5e8fe0be1f
* Fix reasoning text done events duplicating streamed delta content (#5157) The OpenAI Responses API sends both reasoning_text.delta (incremental chunks) and reasoning_text.done (full accumulated text) events. The chat client was emitting Content for both, causing ag-ui to append the full done text onto already-accumulated delta text, producing duplicated reasoning output. Stop emitting Content for reasoning_text.done and reasoning_summary_text.done events, matching how output_text.done is already handled (not emitted). The deltas contain all the content; the done event is redundant. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(openai): emit reasoning done content as fallback when no deltas observed (#5157) Address PR review feedback: - Track item_ids that received reasoning deltas via seen_reasoning_delta_item_ids set - Emit content from done events only when no deltas were received for the item_id, preventing silent content loss on stream resumption - Add comment documenting code_interpreter done event asymmetry - Replace redundant ag-ui test with deduplication-focused test - Add integration test for delta+done sequence in OpenAI chat client tests - Add fallback path tests for done events without preceding deltas Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #5157: Python: [Bug]: "type": "response.reasoning_text.delta" and "response.reasoning_text.done" both get exposed as "text_reasoning" * Fix AG-UI reasoning streaming to use proper Start/End pattern (#5157) _emit_text_reasoning now follows the same streaming pattern as _emit_text: - Emits ReasoningStartEvent/ReasoningMessageStartEvent only on the first delta for a given message_id - Emits only ReasoningMessageContentEvent for subsequent deltas - Defers ReasoningMessageEndEvent/ReasoningEndEvent until _close_reasoning_block is called (on content type switch or end-of-run) This produces the correct protocol pattern: ReasoningStartEvent ReasoningMessageStartEvent ReasoningMessageContentEvent(delta1) ReasoningMessageContentEvent(delta2) ReasoningMessageEndEvent ReasoningEndEvent Instead of wrapping every delta in a full Start→End sequence. Backward compatibility is preserved: calling _emit_text_reasoning without a flow argument still produces the full sequence per call. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix import ordering lint error in AG-UI test file (#5157) Move inline import of TextMessageContentEvent to the top-level import block and ensure alphabetical ordering to satisfy ruff I001 rule. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix mypy error: rename loop variable to avoid type conflict with WorkflowEvent The 'event' variable was already typed as WorkflowEvent[Any] from the async for loop at line 590. Reusing it in the _close_reasoning_block loop (which returns list[BaseEvent]) caused an incompatible assignment error. Renamed to 'reasoning_evt' to avoid the conflict. Fixes #5162 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #5157: review comment fixes * narrow test result reporting to explicit pytest JUnit XML * Fix test args * Fix pytest-results-action in merge workflow and remove committed test artifacts Apply the same JUnit XML fix from python-tests.yml to python-merge-tests.yml: add --junitxml=pytest.xml to all test commands and narrow the results action path from ./python/**.xml to ./python/pytest.xml. Also remove accidentally committed pytest.xml and python-coverage.xml and add them to .gitignore. --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1643 lines
62 KiB
Python
1643 lines
62 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Tests for _agent_run.py helper functions and FlowState."""
|
|
|
|
import pytest
|
|
from ag_ui.core import (
|
|
CustomEvent,
|
|
ReasoningEncryptedValueEvent,
|
|
ReasoningEndEvent,
|
|
ReasoningMessageContentEvent,
|
|
ReasoningMessageEndEvent,
|
|
ReasoningMessageStartEvent,
|
|
ReasoningStartEvent,
|
|
TextMessageContentEvent,
|
|
TextMessageEndEvent,
|
|
TextMessageStartEvent,
|
|
ToolCallArgsEvent,
|
|
)
|
|
from agent_framework import AgentResponseUpdate, Content, Message, ResponseStream
|
|
from agent_framework.exceptions import AgentInvalidResponseException
|
|
|
|
from agent_framework_ag_ui._agent_run import (
|
|
_build_safe_metadata,
|
|
_create_state_context_message,
|
|
_inject_state_context,
|
|
_normalize_response_stream,
|
|
_resume_to_tool_messages,
|
|
_should_suppress_intermediate_snapshot,
|
|
)
|
|
from agent_framework_ag_ui._run_common import (
|
|
FlowState,
|
|
_build_run_finished_event,
|
|
_close_reasoning_block,
|
|
_emit_approval_request,
|
|
_emit_content,
|
|
_emit_mcp_tool_call,
|
|
_emit_mcp_tool_result,
|
|
_emit_text,
|
|
_emit_text_reasoning,
|
|
_emit_tool_call,
|
|
_emit_tool_result,
|
|
_extract_resume_payload,
|
|
_has_only_tool_calls,
|
|
)
|
|
|
|
|
|
class TestBuildSafeMetadata:
|
|
"""Tests for _build_safe_metadata function."""
|
|
|
|
def test_none_metadata(self):
|
|
"""Returns empty dict for None."""
|
|
result = _build_safe_metadata(None)
|
|
assert result == {}
|
|
|
|
def test_empty_metadata(self):
|
|
"""Returns empty dict for empty dict."""
|
|
result = _build_safe_metadata({})
|
|
assert result == {}
|
|
|
|
def test_short_string_values(self):
|
|
"""Preserves short string values."""
|
|
metadata = {"key1": "short", "key2": "value"}
|
|
result = _build_safe_metadata(metadata)
|
|
assert result == metadata
|
|
|
|
def test_truncates_long_strings(self):
|
|
"""Truncates strings over 512 chars."""
|
|
long_value = "x" * 1000
|
|
metadata = {"key": long_value}
|
|
result = _build_safe_metadata(metadata)
|
|
assert len(result["key"]) == 512
|
|
|
|
def test_serializes_non_strings(self):
|
|
"""Serializes non-string values to JSON."""
|
|
metadata = {"count": 42, "items": [1, 2, 3]}
|
|
result = _build_safe_metadata(metadata)
|
|
assert result["count"] == "42"
|
|
assert result["items"] == "[1, 2, 3]"
|
|
|
|
def test_truncates_serialized_values(self):
|
|
"""Truncates serialized values over 512 chars."""
|
|
long_list = list(range(200))
|
|
metadata = {"data": long_list}
|
|
result = _build_safe_metadata(metadata)
|
|
assert len(result["data"]) == 512
|
|
|
|
|
|
class TestHasOnlyToolCalls:
|
|
"""Tests for _has_only_tool_calls function."""
|
|
|
|
def test_only_tool_calls(self):
|
|
"""Returns True when only function_call content."""
|
|
contents = [
|
|
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
|
|
]
|
|
assert _has_only_tool_calls(contents) is True
|
|
|
|
def test_tool_call_with_text(self):
|
|
"""Returns False when both tool call and text."""
|
|
contents = [
|
|
Content.from_text("Some text"),
|
|
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
|
|
]
|
|
assert _has_only_tool_calls(contents) is False
|
|
|
|
def test_only_text(self):
|
|
"""Returns False when only text."""
|
|
contents = [Content.from_text("Just text")]
|
|
assert _has_only_tool_calls(contents) is False
|
|
|
|
def test_empty_contents(self):
|
|
"""Returns False for empty contents."""
|
|
assert _has_only_tool_calls([]) is False
|
|
|
|
def test_tool_call_with_empty_text(self):
|
|
"""Returns True when text content has empty text."""
|
|
contents = [
|
|
Content.from_text(""),
|
|
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
|
|
]
|
|
assert _has_only_tool_calls(contents) is True
|
|
|
|
|
|
class TestShouldSuppressIntermediateSnapshot:
|
|
"""Tests for _should_suppress_intermediate_snapshot function."""
|
|
|
|
def test_no_tool_name(self):
|
|
"""Returns False when no tool name."""
|
|
result = _should_suppress_intermediate_snapshot(
|
|
None, {"key": {"tool": "write_doc", "tool_argument": "content"}}, False
|
|
)
|
|
assert result is False
|
|
|
|
def test_no_config(self):
|
|
"""Returns False when no config."""
|
|
result = _should_suppress_intermediate_snapshot("write_doc", None, False)
|
|
assert result is False
|
|
|
|
def test_confirmation_required(self):
|
|
"""Returns False when confirmation is required."""
|
|
config = {"key": {"tool": "write_doc", "tool_argument": "content"}}
|
|
result = _should_suppress_intermediate_snapshot("write_doc", config, True)
|
|
assert result is False
|
|
|
|
def test_tool_not_in_config(self):
|
|
"""Returns False when tool not in config."""
|
|
config = {"key": {"tool": "other_tool", "tool_argument": "content"}}
|
|
result = _should_suppress_intermediate_snapshot("write_doc", config, False)
|
|
assert result is False
|
|
|
|
def test_suppresses_predictive_tool(self):
|
|
"""Returns True for predictive tool without confirmation."""
|
|
config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
|
|
result = _should_suppress_intermediate_snapshot("write_doc", config, False)
|
|
assert result is True
|
|
|
|
|
|
class TestFlowState:
|
|
"""Tests for FlowState dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Tests default initialization."""
|
|
flow = FlowState()
|
|
assert flow.message_id is None
|
|
assert flow.tool_call_id is None
|
|
assert flow.tool_call_name is None
|
|
assert flow.waiting_for_approval is False
|
|
assert flow.current_state == {}
|
|
assert flow.accumulated_text == ""
|
|
assert flow.pending_tool_calls == []
|
|
assert flow.tool_calls_by_id == {}
|
|
assert flow.tool_results == []
|
|
assert flow.tool_calls_ended == set()
|
|
assert flow.interrupts == []
|
|
|
|
def test_get_tool_name(self):
|
|
"""Tests get_tool_name method."""
|
|
flow = FlowState()
|
|
flow.tool_calls_by_id = {"call_123": {"function": {"name": "get_weather", "arguments": "{}"}}}
|
|
|
|
assert flow.get_tool_name("call_123") == "get_weather"
|
|
assert flow.get_tool_name("nonexistent") is None
|
|
assert flow.get_tool_name(None) is None
|
|
|
|
def test_get_tool_name_empty_name(self):
|
|
"""Tests get_tool_name with empty name."""
|
|
flow = FlowState()
|
|
flow.tool_calls_by_id = {"call_123": {"function": {"name": "", "arguments": "{}"}}}
|
|
|
|
assert flow.get_tool_name("call_123") is None
|
|
|
|
def test_get_pending_without_end(self):
|
|
"""Tests get_pending_without_end method."""
|
|
flow = FlowState()
|
|
flow.pending_tool_calls = [
|
|
{"id": "call_1", "function": {"name": "tool1"}},
|
|
{"id": "call_2", "function": {"name": "tool2"}},
|
|
{"id": "call_3", "function": {"name": "tool3"}},
|
|
]
|
|
flow.tool_calls_ended = {"call_1", "call_3"}
|
|
|
|
result = flow.get_pending_without_end()
|
|
assert len(result) == 1
|
|
assert result[0]["id"] == "call_2"
|
|
|
|
|
|
class TestNormalizeResponseStream:
|
|
"""Tests for _normalize_response_stream helper."""
|
|
|
|
async def test_accepts_response_stream(self):
|
|
"""Accept standard ResponseStream values."""
|
|
|
|
async def _stream():
|
|
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
|
|
|
stream = await _normalize_response_stream(ResponseStream(_stream()))
|
|
updates = [update async for update in stream]
|
|
|
|
assert len(updates) == 1
|
|
assert updates[0].contents[0].text == "hello"
|
|
|
|
async def test_accepts_async_iterable(self):
|
|
"""Accept workflow-style async generator streams."""
|
|
|
|
async def _stream():
|
|
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
|
|
|
stream = await _normalize_response_stream(_stream())
|
|
updates = [update async for update in stream]
|
|
|
|
assert len(updates) == 1
|
|
assert updates[0].contents[0].text == "hello"
|
|
|
|
async def test_accepts_awaitable_resolving_to_async_iterable(self):
|
|
"""Accept awaitables that resolve to async iterable streams."""
|
|
|
|
async def _stream():
|
|
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
|
|
|
async def _resolve():
|
|
return _stream()
|
|
|
|
stream = await _normalize_response_stream(_resolve())
|
|
updates = [update async for update in stream]
|
|
|
|
assert len(updates) == 1
|
|
assert updates[0].contents[0].text == "hello"
|
|
|
|
async def test_rejects_non_stream_values(self):
|
|
"""Reject unsupported stream return values."""
|
|
with pytest.raises(AgentInvalidResponseException):
|
|
await _normalize_response_stream("not-a-stream")
|
|
|
|
|
|
class TestCreateStateContextMessage:
|
|
"""Tests for _create_state_context_message function."""
|
|
|
|
def test_no_state(self):
|
|
"""Returns None when no state."""
|
|
result = _create_state_context_message({}, {"properties": {}})
|
|
assert result is None
|
|
|
|
def test_no_schema(self):
|
|
"""Returns None when no schema."""
|
|
result = _create_state_context_message({"key": "value"}, {})
|
|
assert result is None
|
|
|
|
def test_creates_message(self):
|
|
"""Creates state context message."""
|
|
|
|
state = {"document": "Hello world"}
|
|
schema = {"properties": {"document": {"type": "string"}}}
|
|
|
|
result = _create_state_context_message(state, schema)
|
|
|
|
assert result is not None
|
|
assert result.role == "system"
|
|
assert len(result.contents) == 1
|
|
assert "Hello world" in result.contents[0].text
|
|
assert "Current state" in result.contents[0].text
|
|
|
|
|
|
class TestInjectStateContext:
|
|
"""Tests for _inject_state_context function."""
|
|
|
|
def test_no_state_message(self):
|
|
"""Returns original messages when no state context needed."""
|
|
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
|
|
result = _inject_state_context(messages, {}, {})
|
|
assert result == messages
|
|
|
|
def test_empty_messages(self):
|
|
"""Returns empty list for empty messages."""
|
|
result = _inject_state_context([], {"key": "value"}, {"properties": {}})
|
|
assert result == []
|
|
|
|
def test_last_message_not_user(self):
|
|
"""Returns original messages when last message is not from user."""
|
|
messages = [
|
|
Message(role="user", contents=[Content.from_text("Hello")]),
|
|
Message(role="assistant", contents=[Content.from_text("Hi")]),
|
|
]
|
|
state = {"key": "value"}
|
|
schema = {"properties": {"key": {"type": "string"}}}
|
|
|
|
result = _inject_state_context(messages, state, schema)
|
|
assert result == messages
|
|
|
|
def test_injects_before_last_user_message(self):
|
|
"""Injects state context before last user message."""
|
|
|
|
messages = [
|
|
Message(role="system", contents=[Content.from_text("You are helpful")]),
|
|
Message(role="user", contents=[Content.from_text("Hello")]),
|
|
]
|
|
state = {"document": "content"}
|
|
schema = {"properties": {"document": {"type": "string"}}}
|
|
|
|
result = _inject_state_context(messages, state, schema)
|
|
|
|
assert len(result) == 3
|
|
# System message first
|
|
assert result[0].role == "system"
|
|
assert "helpful" in result[0].contents[0].text
|
|
# State context second
|
|
assert result[1].role == "system"
|
|
assert "Current state" in result[1].contents[0].text
|
|
# User message last
|
|
assert result[2].role == "user"
|
|
assert "Hello" in result[2].contents[0].text
|
|
|
|
|
|
# Additional tests for _agent_run.py functions
|
|
|
|
|
|
def test_emit_text_basic():
|
|
"""Test _emit_text emits correct events."""
|
|
flow = FlowState()
|
|
content = Content.from_text("Hello world")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert len(events) == 2 # TextMessageStartEvent + TextMessageContentEvent
|
|
assert flow.message_id is not None
|
|
assert flow.accumulated_text == "Hello world"
|
|
|
|
|
|
def test_emit_text_skip_empty():
|
|
"""Test _emit_text skips empty text."""
|
|
flow = FlowState()
|
|
content = Content.from_text("")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert len(events) == 0
|
|
|
|
|
|
def test_emit_text_continues_existing_message():
|
|
"""Test _emit_text continues existing message."""
|
|
flow = FlowState()
|
|
flow.message_id = "existing-id"
|
|
content = Content.from_text("more text")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert len(events) == 1 # Only TextMessageContentEvent, no new start
|
|
assert flow.message_id == "existing-id"
|
|
|
|
|
|
def test_emit_text_skips_duplicate_full_message_delta():
|
|
"""Test _emit_text skips replayed full-message chunks on an open message."""
|
|
flow = FlowState()
|
|
flow.message_id = "existing-id"
|
|
flow.accumulated_text = "Case complete."
|
|
content = Content.from_text("Case complete.")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert events == []
|
|
assert flow.accumulated_text == "Case complete."
|
|
|
|
|
|
def test_emit_text_skips_when_waiting_for_approval():
|
|
"""Test _emit_text skips when waiting for approval."""
|
|
flow = FlowState()
|
|
flow.waiting_for_approval = True
|
|
content = Content.from_text("should skip")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert len(events) == 0
|
|
|
|
|
|
def test_emit_text_skips_when_skip_text_flag():
|
|
"""Test _emit_text skips with skip_text flag."""
|
|
flow = FlowState()
|
|
content = Content.from_text("should skip")
|
|
|
|
events = _emit_text(content, flow, skip_text=True)
|
|
|
|
assert len(events) == 0
|
|
|
|
|
|
def test_emit_tool_call_basic():
|
|
"""Test _emit_tool_call emits correct events."""
|
|
flow = FlowState()
|
|
content = Content.from_function_call(
|
|
call_id="call_123",
|
|
name="get_weather",
|
|
arguments='{"city": "NYC"}',
|
|
)
|
|
|
|
events = _emit_tool_call(content, flow)
|
|
|
|
assert len(events) >= 1 # At least ToolCallStartEvent
|
|
assert flow.tool_call_id == "call_123"
|
|
assert flow.tool_call_name == "get_weather"
|
|
|
|
|
|
def test_emit_tool_call_generates_id():
|
|
"""Test _emit_tool_call generates ID when not provided."""
|
|
flow = FlowState()
|
|
# Create content without call_id
|
|
content = Content(type="function_call", name="test_tool", arguments="{}")
|
|
|
|
events = _emit_tool_call(content, flow)
|
|
|
|
assert len(events) >= 1
|
|
assert flow.tool_call_id is not None # ID should be generated
|
|
|
|
|
|
def test_emit_tool_call_skips_duplicate_full_arguments_replay():
|
|
"""Test _emit_tool_call skips replayed full-arguments on an existing tool call.
|
|
|
|
This is a regression test for issue #4194 where some streaming providers
|
|
send the full arguments string again after streaming deltas, causing the
|
|
arguments to be doubled in MESSAGES_SNAPSHOT events.
|
|
|
|
Mirrors test_emit_text_skips_duplicate_full_message_delta for consistency.
|
|
"""
|
|
flow = FlowState()
|
|
full_args = '{"city": "Seattle"}'
|
|
|
|
# Step 1: Initial tool call with name + arguments (normal start)
|
|
content_start = Content.from_function_call(
|
|
call_id="call_dup",
|
|
name="get_weather",
|
|
arguments=full_args,
|
|
)
|
|
events_start = _emit_tool_call(content_start, flow)
|
|
|
|
# Should emit ToolCallStartEvent + ToolCallArgsEvent
|
|
assert any(isinstance(e, ToolCallArgsEvent) for e in events_start)
|
|
assert flow.tool_calls_by_id["call_dup"]["function"]["arguments"] == full_args
|
|
|
|
# Step 2: Provider replays the full arguments (duplicate)
|
|
content_replay = Content(type="function_call", call_id="call_dup", arguments=full_args)
|
|
events_replay = _emit_tool_call(content_replay, flow)
|
|
|
|
# Should NOT emit any ToolCallArgsEvent (early return on replay)
|
|
args_events = [e for e in events_replay if isinstance(e, ToolCallArgsEvent)]
|
|
assert args_events == [], "Duplicate full-arguments replay should not emit ToolCallArgsEvent"
|
|
|
|
# Accumulated arguments should remain unchanged
|
|
assert flow.tool_calls_by_id["call_dup"]["function"]["arguments"] == full_args
|
|
|
|
|
|
def test_emit_tool_result_closes_open_message():
|
|
"""Test _emit_tool_result emits TextMessageEndEvent for open text message.
|
|
|
|
This is a regression test for where TEXT_MESSAGE_END was not
|
|
emitted when using MCP tools because the message_id was reset without
|
|
closing the message first.
|
|
"""
|
|
flow = FlowState()
|
|
# Simulate an open text message (e.g., from Feature #4 tool-only detection)
|
|
flow.message_id = "open-msg-123"
|
|
flow.tool_call_id = "call_456"
|
|
|
|
content = Content.from_function_result(call_id="call_456", result="tool result")
|
|
|
|
events = _emit_tool_result(content, flow, predictive_handler=None)
|
|
|
|
# Should have: ToolCallEndEvent, ToolCallResultEvent, TextMessageEndEvent
|
|
assert len(events) == 3
|
|
|
|
# Verify TextMessageEndEvent is emitted for the open message
|
|
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
|
|
assert len(text_end_events) == 1
|
|
assert text_end_events[0].message_id == "open-msg-123"
|
|
|
|
# Verify message_id is reset after
|
|
assert flow.message_id is None
|
|
|
|
|
|
def test_emit_tool_result_no_open_message():
|
|
"""Test _emit_tool_result works when there's no open text message."""
|
|
flow = FlowState()
|
|
# No open message
|
|
flow.message_id = None
|
|
flow.tool_call_id = "call_456"
|
|
|
|
content = Content.from_function_result(call_id="call_456", result="tool result")
|
|
|
|
events = _emit_tool_result(content, flow, predictive_handler=None)
|
|
|
|
# Should have: ToolCallEndEvent, ToolCallResultEvent (no TextMessageEndEvent)
|
|
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
|
|
assert len(text_end_events) == 0
|
|
|
|
|
|
def test_emit_tool_result_serializes_non_string_result():
|
|
"""Non-string tool results should be serialized before emitting TOOL_CALL_RESULT."""
|
|
flow = FlowState()
|
|
content = Content.from_function_result(call_id="call_789", result={"ok": True, "items": [1, 2]})
|
|
|
|
events = _emit_tool_result(content, flow, predictive_handler=None)
|
|
result_event = next(event for event in events if getattr(event, "type", None) == "TOOL_CALL_RESULT")
|
|
|
|
assert isinstance(result_event.content, str)
|
|
assert '"ok": true' in result_event.content
|
|
assert flow.tool_results[0]["content"] == result_event.content
|
|
|
|
|
|
def test_emit_content_usage_emits_custom_usage_event():
|
|
"""Usage content should be emitted as a custom usage event."""
|
|
flow = FlowState()
|
|
content = Content.from_usage({"input_token_count": 3, "output_token_count": 2, "total_token_count": 5})
|
|
|
|
events = _emit_content(content, flow)
|
|
|
|
assert len(events) == 1
|
|
assert events[0].type == "CUSTOM"
|
|
assert events[0].name == "usage"
|
|
assert events[0].value["total_token_count"] == 5
|
|
|
|
|
|
def test_emit_approval_request_populates_interrupt_metadata():
|
|
"""Approval requests should populate FlowState interrupts for RUN_FINISHED metadata."""
|
|
flow = FlowState(message_id="msg-1")
|
|
function_call = Content.from_function_call(call_id="call_123", name="write_doc", arguments={"content": "x"})
|
|
approval_content = Content.from_function_approval_request(id="approval_1", function_call=function_call)
|
|
|
|
_emit_approval_request(approval_content, flow)
|
|
|
|
assert flow.waiting_for_approval is True
|
|
assert len(flow.interrupts) == 1
|
|
assert flow.interrupts[0]["id"] == "call_123"
|
|
assert flow.interrupts[0]["value"]["type"] == "function_approval_request"
|
|
|
|
|
|
def test_emit_approval_request_accumulates_multiple_interrupts():
|
|
"""Multiple approval requests in the same turn should accumulate in flow.interrupts."""
|
|
flow = FlowState(message_id="msg-1")
|
|
|
|
for i in range(1, 4):
|
|
function_call = Content.from_function_call(
|
|
call_id=f"call_{i}",
|
|
name=f"tool_{i}",
|
|
arguments={"arg": f"value_{i}"},
|
|
)
|
|
approval_content = Content.from_function_approval_request(
|
|
id=f"approval_{i}",
|
|
function_call=function_call,
|
|
)
|
|
_emit_approval_request(approval_content, flow)
|
|
|
|
assert len(flow.interrupts) == 3
|
|
interrupt_ids = {intr["id"] for intr in flow.interrupts}
|
|
assert interrupt_ids == {"call_1", "call_2", "call_3"}
|
|
|
|
|
|
def test_resume_to_tool_messages_from_interrupts_payload():
|
|
"""Resume payload interrupt responses map to tool messages."""
|
|
resume = {
|
|
"interrupts": [
|
|
{"id": "req_1", "value": {"accepted": True, "steps": []}},
|
|
{"id": "req_2", "value": "plain value"},
|
|
]
|
|
}
|
|
|
|
messages = _resume_to_tool_messages(resume)
|
|
assert len(messages) == 2
|
|
assert messages[0]["role"] == "tool"
|
|
assert messages[0]["toolCallId"] == "req_1"
|
|
assert '"accepted": true' in messages[0]["content"]
|
|
assert messages[1]["content"] == "plain value"
|
|
|
|
|
|
def test_extract_resume_payload_prefers_top_level_resume():
|
|
"""Top-level resume should take precedence over forwarded props."""
|
|
payload = {
|
|
"resume": {"interrupts": [{"id": "req_1", "value": "approved"}]},
|
|
"forwarded_props": {"command": {"resume": "ignored"}},
|
|
}
|
|
|
|
result = _extract_resume_payload(payload)
|
|
assert result == {"interrupts": [{"id": "req_1", "value": "approved"}]}
|
|
|
|
|
|
def test_extract_resume_payload_reads_forwarded_command_resume():
|
|
"""Forwarded command.resume should be treated as a resume payload."""
|
|
payload = {
|
|
"forwarded_props": {
|
|
"command": {"resume": '{"airline":"KLM","departure":"Amsterdam (AMS)","arrival":"San Francisco (SFO)"}'}
|
|
}
|
|
}
|
|
|
|
result = _extract_resume_payload(payload)
|
|
assert isinstance(result, str)
|
|
assert "KLM" in result
|
|
|
|
|
|
def test_build_run_finished_event_with_interrupt():
|
|
"""RUN_FINISHED helper should preserve interrupt payloads."""
|
|
event = _build_run_finished_event("run-1", "thread-1", interrupts=[{"id": "req_1", "value": {"x": 1}}])
|
|
dumped = event.model_dump()
|
|
|
|
assert dumped["run_id"] == "run-1"
|
|
assert dumped["thread_id"] == "thread-1"
|
|
assert dumped["interrupt"] == [{"id": "req_1", "value": {"x": 1}}]
|
|
|
|
|
|
def test_extract_approved_state_updates_no_handler():
|
|
"""Test _extract_approved_state_updates returns empty with no handler."""
|
|
from agent_framework_ag_ui._agent_run import _extract_approved_state_updates
|
|
|
|
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
|
|
result = _extract_approved_state_updates(messages, None)
|
|
assert result == {}
|
|
|
|
|
|
def test_extract_approved_state_updates_no_approval():
|
|
"""Test _extract_approved_state_updates returns empty when no approval content."""
|
|
from agent_framework_ag_ui._agent_run import _extract_approved_state_updates
|
|
from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler
|
|
|
|
handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "content"}})
|
|
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
|
|
result = _extract_approved_state_updates(messages, handler)
|
|
assert result == {}
|
|
|
|
|
|
class TestBuildMessagesSnapshot:
|
|
"""Tests for _build_messages_snapshot function."""
|
|
|
|
def test_tool_calls_and_text_are_separate_messages(self):
|
|
"""Test that tool calls and text content are emitted as separate messages.
|
|
|
|
This is a regression test for issue #3619 where tool calls and content
|
|
were incorrectly merged into a single assistant message.
|
|
"""
|
|
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.message_id = "msg-123"
|
|
flow.pending_tool_calls = [
|
|
{"id": "call_1", "function": {"name": "get_weather", "arguments": '{"city": "NYC"}'}},
|
|
]
|
|
flow.accumulated_text = "Here is the weather information."
|
|
flow.tool_results = [{"id": "result-1", "role": "tool", "content": '{"temp": 72}', "toolCallId": "call_1"}]
|
|
|
|
result = _build_messages_snapshot(flow, [])
|
|
|
|
# Should have 3 messages: tool call msg, tool result, text content msg
|
|
assert len(result.messages) == 3
|
|
|
|
# First message: assistant with tool calls only (no content)
|
|
assistant_tool_msg = result.messages[0]
|
|
assert assistant_tool_msg.role == "assistant"
|
|
assert assistant_tool_msg.tool_calls is not None
|
|
assert len(assistant_tool_msg.tool_calls) == 1
|
|
assert assistant_tool_msg.content is None
|
|
|
|
# Second message: tool result
|
|
tool_result_msg = result.messages[1]
|
|
assert tool_result_msg.role == "tool"
|
|
|
|
# Third message: assistant with content only (no tool calls)
|
|
assistant_text_msg = result.messages[2]
|
|
assert assistant_text_msg.role == "assistant"
|
|
assert assistant_text_msg.content == "Here is the weather information."
|
|
assert assistant_text_msg.tool_calls is None
|
|
|
|
# The text message should have a different ID than the tool call message
|
|
assert assistant_text_msg.id != assistant_tool_msg.id
|
|
|
|
def test_only_tool_calls_no_text(self):
|
|
"""Test snapshot with only tool calls and no accumulated text."""
|
|
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.message_id = "msg-123"
|
|
flow.pending_tool_calls = [
|
|
{"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
|
|
]
|
|
flow.accumulated_text = ""
|
|
flow.tool_results = []
|
|
|
|
result = _build_messages_snapshot(flow, [])
|
|
|
|
# Should have 1 message: tool call msg only
|
|
assert len(result.messages) == 1
|
|
assert result.messages[0].role == "assistant"
|
|
assert result.messages[0].tool_calls is not None
|
|
assert result.messages[0].content is None
|
|
|
|
def test_only_text_no_tool_calls(self):
|
|
"""Test snapshot with only text and no tool calls."""
|
|
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.message_id = "msg-123"
|
|
flow.pending_tool_calls = []
|
|
flow.accumulated_text = "Hello world"
|
|
flow.tool_results = []
|
|
|
|
result = _build_messages_snapshot(flow, [])
|
|
|
|
# Should have 1 message: text content msg only
|
|
assert len(result.messages) == 1
|
|
assert result.messages[0].role == "assistant"
|
|
assert result.messages[0].content == "Hello world"
|
|
assert result.messages[0].tool_calls is None
|
|
# Should use the existing message_id
|
|
assert result.messages[0].id == "msg-123"
|
|
|
|
def test_preserves_snapshot_messages(self):
|
|
"""Test that existing snapshot messages are preserved."""
|
|
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.pending_tool_calls = []
|
|
flow.accumulated_text = ""
|
|
|
|
existing_messages = [
|
|
{"id": "user-1", "role": "user", "content": "Hello"},
|
|
{"id": "assist-1", "role": "assistant", "content": "Hi there"},
|
|
]
|
|
|
|
result = _build_messages_snapshot(flow, existing_messages)
|
|
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].id == "user-1"
|
|
assert result.messages[1].id == "assist-1"
|
|
|
|
|
|
def test_malformed_json_in_confirm_args_skips_confirmation():
|
|
"""Test that malformed JSON in tool arguments skips confirm_changes flow.
|
|
|
|
This is a regression test to ensure that when tool arguments contain malformed
|
|
JSON, the code skips the confirmation flow entirely rather than crashing or
|
|
showing incomplete data to the user.
|
|
"""
|
|
import json
|
|
|
|
# Simulate the parsing logic - malformed JSON should trigger skip
|
|
malformed_arguments = "{ invalid json }"
|
|
tool_call = {"function": {"name": "write_doc", "arguments": malformed_arguments}}
|
|
|
|
# This is what the code should do - detect parsing failure and skip
|
|
should_skip_confirmation = False
|
|
try:
|
|
json.loads(tool_call.get("function", {}).get("arguments", "{}"))
|
|
except json.JSONDecodeError:
|
|
should_skip_confirmation = True
|
|
|
|
# Should skip confirmation when JSON is malformed
|
|
assert should_skip_confirmation is True
|
|
|
|
# Valid JSON should proceed with confirmation
|
|
valid_arguments = '{"content": "hello"}'
|
|
tool_call_valid = {"function": {"name": "write_doc", "arguments": valid_arguments}}
|
|
should_skip_confirmation = False
|
|
try:
|
|
function_arguments = json.loads(tool_call_valid.get("function", {}).get("arguments", "{}"))
|
|
except json.JSONDecodeError:
|
|
should_skip_confirmation = True
|
|
|
|
assert should_skip_confirmation is False
|
|
assert function_arguments == {"content": "hello"}
|
|
|
|
|
|
class TestTextMessageEventBalancing:
|
|
"""Tests for proper TEXT_MESSAGE_START/END event balancing.
|
|
|
|
These tests verify that the streaming flow produces balanced pairs of
|
|
TextMessageStartEvent and TextMessageEndEvent, especially when tool
|
|
execution is involved.
|
|
"""
|
|
|
|
def test_tool_only_flow_produces_balanced_events(self):
|
|
"""Test that a tool-only response produces balanced TEXT_MESSAGE events.
|
|
|
|
This simulates the scenario where the LLM immediately calls a tool
|
|
without any initial text, then returns text after the tool result.
|
|
"""
|
|
flow = FlowState()
|
|
all_events: list = []
|
|
|
|
# Step 1: LLM outputs function_call only (no text)
|
|
func_call_content = Content.from_function_call(
|
|
call_id="call_weather",
|
|
name="get_weather",
|
|
arguments='{"city": "Seattle"}',
|
|
)
|
|
|
|
# Feature #4 check: this should trigger TextMessageStartEvent
|
|
contents = [func_call_content]
|
|
if not flow.message_id and _has_only_tool_calls(contents):
|
|
flow.message_id = "tool-msg-1"
|
|
all_events.append(TextMessageStartEvent(message_id=flow.message_id, role="assistant"))
|
|
|
|
# Emit tool call events
|
|
all_events.extend(_emit_content(func_call_content, flow))
|
|
|
|
# Step 2: Tool executes and returns result
|
|
func_result_content = Content.from_function_result(
|
|
call_id="call_weather",
|
|
result='{"temp": 55, "conditions": "rainy"}',
|
|
)
|
|
|
|
# This should close the text message
|
|
all_events.extend(_emit_tool_result(func_result_content, flow))
|
|
|
|
# Verify message_id was reset
|
|
assert flow.message_id is None, "message_id should be reset after tool result"
|
|
|
|
# Step 3: LLM outputs text response
|
|
text_content = Content.from_text("The weather in Seattle is 55°F and rainy.")
|
|
|
|
# Since message_id is None, _emit_text should create a new one
|
|
for event in _emit_content(text_content, flow):
|
|
all_events.append(event)
|
|
|
|
# Step 4: End of stream - emit final TextMessageEndEvent
|
|
if flow.message_id:
|
|
all_events.append(TextMessageEndEvent(message_id=flow.message_id))
|
|
|
|
# Verify event counts
|
|
start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)]
|
|
end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
|
|
|
|
# Should have 2 TextMessageStartEvent and 2 TextMessageEndEvent
|
|
assert len(start_events) == 2, f"Expected 2 start events, got {len(start_events)}"
|
|
assert len(end_events) == 2, f"Expected 2 end events, got {len(end_events)}"
|
|
|
|
# Verify order: first message should start and end before second starts
|
|
# Find indices
|
|
start_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageStartEvent)]
|
|
end_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageEndEvent)]
|
|
|
|
# First end should come before second start
|
|
assert end_indices[0] < start_indices[1], (
|
|
f"First TextMessageEndEvent (index {end_indices[0]}) should come "
|
|
f"before second TextMessageStartEvent (index {start_indices[1]})"
|
|
)
|
|
|
|
def test_text_then_tool_flow(self):
|
|
"""Test flow where LLM outputs text first, then calls a tool.
|
|
|
|
This simulates: "Let me check the weather..." -> tool call -> tool result -> "The weather is..."
|
|
"""
|
|
flow = FlowState()
|
|
all_events: list = []
|
|
|
|
# Step 1: LLM outputs text first
|
|
text1 = Content.from_text("Let me check the weather for you.")
|
|
all_events.extend(_emit_content(text1, flow))
|
|
|
|
# Verify message_id is set
|
|
assert flow.message_id is not None, "message_id should be set after text"
|
|
first_msg_id = flow.message_id
|
|
|
|
# Step 2: LLM outputs function_call
|
|
func_call = Content.from_function_call(
|
|
call_id="call_1",
|
|
name="get_weather",
|
|
arguments="{}",
|
|
)
|
|
all_events.extend(_emit_content(func_call, flow))
|
|
|
|
# Step 3: Tool result comes back
|
|
func_result = Content.from_function_result(call_id="call_1", result="sunny")
|
|
all_events.extend(_emit_tool_result(func_result, flow))
|
|
|
|
# Verify message_id was reset and first message was closed
|
|
assert flow.message_id is None
|
|
end_events_so_far = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
|
|
assert len(end_events_so_far) == 1
|
|
assert end_events_so_far[0].message_id == first_msg_id
|
|
|
|
# Step 4: LLM outputs follow-up text
|
|
text2 = Content.from_text("The weather is sunny!")
|
|
all_events.extend(_emit_content(text2, flow))
|
|
|
|
# Step 5: End of stream
|
|
if flow.message_id:
|
|
all_events.append(TextMessageEndEvent(message_id=flow.message_id))
|
|
|
|
# Verify balance
|
|
start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)]
|
|
end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
|
|
|
|
assert len(start_events) == 2
|
|
assert len(end_events) == 2
|
|
|
|
|
|
async def test_run_agent_stream_accumulates_multiple_confirm_interrupts():
|
|
"""Multiple predictive tool calls in a single streaming run should accumulate interrupts.
|
|
|
|
This exercises the confirm_changes path in run_agent_stream (_agent_run.py),
|
|
ensuring that flow.interrupts.append() works correctly for multiple tool calls
|
|
and all interrupts appear in the RUN_FINISHED event.
|
|
"""
|
|
import json
|
|
|
|
from conftest import StubAgent
|
|
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
predict_config = {
|
|
"tasks": {"tool": "generate_tasks", "tool_argument": "steps"},
|
|
"notes": {"tool": "generate_notes", "tool_argument": "items"},
|
|
}
|
|
state_schema = {
|
|
"tasks": {"type": "array", "items": {"type": "object"}},
|
|
"notes": {"type": "array", "items": {"type": "object"}},
|
|
}
|
|
|
|
updates = [
|
|
AgentResponseUpdate(
|
|
contents=[
|
|
Content.from_function_call(
|
|
name="generate_tasks",
|
|
call_id="call-tasks",
|
|
arguments=json.dumps({"steps": [{"description": "Task 1"}]}),
|
|
),
|
|
Content.from_function_call(
|
|
name="generate_notes",
|
|
call_id="call-notes",
|
|
arguments=json.dumps({"items": [{"description": "Note 1"}]}),
|
|
),
|
|
],
|
|
role="assistant",
|
|
),
|
|
]
|
|
|
|
stub = StubAgent(updates=updates)
|
|
agent = AgentFrameworkAgent(
|
|
agent=stub,
|
|
state_schema=state_schema,
|
|
predict_state_config=predict_config,
|
|
require_confirmation=True,
|
|
)
|
|
|
|
payload = {
|
|
"thread_id": "thread-multi",
|
|
"run_id": "run-multi",
|
|
"messages": [{"role": "user", "content": "Generate tasks and notes"}],
|
|
"state": {"tasks": [], "notes": []},
|
|
}
|
|
|
|
events = [event async for event in agent.run(payload)]
|
|
|
|
# Find RUN_FINISHED event and verify multiple interrupts
|
|
finished_events = [
|
|
e
|
|
for e in events
|
|
if getattr(e, "type", None) == "RUN_FINISHED"
|
|
or getattr(getattr(e, "type", None), "value", None) == "RUN_FINISHED"
|
|
]
|
|
assert finished_events, f"Expected RUN_FINISHED event. Types: {[getattr(e, 'type', None) for e in events]}"
|
|
finished = finished_events[-1]
|
|
interrupt = getattr(finished, "interrupt", None)
|
|
assert interrupt is not None, "Expected interrupt metadata in RUN_FINISHED"
|
|
assert len(interrupt) == 2, f"Expected 2 interrupts (one per tool), got {len(interrupt)}"
|
|
|
|
# Verify both tool calls are represented in interrupt metadata
|
|
interrupt_tool_names = {i["value"]["function_call"]["name"] for i in interrupt}
|
|
assert interrupt_tool_names == {"generate_tasks", "generate_notes"}
|
|
|
|
|
|
def test_emit_oauth_consent_request():
|
|
"""Test that oauth_consent_request content emits a CustomEvent."""
|
|
content = Content.from_oauth_consent_request(
|
|
consent_link="https://login.microsoftonline.com/consent",
|
|
)
|
|
flow = FlowState()
|
|
events = _emit_content(content, flow)
|
|
|
|
assert len(events) == 1
|
|
assert isinstance(events[0], CustomEvent)
|
|
assert events[0].name == "oauth_consent_request"
|
|
assert events[0].value == {"consent_link": "https://login.microsoftonline.com/consent"}
|
|
|
|
|
|
def test_emit_oauth_consent_request_no_link():
|
|
"""Test that oauth_consent_request without a consent_link emits no events."""
|
|
content = Content("oauth_consent_request")
|
|
flow = FlowState()
|
|
events = _emit_content(content, flow)
|
|
|
|
assert len(events) == 0
|
|
|
|
|
|
# ============================================================================
|
|
# Tests for MCP tool call, MCP tool result, and text reasoning event emission
|
|
# ============================================================================
|
|
|
|
|
|
class TestEmitMcpToolCall:
|
|
"""Tests for _emit_mcp_tool_call function."""
|
|
|
|
def test_produces_start_and_args_events(self):
|
|
"""MCP tool call emits ToolCallStart + ToolCallArgs events."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_call(
|
|
call_id="mcp_call_1",
|
|
tool_name="search",
|
|
server_name="brave",
|
|
arguments={"query": "weather"},
|
|
)
|
|
|
|
events = _emit_mcp_tool_call(content, flow)
|
|
|
|
assert len(events) == 2
|
|
assert events[0].type == "TOOL_CALL_START"
|
|
assert events[0].tool_call_id == "mcp_call_1"
|
|
assert events[0].tool_call_name == "search"
|
|
assert events[1].type == "TOOL_CALL_ARGS"
|
|
assert events[1].tool_call_id == "mcp_call_1"
|
|
assert "weather" in events[1].delta
|
|
|
|
def test_tracks_in_flow_state(self):
|
|
"""MCP tool call is tracked in flow.pending_tool_calls and tool_calls_by_id."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_call(
|
|
call_id="mcp_call_2",
|
|
tool_name="get_file",
|
|
arguments='{"path": "/tmp/test.txt"}',
|
|
)
|
|
|
|
_emit_mcp_tool_call(content, flow)
|
|
|
|
assert len(flow.pending_tool_calls) == 1
|
|
assert flow.pending_tool_calls[0]["id"] == "mcp_call_2"
|
|
assert "mcp_call_2" in flow.tool_calls_by_id
|
|
assert flow.tool_calls_by_id["mcp_call_2"]["function"]["name"] == "get_file"
|
|
assert flow.tool_calls_by_id["mcp_call_2"]["function"]["arguments"] == '{"path": "/tmp/test.txt"}'
|
|
|
|
def test_no_server_name_uses_tool_name_only(self):
|
|
"""Without server_name, display name is just tool_name."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_call(
|
|
call_id="mcp_call_3",
|
|
tool_name="list_files",
|
|
)
|
|
|
|
events = _emit_mcp_tool_call(content, flow)
|
|
|
|
assert events[0].tool_call_name == "list_files"
|
|
|
|
def test_no_arguments_skips_args_event(self):
|
|
"""No arguments produces only ToolCallStart, no ToolCallArgs."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_call(
|
|
call_id="mcp_call_4",
|
|
tool_name="ping",
|
|
)
|
|
|
|
events = _emit_mcp_tool_call(content, flow)
|
|
|
|
assert len(events) == 1
|
|
assert events[0].type == "TOOL_CALL_START"
|
|
|
|
def test_generates_id_when_missing(self):
|
|
"""A tool_call_id is generated when call_id is None."""
|
|
flow = FlowState()
|
|
content = Content(type="mcp_server_tool_call", tool_name="test_tool")
|
|
|
|
events = _emit_mcp_tool_call(content, flow)
|
|
|
|
assert len(events) >= 1
|
|
assert events[0].tool_call_id is not None
|
|
assert events[0].tool_call_id != ""
|
|
assert events[0].tool_call_name == "test_tool"
|
|
|
|
def test_missing_tool_name_falls_back_to_mcp_tool(self):
|
|
"""When tool_name is None, the fallback 'mcp_tool' is used."""
|
|
flow = FlowState()
|
|
content = Content(type="mcp_server_tool_call")
|
|
|
|
events = _emit_mcp_tool_call(content, flow)
|
|
|
|
assert len(events) >= 1
|
|
assert events[0].tool_call_name == "mcp_tool"
|
|
|
|
|
|
class TestEmitMcpToolResult:
|
|
"""Tests for _emit_mcp_tool_result function."""
|
|
|
|
def test_produces_end_and_result_events(self):
|
|
"""MCP tool result emits ToolCallEnd + ToolCallResult events."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_result(
|
|
call_id="mcp_call_1",
|
|
output={"results": [{"title": "Weather", "url": "https://example.com"}]},
|
|
)
|
|
|
|
events = _emit_mcp_tool_result(content, flow)
|
|
|
|
assert len(events) == 2
|
|
assert events[0].type == "TOOL_CALL_END"
|
|
assert events[0].tool_call_id == "mcp_call_1"
|
|
assert events[1].type == "TOOL_CALL_RESULT"
|
|
assert events[1].tool_call_id == "mcp_call_1"
|
|
assert "Weather" in events[1].content
|
|
|
|
def test_tracks_in_flow_state(self):
|
|
"""MCP tool result is tracked in flow.tool_results and tool_calls_ended."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_result(
|
|
call_id="mcp_call_5",
|
|
output="Success",
|
|
)
|
|
|
|
_emit_mcp_tool_result(content, flow)
|
|
|
|
assert "mcp_call_5" in flow.tool_calls_ended
|
|
assert len(flow.tool_results) == 1
|
|
assert flow.tool_results[0]["toolCallId"] == "mcp_call_5"
|
|
assert flow.tool_results[0]["content"] == "Success"
|
|
|
|
def test_no_call_id_returns_empty(self):
|
|
"""Missing call_id returns empty events list with a warning."""
|
|
flow = FlowState()
|
|
content = Content(type="mcp_server_tool_result", output="data")
|
|
|
|
events = _emit_mcp_tool_result(content, flow)
|
|
|
|
assert events == []
|
|
|
|
def test_serializes_non_string_output(self):
|
|
"""Non-string output is serialized to JSON."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_result(
|
|
call_id="mcp_call_6",
|
|
output={"key": "value", "count": 42},
|
|
)
|
|
|
|
events = _emit_mcp_tool_result(content, flow)
|
|
|
|
result_event = events[1]
|
|
assert isinstance(result_event.content, str)
|
|
assert '"key": "value"' in result_event.content
|
|
|
|
def test_output_none_falls_back_to_empty_string(self):
|
|
"""When output is None (default), the result content is an empty string."""
|
|
flow = FlowState()
|
|
content = Content(type="mcp_server_tool_result", call_id="mcp_call_none")
|
|
|
|
events = _emit_mcp_tool_result(content, flow)
|
|
|
|
assert len(events) == 2
|
|
assert events[1].type == "TOOL_CALL_RESULT"
|
|
assert events[1].content == ""
|
|
|
|
def test_resets_flow_state_like_emit_tool_result(self):
|
|
"""MCP tool result performs same FlowState cleanup as _emit_tool_result."""
|
|
flow = FlowState()
|
|
flow.tool_call_id = "mcp_call_7"
|
|
flow.tool_call_name = "brave/search"
|
|
flow.message_id = "open-msg-456"
|
|
flow.accumulated_text = "Let me search for that..."
|
|
|
|
content = Content.from_mcp_server_tool_result(
|
|
call_id="mcp_call_7",
|
|
output="search results",
|
|
)
|
|
|
|
events = _emit_mcp_tool_result(content, flow)
|
|
|
|
assert flow.tool_call_id is None
|
|
assert flow.tool_call_name is None
|
|
assert flow.message_id is None
|
|
assert flow.accumulated_text == ""
|
|
|
|
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
|
|
assert len(text_end_events) == 1
|
|
assert text_end_events[0].message_id == "open-msg-456"
|
|
|
|
def test_no_open_message_skips_text_end(self):
|
|
"""MCP tool result without open text message skips TextMessageEndEvent."""
|
|
flow = FlowState()
|
|
flow.message_id = None
|
|
|
|
content = Content.from_mcp_server_tool_result(
|
|
call_id="mcp_call_8",
|
|
output="result",
|
|
)
|
|
|
|
events = _emit_mcp_tool_result(content, flow)
|
|
|
|
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
|
|
assert len(text_end_events) == 0
|
|
|
|
def test_predictive_handler_emits_state_snapshot(self):
|
|
"""MCP tool result applies pending updates and emits StateSnapshotEvent when predictive_handler is set."""
|
|
from unittest.mock import MagicMock
|
|
|
|
from ag_ui.core import StateSnapshotEvent
|
|
|
|
flow = FlowState()
|
|
flow.current_state = {"doc": "hello"}
|
|
content = Content.from_mcp_server_tool_result(
|
|
call_id="mcp_call_9",
|
|
output="done",
|
|
)
|
|
|
|
handler = MagicMock()
|
|
events = _emit_mcp_tool_result(content, flow, predictive_handler=handler)
|
|
|
|
handler.apply_pending_updates.assert_called_once()
|
|
snapshot_events = [e for e in events if isinstance(e, StateSnapshotEvent)]
|
|
assert len(snapshot_events) == 1
|
|
assert snapshot_events[0].snapshot == {"doc": "hello"}
|
|
|
|
|
|
class TestEmitTextReasoning:
|
|
"""Tests for _emit_text_reasoning function."""
|
|
|
|
def test_produces_reasoning_events(self):
|
|
"""Text reasoning emits the full reasoning event sequence."""
|
|
content = Content.from_text_reasoning(
|
|
id="reason_1",
|
|
text="The user is asking about weather, so I should call the weather tool.",
|
|
)
|
|
|
|
events = _emit_text_reasoning(content)
|
|
|
|
assert len(events) == 5
|
|
assert isinstance(events[0], ReasoningStartEvent)
|
|
assert events[0].message_id == "reason_1"
|
|
assert isinstance(events[1], ReasoningMessageStartEvent)
|
|
assert events[1].message_id == "reason_1"
|
|
assert events[1].role == "assistant"
|
|
assert isinstance(events[2], ReasoningMessageContentEvent)
|
|
assert events[2].message_id == "reason_1"
|
|
assert events[2].delta == "The user is asking about weather, so I should call the weather tool."
|
|
assert isinstance(events[3], ReasoningMessageEndEvent)
|
|
assert events[3].message_id == "reason_1"
|
|
assert isinstance(events[4], ReasoningEndEvent)
|
|
assert events[4].message_id == "reason_1"
|
|
|
|
def test_protected_data_emits_encrypted_value_event(self):
|
|
"""protected_data is emitted as a ReasoningEncryptedValueEvent."""
|
|
content = Content.from_text_reasoning(
|
|
id="reason_2",
|
|
text="visible reasoning",
|
|
protected_data="encrypted metadata",
|
|
)
|
|
|
|
events = _emit_text_reasoning(content)
|
|
|
|
encrypted_events = [e for e in events if isinstance(e, ReasoningEncryptedValueEvent)]
|
|
assert len(encrypted_events) == 1
|
|
assert encrypted_events[0].subtype == "message"
|
|
assert encrypted_events[0].entity_id == "reason_2"
|
|
assert encrypted_events[0].encrypted_value == "encrypted metadata"
|
|
|
|
def test_protected_data_only_emits_event(self):
|
|
"""Content with only protected_data (no text) still emits reasoning events."""
|
|
content = Content.from_text_reasoning(
|
|
protected_data="encrypted reasoning content",
|
|
)
|
|
|
|
events = _emit_text_reasoning(content)
|
|
|
|
# Should have start, msg_start, msg_end, encrypted_value, end (no content event)
|
|
assert len(events) == 5
|
|
assert isinstance(events[0], ReasoningStartEvent)
|
|
assert isinstance(events[1], ReasoningMessageStartEvent)
|
|
assert isinstance(events[2], ReasoningMessageEndEvent)
|
|
assert isinstance(events[3], ReasoningEncryptedValueEvent)
|
|
assert events[3].encrypted_value == "encrypted reasoning content"
|
|
assert isinstance(events[4], ReasoningEndEvent)
|
|
|
|
def test_empty_text_and_no_protected_data_returns_empty(self):
|
|
"""Empty text and no protected_data returns no events."""
|
|
content = Content.from_text_reasoning()
|
|
|
|
events = _emit_text_reasoning(content)
|
|
|
|
assert events == []
|
|
|
|
def test_generates_message_id_when_missing(self):
|
|
"""When id is None, a message_id is generated."""
|
|
content = Content.from_text_reasoning(text="thinking...")
|
|
|
|
events = _emit_text_reasoning(content)
|
|
|
|
assert len(events) == 5
|
|
assert events[0].message_id is not None
|
|
assert events[0].message_id != ""
|
|
# All events share the same message_id
|
|
assert events[1].message_id == events[0].message_id
|
|
|
|
|
|
class TestEmitContentMcpRouting:
|
|
"""Tests that _emit_content correctly routes MCP and reasoning types."""
|
|
|
|
def test_routes_mcp_server_tool_call(self):
|
|
"""_emit_content dispatches mcp_server_tool_call to _emit_mcp_tool_call."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_call(
|
|
call_id="route_test_1",
|
|
tool_name="test_tool",
|
|
server_name="test_server",
|
|
)
|
|
|
|
events = _emit_content(content, flow)
|
|
|
|
assert len(events) >= 1
|
|
assert events[0].type == "TOOL_CALL_START"
|
|
assert events[0].tool_call_name == "test_tool"
|
|
|
|
def test_routes_mcp_server_tool_result(self):
|
|
"""_emit_content dispatches mcp_server_tool_result to _emit_mcp_tool_result."""
|
|
flow = FlowState()
|
|
content = Content.from_mcp_server_tool_result(
|
|
call_id="route_test_2",
|
|
output="result data",
|
|
)
|
|
|
|
events = _emit_content(content, flow)
|
|
|
|
assert len(events) == 2
|
|
assert events[0].type == "TOOL_CALL_END"
|
|
assert events[1].type == "TOOL_CALL_RESULT"
|
|
|
|
def test_routes_text_reasoning(self):
|
|
"""_emit_content dispatches text_reasoning to _emit_text_reasoning."""
|
|
flow = FlowState()
|
|
content = Content.from_text_reasoning(text="I need to think about this...")
|
|
|
|
events = _emit_content(content, flow)
|
|
|
|
# Streaming pattern: Start + MessageStart + Content (no End events yet)
|
|
assert len(events) == 3
|
|
assert isinstance(events[0], ReasoningStartEvent)
|
|
assert isinstance(events[1], ReasoningMessageStartEvent)
|
|
assert isinstance(events[2], ReasoningMessageContentEvent)
|
|
|
|
|
|
class TestReasoningInSnapshot:
|
|
"""Tests for reasoning message inclusion in MESSAGES_SNAPSHOT."""
|
|
|
|
def test_reasoning_persisted_to_flow_state(self):
|
|
"""_emit_text_reasoning with flow persists reasoning into flow.reasoning_messages."""
|
|
flow = FlowState()
|
|
content = Content.from_text_reasoning(
|
|
id="reason_persist",
|
|
text="Let me think step by step.",
|
|
)
|
|
|
|
_emit_text_reasoning(content, flow)
|
|
|
|
assert len(flow.reasoning_messages) == 1
|
|
assert flow.reasoning_messages[0]["id"] == "reason_persist"
|
|
assert flow.reasoning_messages[0]["role"] == "reasoning"
|
|
assert flow.reasoning_messages[0]["content"] == "Let me think step by step."
|
|
assert "encryptedValue" not in flow.reasoning_messages[0]
|
|
|
|
def test_reasoning_with_encrypted_value_persisted(self):
|
|
"""Reasoning with protected_data preserves encryptedValue in flow state."""
|
|
flow = FlowState()
|
|
content = Content.from_text_reasoning(
|
|
id="reason_enc",
|
|
text="visible reasoning",
|
|
protected_data="encrypted-data-123",
|
|
)
|
|
|
|
_emit_text_reasoning(content, flow)
|
|
|
|
assert len(flow.reasoning_messages) == 1
|
|
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-data-123"
|
|
|
|
def test_snapshot_includes_reasoning(self):
|
|
"""_build_messages_snapshot includes reasoning messages from flow state."""
|
|
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.accumulated_text = "Here is my answer."
|
|
flow.reasoning_messages = [
|
|
{"id": "r1", "role": "reasoning", "content": "Thinking..."},
|
|
]
|
|
|
|
snapshot = _build_messages_snapshot(flow, [])
|
|
|
|
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
|
|
assert "reasoning" in roles
|
|
|
|
def test_snapshot_preserves_reasoning_encrypted_value(self):
|
|
"""Snapshot reasoning with encryptedValue is preserved end-to-end."""
|
|
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
content = Content.from_text_reasoning(
|
|
id="reason_e2e",
|
|
text="visible",
|
|
protected_data="secret-data",
|
|
)
|
|
_emit_text_reasoning(content, flow)
|
|
|
|
text_content = Content.from_text("Final answer.")
|
|
_emit_text(text_content, flow)
|
|
|
|
snapshot = _build_messages_snapshot(flow, [])
|
|
|
|
reasoning_msgs = [
|
|
m
|
|
for m in snapshot.messages
|
|
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "reasoning"
|
|
]
|
|
assert len(reasoning_msgs) == 1
|
|
msg = reasoning_msgs[0]
|
|
if isinstance(msg, dict):
|
|
assert msg["content"] == "visible"
|
|
assert msg["encryptedValue"] == "secret-data"
|
|
|
|
def test_emit_content_routes_reasoning_with_flow(self):
|
|
"""_emit_content passes flow to _emit_text_reasoning for persistence."""
|
|
flow = FlowState()
|
|
content = Content.from_text_reasoning(text="routed reasoning")
|
|
|
|
_emit_content(content, flow)
|
|
|
|
assert len(flow.reasoning_messages) == 1
|
|
assert flow.reasoning_messages[0]["content"] == "routed reasoning"
|
|
|
|
def test_reasoning_without_flow_does_not_error(self):
|
|
"""Calling _emit_text_reasoning without flow still works (backward compat)."""
|
|
content = Content.from_text_reasoning(text="no flow")
|
|
|
|
events = _emit_text_reasoning(content)
|
|
|
|
assert len(events) == 5
|
|
assert isinstance(events[0], ReasoningStartEvent)
|
|
|
|
def test_snapshot_reasoning_ordering(self):
|
|
"""Reasoning messages appear after assistant text in snapshot."""
|
|
from agent_framework_ag_ui._agent_run import _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
reasoning_content = Content.from_text_reasoning(id="r1", text="Thinking...")
|
|
_emit_text_reasoning(reasoning_content, flow)
|
|
|
|
text_content = Content.from_text("Answer")
|
|
_emit_text(text_content, flow)
|
|
|
|
snapshot = _build_messages_snapshot(flow, [{"id": "u1", "role": "user", "content": "Hi"}])
|
|
|
|
# user -> assistant text -> reasoning
|
|
assert len(snapshot.messages) == 3
|
|
roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages]
|
|
assert roles == ["user", "assistant", "reasoning"]
|
|
|
|
def test_reasoning_accumulates_incremental_deltas(self):
|
|
"""Multiple reasoning deltas with the same id accumulate into one entry."""
|
|
flow = FlowState()
|
|
content1 = Content.from_text_reasoning(id="reason_inc", text="First ")
|
|
content2 = Content.from_text_reasoning(id="reason_inc", text="second ")
|
|
content3 = Content.from_text_reasoning(id="reason_inc", text="third.")
|
|
|
|
_emit_text_reasoning(content1, flow)
|
|
_emit_text_reasoning(content2, flow)
|
|
_emit_text_reasoning(content3, flow)
|
|
|
|
assert len(flow.reasoning_messages) == 1
|
|
assert flow.reasoning_messages[0]["id"] == "reason_inc"
|
|
assert flow.reasoning_messages[0]["content"] == "First second third."
|
|
|
|
def test_reasoning_accumulates_distinct_message_ids(self):
|
|
"""Reasoning entries with different ids are stored separately."""
|
|
flow = FlowState()
|
|
content_a = Content.from_text_reasoning(id="a", text="alpha")
|
|
content_b = Content.from_text_reasoning(id="b", text="beta")
|
|
|
|
_emit_text_reasoning(content_a, flow)
|
|
_emit_text_reasoning(content_b, flow)
|
|
|
|
assert len(flow.reasoning_messages) == 2
|
|
assert flow.reasoning_messages[0]["content"] == "alpha"
|
|
assert flow.reasoning_messages[1]["content"] == "beta"
|
|
|
|
def test_reasoning_encrypted_value_updated_on_later_delta(self):
|
|
"""encryptedValue is set even when it arrives with a later delta."""
|
|
flow = FlowState()
|
|
content1 = Content.from_text_reasoning(id="enc_late", text="part1 ")
|
|
content2 = Content.from_text_reasoning(id="enc_late", text="part2", protected_data="encrypted-payload")
|
|
|
|
_emit_text_reasoning(content1, flow)
|
|
_emit_text_reasoning(content2, flow)
|
|
|
|
assert len(flow.reasoning_messages) == 1
|
|
assert flow.reasoning_messages[0]["content"] == "part1 part2"
|
|
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-payload"
|
|
|
|
def test_reasoning_done_after_deltas_does_not_duplicate(self):
|
|
"""A done-style content arriving after deltas does not duplicate accumulated text.
|
|
|
|
The upstream client should skip done events when deltas preceded them,
|
|
but if one leaks through, the accumulator must not double-append.
|
|
This test verifies that only the delta-produced text is stored.
|
|
"""
|
|
flow = FlowState()
|
|
msg_id = "reason_dedup"
|
|
|
|
delta1 = Content.from_text_reasoning(id=msg_id, text="Hello ")
|
|
delta2 = Content.from_text_reasoning(id=msg_id, text="world")
|
|
|
|
_emit_text_reasoning(delta1, flow)
|
|
_emit_text_reasoning(delta2, flow)
|
|
|
|
# Accumulated text should equal the concatenation of deltas only
|
|
assert len(flow.reasoning_messages) == 1
|
|
assert flow.reasoning_messages[0]["content"] == "Hello world"
|
|
assert flow.reasoning_messages[0]["id"] == msg_id
|
|
|
|
def test_reasoning_deltas_emit_one_content_event_each(self):
|
|
"""Each reasoning delta emits exactly one ReasoningMessageContentEvent
|
|
within a single Start/End sequence (streaming pattern)."""
|
|
flow = FlowState()
|
|
msg_id = "reason_evt"
|
|
|
|
delta1 = Content.from_text_reasoning(id=msg_id, text="Think ")
|
|
delta2 = Content.from_text_reasoning(id=msg_id, text="hard")
|
|
|
|
events1 = _emit_text_reasoning(delta1, flow)
|
|
events2 = _emit_text_reasoning(delta2, flow)
|
|
close_events = _close_reasoning_block(flow)
|
|
|
|
all_events = events1 + events2 + close_events
|
|
content_events = [e for e in all_events if isinstance(e, ReasoningMessageContentEvent)]
|
|
|
|
assert len(content_events) == 2
|
|
assert content_events[0].delta == "Think "
|
|
assert content_events[1].delta == "hard"
|
|
|
|
# Streaming pattern: one Start/End sequence wrapping both content events
|
|
start_events = [e for e in all_events if isinstance(e, ReasoningStartEvent)]
|
|
end_events = [e for e in all_events if isinstance(e, ReasoningEndEvent)]
|
|
msg_start_events = [e for e in all_events if isinstance(e, ReasoningMessageStartEvent)]
|
|
msg_end_events = [e for e in all_events if isinstance(e, ReasoningMessageEndEvent)]
|
|
assert len(start_events) == 1
|
|
assert len(end_events) == 1
|
|
assert len(msg_start_events) == 1
|
|
assert len(msg_end_events) == 1
|
|
|
|
def test_reasoning_streaming_event_order(self):
|
|
"""Streaming reasoning emits Start once, then Content per delta, then End on close."""
|
|
flow = FlowState()
|
|
msg_id = "reason_order"
|
|
|
|
d1 = Content.from_text_reasoning(id=msg_id, text="A ")
|
|
d2 = Content.from_text_reasoning(id=msg_id, text="B ")
|
|
d3 = Content.from_text_reasoning(id=msg_id, text="C")
|
|
|
|
events = []
|
|
events.extend(_emit_text_reasoning(d1, flow))
|
|
events.extend(_emit_text_reasoning(d2, flow))
|
|
events.extend(_emit_text_reasoning(d3, flow))
|
|
events.extend(_close_reasoning_block(flow))
|
|
|
|
assert isinstance(events[0], ReasoningStartEvent)
|
|
assert isinstance(events[1], ReasoningMessageStartEvent)
|
|
assert isinstance(events[2], ReasoningMessageContentEvent)
|
|
assert events[2].delta == "A "
|
|
assert isinstance(events[3], ReasoningMessageContentEvent)
|
|
assert events[3].delta == "B "
|
|
assert isinstance(events[4], ReasoningMessageContentEvent)
|
|
assert events[4].delta == "C"
|
|
assert isinstance(events[5], ReasoningMessageEndEvent)
|
|
assert isinstance(events[6], ReasoningEndEvent)
|
|
assert len(events) == 7
|
|
|
|
def test_close_reasoning_block_noop_when_not_open(self):
|
|
"""_close_reasoning_block returns empty list when no reasoning block is open."""
|
|
flow = FlowState()
|
|
assert _close_reasoning_block(flow) == []
|
|
|
|
def test_close_reasoning_block_resets_state(self):
|
|
"""_close_reasoning_block clears reasoning_message_id."""
|
|
flow = FlowState()
|
|
_emit_text_reasoning(Content.from_text_reasoning(id="r1", text="x"), flow)
|
|
assert flow.reasoning_message_id == "r1"
|
|
|
|
_close_reasoning_block(flow)
|
|
assert flow.reasoning_message_id is None
|
|
|
|
def test_emit_content_closes_reasoning_on_text(self):
|
|
"""Switching from reasoning to text content auto-closes reasoning block."""
|
|
flow = FlowState()
|
|
reasoning = Content.from_text_reasoning(id="r1", text="thinking")
|
|
text = Content.from_text("answer")
|
|
|
|
r_events = _emit_content(reasoning, flow)
|
|
t_events = _emit_content(text, flow)
|
|
|
|
# reasoning events: Start + MsgStart + Content
|
|
assert isinstance(r_events[0], ReasoningStartEvent)
|
|
# text events should start with reasoning End events
|
|
assert isinstance(t_events[0], ReasoningMessageEndEvent)
|
|
assert isinstance(t_events[1], ReasoningEndEvent)
|
|
# then text start
|
|
|
|
assert isinstance(t_events[2], TextMessageStartEvent)
|
|
assert isinstance(t_events[3], TextMessageContentEvent)
|
|
|
|
def test_reasoning_distinct_ids_close_previous_block(self):
|
|
"""Emitting reasoning with a new message_id auto-closes the previous block."""
|
|
flow = FlowState()
|
|
c1 = Content.from_text_reasoning(id="block1", text="first")
|
|
c2 = Content.from_text_reasoning(id="block2", text="second")
|
|
|
|
events1 = _emit_text_reasoning(c1, flow)
|
|
events2 = _emit_text_reasoning(c2, flow)
|
|
close = _close_reasoning_block(flow)
|
|
|
|
# events1: Start(block1) + MsgStart(block1) + Content(block1)
|
|
assert events1[0].message_id == "block1"
|
|
# events2: MsgEnd(block1) + End(block1) + Start(block2) + MsgStart(block2) + Content(block2)
|
|
assert isinstance(events2[0], ReasoningMessageEndEvent)
|
|
assert events2[0].message_id == "block1"
|
|
assert isinstance(events2[1], ReasoningEndEvent)
|
|
assert events2[1].message_id == "block1"
|
|
assert isinstance(events2[2], ReasoningStartEvent)
|
|
assert events2[2].message_id == "block2"
|
|
# close: MsgEnd(block2) + End(block2)
|
|
assert isinstance(close[0], ReasoningMessageEndEvent)
|
|
assert close[0].message_id == "block2"
|