Files
Evan Mattson 94af83680e Python: Fix RUN_FINISHED.interrupt to accumulate all interrupts when multiple tools need approval (#4717)
* Fix flow.interrupts overwrite when multiple tools need approval (#4590)

Change flow.interrupts assignment to append so that all interrupt entries
accumulate when multiple tools require approval in a single turn.

Both _run_common.py and _agent_run.py used assignment (=) which caused
each new interrupt to overwrite the previous one. Switching to append()
ensures RUN_FINISHED.interrupt contains all pending approvals.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Add test for streaming path with multiple confirm_changes interrupts (#4590)

Add integration test exercising run_agent_stream with multiple predictive
tool calls requiring confirmation. Verifies that flow.interrupts.append()
correctly accumulates all interrupt entries and they appear in the
RUN_FINISHED event.

Also confirms FlowState already declares interrupts field with
default_factory=list, addressing the AttributeError concern from review.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Apply pre-commit auto-fixes

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-17 12:44:44 +00:00

994 lines
36 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for _agent_run.py helper functions and FlowState."""
import pytest
from ag_ui.core import (
CustomEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallArgsEvent,
)
from agent_framework import AgentResponseUpdate, Content, Message, ResponseStream
from agent_framework.exceptions import AgentInvalidResponseException
from agent_framework_ag_ui._agent_run import (
_build_safe_metadata,
_create_state_context_message,
_inject_state_context,
_normalize_response_stream,
_resume_to_tool_messages,
_should_suppress_intermediate_snapshot,
)
from agent_framework_ag_ui._run_common import (
FlowState,
_build_run_finished_event,
_emit_approval_request,
_emit_content,
_emit_text,
_emit_tool_call,
_emit_tool_result,
_extract_resume_payload,
_has_only_tool_calls,
)
class TestBuildSafeMetadata:
"""Tests for _build_safe_metadata function."""
def test_none_metadata(self):
"""Returns empty dict for None."""
result = _build_safe_metadata(None)
assert result == {}
def test_empty_metadata(self):
"""Returns empty dict for empty dict."""
result = _build_safe_metadata({})
assert result == {}
def test_short_string_values(self):
"""Preserves short string values."""
metadata = {"key1": "short", "key2": "value"}
result = _build_safe_metadata(metadata)
assert result == metadata
def test_truncates_long_strings(self):
"""Truncates strings over 512 chars."""
long_value = "x" * 1000
metadata = {"key": long_value}
result = _build_safe_metadata(metadata)
assert len(result["key"]) == 512
def test_serializes_non_strings(self):
"""Serializes non-string values to JSON."""
metadata = {"count": 42, "items": [1, 2, 3]}
result = _build_safe_metadata(metadata)
assert result["count"] == "42"
assert result["items"] == "[1, 2, 3]"
def test_truncates_serialized_values(self):
"""Truncates serialized values over 512 chars."""
long_list = list(range(200))
metadata = {"data": long_list}
result = _build_safe_metadata(metadata)
assert len(result["data"]) == 512
class TestHasOnlyToolCalls:
"""Tests for _has_only_tool_calls function."""
def test_only_tool_calls(self):
"""Returns True when only function_call content."""
contents = [
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
]
assert _has_only_tool_calls(contents) is True
def test_tool_call_with_text(self):
"""Returns False when both tool call and text."""
contents = [
Content.from_text("Some text"),
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
]
assert _has_only_tool_calls(contents) is False
def test_only_text(self):
"""Returns False when only text."""
contents = [Content.from_text("Just text")]
assert _has_only_tool_calls(contents) is False
def test_empty_contents(self):
"""Returns False for empty contents."""
assert _has_only_tool_calls([]) is False
def test_tool_call_with_empty_text(self):
"""Returns True when text content has empty text."""
contents = [
Content.from_text(""),
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
]
assert _has_only_tool_calls(contents) is True
class TestShouldSuppressIntermediateSnapshot:
"""Tests for _should_suppress_intermediate_snapshot function."""
def test_no_tool_name(self):
"""Returns False when no tool name."""
result = _should_suppress_intermediate_snapshot(
None, {"key": {"tool": "write_doc", "tool_argument": "content"}}, False
)
assert result is False
def test_no_config(self):
"""Returns False when no config."""
result = _should_suppress_intermediate_snapshot("write_doc", None, False)
assert result is False
def test_confirmation_required(self):
"""Returns False when confirmation is required."""
config = {"key": {"tool": "write_doc", "tool_argument": "content"}}
result = _should_suppress_intermediate_snapshot("write_doc", config, True)
assert result is False
def test_tool_not_in_config(self):
"""Returns False when tool not in config."""
config = {"key": {"tool": "other_tool", "tool_argument": "content"}}
result = _should_suppress_intermediate_snapshot("write_doc", config, False)
assert result is False
def test_suppresses_predictive_tool(self):
"""Returns True for predictive tool without confirmation."""
config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
result = _should_suppress_intermediate_snapshot("write_doc", config, False)
assert result is True
class TestFlowState:
"""Tests for FlowState dataclass."""
def test_default_values(self):
"""Tests default initialization."""
flow = FlowState()
assert flow.message_id is None
assert flow.tool_call_id is None
assert flow.tool_call_name is None
assert flow.waiting_for_approval is False
assert flow.current_state == {}
assert flow.accumulated_text == ""
assert flow.pending_tool_calls == []
assert flow.tool_calls_by_id == {}
assert flow.tool_results == []
assert flow.tool_calls_ended == set()
assert flow.interrupts == []
def test_get_tool_name(self):
"""Tests get_tool_name method."""
flow = FlowState()
flow.tool_calls_by_id = {"call_123": {"function": {"name": "get_weather", "arguments": "{}"}}}
assert flow.get_tool_name("call_123") == "get_weather"
assert flow.get_tool_name("nonexistent") is None
assert flow.get_tool_name(None) is None
def test_get_tool_name_empty_name(self):
"""Tests get_tool_name with empty name."""
flow = FlowState()
flow.tool_calls_by_id = {"call_123": {"function": {"name": "", "arguments": "{}"}}}
assert flow.get_tool_name("call_123") is None
def test_get_pending_without_end(self):
"""Tests get_pending_without_end method."""
flow = FlowState()
flow.pending_tool_calls = [
{"id": "call_1", "function": {"name": "tool1"}},
{"id": "call_2", "function": {"name": "tool2"}},
{"id": "call_3", "function": {"name": "tool3"}},
]
flow.tool_calls_ended = {"call_1", "call_3"}
result = flow.get_pending_without_end()
assert len(result) == 1
assert result[0]["id"] == "call_2"
class TestNormalizeResponseStream:
"""Tests for _normalize_response_stream helper."""
async def test_accepts_response_stream(self):
"""Accept standard ResponseStream values."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
stream = await _normalize_response_stream(ResponseStream(_stream()))
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_accepts_async_iterable(self):
"""Accept workflow-style async generator streams."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
stream = await _normalize_response_stream(_stream())
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_accepts_awaitable_resolving_to_async_iterable(self):
"""Accept awaitables that resolve to async iterable streams."""
async def _stream():
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
async def _resolve():
return _stream()
stream = await _normalize_response_stream(_resolve())
updates = [update async for update in stream]
assert len(updates) == 1
assert updates[0].contents[0].text == "hello"
async def test_rejects_non_stream_values(self):
"""Reject unsupported stream return values."""
with pytest.raises(AgentInvalidResponseException):
await _normalize_response_stream("not-a-stream")
class TestCreateStateContextMessage:
"""Tests for _create_state_context_message function."""
def test_no_state(self):
"""Returns None when no state."""
result = _create_state_context_message({}, {"properties": {}})
assert result is None
def test_no_schema(self):
"""Returns None when no schema."""
result = _create_state_context_message({"key": "value"}, {})
assert result is None
def test_creates_message(self):
"""Creates state context message."""
state = {"document": "Hello world"}
schema = {"properties": {"document": {"type": "string"}}}
result = _create_state_context_message(state, schema)
assert result is not None
assert result.role == "system"
assert len(result.contents) == 1
assert "Hello world" in result.contents[0].text
assert "Current state" in result.contents[0].text
class TestInjectStateContext:
"""Tests for _inject_state_context function."""
def test_no_state_message(self):
"""Returns original messages when no state context needed."""
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
result = _inject_state_context(messages, {}, {})
assert result == messages
def test_empty_messages(self):
"""Returns empty list for empty messages."""
result = _inject_state_context([], {"key": "value"}, {"properties": {}})
assert result == []
def test_last_message_not_user(self):
"""Returns original messages when last message is not from user."""
messages = [
Message(role="user", contents=[Content.from_text("Hello")]),
Message(role="assistant", contents=[Content.from_text("Hi")]),
]
state = {"key": "value"}
schema = {"properties": {"key": {"type": "string"}}}
result = _inject_state_context(messages, state, schema)
assert result == messages
def test_injects_before_last_user_message(self):
"""Injects state context before last user message."""
messages = [
Message(role="system", contents=[Content.from_text("You are helpful")]),
Message(role="user", contents=[Content.from_text("Hello")]),
]
state = {"document": "content"}
schema = {"properties": {"document": {"type": "string"}}}
result = _inject_state_context(messages, state, schema)
assert len(result) == 3
# System message first
assert result[0].role == "system"
assert "helpful" in result[0].contents[0].text
# State context second
assert result[1].role == "system"
assert "Current state" in result[1].contents[0].text
# User message last
assert result[2].role == "user"
assert "Hello" in result[2].contents[0].text
# Additional tests for _agent_run.py functions
def test_emit_text_basic():
"""Test _emit_text emits correct events."""
flow = FlowState()
content = Content.from_text("Hello world")
events = _emit_text(content, flow)
assert len(events) == 2 # TextMessageStartEvent + TextMessageContentEvent
assert flow.message_id is not None
assert flow.accumulated_text == "Hello world"
def test_emit_text_skip_empty():
"""Test _emit_text skips empty text."""
flow = FlowState()
content = Content.from_text("")
events = _emit_text(content, flow)
assert len(events) == 0
def test_emit_text_continues_existing_message():
"""Test _emit_text continues existing message."""
flow = FlowState()
flow.message_id = "existing-id"
content = Content.from_text("more text")
events = _emit_text(content, flow)
assert len(events) == 1 # Only TextMessageContentEvent, no new start
assert flow.message_id == "existing-id"
def test_emit_text_skips_duplicate_full_message_delta():
"""Test _emit_text skips replayed full-message chunks on an open message."""
flow = FlowState()
flow.message_id = "existing-id"
flow.accumulated_text = "Case complete."
content = Content.from_text("Case complete.")
events = _emit_text(content, flow)
assert events == []
assert flow.accumulated_text == "Case complete."
def test_emit_text_skips_when_waiting_for_approval():
"""Test _emit_text skips when waiting for approval."""
flow = FlowState()
flow.waiting_for_approval = True
content = Content.from_text("should skip")
events = _emit_text(content, flow)
assert len(events) == 0
def test_emit_text_skips_when_skip_text_flag():
"""Test _emit_text skips with skip_text flag."""
flow = FlowState()
content = Content.from_text("should skip")
events = _emit_text(content, flow, skip_text=True)
assert len(events) == 0
def test_emit_tool_call_basic():
"""Test _emit_tool_call emits correct events."""
flow = FlowState()
content = Content.from_function_call(
call_id="call_123",
name="get_weather",
arguments='{"city": "NYC"}',
)
events = _emit_tool_call(content, flow)
assert len(events) >= 1 # At least ToolCallStartEvent
assert flow.tool_call_id == "call_123"
assert flow.tool_call_name == "get_weather"
def test_emit_tool_call_generates_id():
"""Test _emit_tool_call generates ID when not provided."""
flow = FlowState()
# Create content without call_id
content = Content(type="function_call", name="test_tool", arguments="{}")
events = _emit_tool_call(content, flow)
assert len(events) >= 1
assert flow.tool_call_id is not None # ID should be generated
def test_emit_tool_call_skips_duplicate_full_arguments_replay():
"""Test _emit_tool_call skips replayed full-arguments on an existing tool call.
This is a regression test for issue #4194 where some streaming providers
send the full arguments string again after streaming deltas, causing the
arguments to be doubled in MESSAGES_SNAPSHOT events.
Mirrors test_emit_text_skips_duplicate_full_message_delta for consistency.
"""
flow = FlowState()
full_args = '{"city": "Seattle"}'
# Step 1: Initial tool call with name + arguments (normal start)
content_start = Content.from_function_call(
call_id="call_dup",
name="get_weather",
arguments=full_args,
)
events_start = _emit_tool_call(content_start, flow)
# Should emit ToolCallStartEvent + ToolCallArgsEvent
assert any(isinstance(e, ToolCallArgsEvent) for e in events_start)
assert flow.tool_calls_by_id["call_dup"]["function"]["arguments"] == full_args
# Step 2: Provider replays the full arguments (duplicate)
content_replay = Content(type="function_call", call_id="call_dup", arguments=full_args)
events_replay = _emit_tool_call(content_replay, flow)
# Should NOT emit any ToolCallArgsEvent (early return on replay)
args_events = [e for e in events_replay if isinstance(e, ToolCallArgsEvent)]
assert args_events == [], "Duplicate full-arguments replay should not emit ToolCallArgsEvent"
# Accumulated arguments should remain unchanged
assert flow.tool_calls_by_id["call_dup"]["function"]["arguments"] == full_args
def test_emit_tool_result_closes_open_message():
"""Test _emit_tool_result emits TextMessageEndEvent for open text message.
This is a regression test for where TEXT_MESSAGE_END was not
emitted when using MCP tools because the message_id was reset without
closing the message first.
"""
flow = FlowState()
# Simulate an open text message (e.g., from Feature #4 tool-only detection)
flow.message_id = "open-msg-123"
flow.tool_call_id = "call_456"
content = Content.from_function_result(call_id="call_456", result="tool result")
events = _emit_tool_result(content, flow, predictive_handler=None)
# Should have: ToolCallEndEvent, ToolCallResultEvent, TextMessageEndEvent
assert len(events) == 3
# Verify TextMessageEndEvent is emitted for the open message
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
assert len(text_end_events) == 1
assert text_end_events[0].message_id == "open-msg-123"
# Verify message_id is reset after
assert flow.message_id is None
def test_emit_tool_result_no_open_message():
"""Test _emit_tool_result works when there's no open text message."""
flow = FlowState()
# No open message
flow.message_id = None
flow.tool_call_id = "call_456"
content = Content.from_function_result(call_id="call_456", result="tool result")
events = _emit_tool_result(content, flow, predictive_handler=None)
# Should have: ToolCallEndEvent, ToolCallResultEvent (no TextMessageEndEvent)
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
assert len(text_end_events) == 0
def test_emit_tool_result_serializes_non_string_result():
"""Non-string tool results should be serialized before emitting TOOL_CALL_RESULT."""
flow = FlowState()
content = Content.from_function_result(call_id="call_789", result={"ok": True, "items": [1, 2]})
events = _emit_tool_result(content, flow, predictive_handler=None)
result_event = next(event for event in events if getattr(event, "type", None) == "TOOL_CALL_RESULT")
assert isinstance(result_event.content, str)
assert '"ok": true' in result_event.content
assert flow.tool_results[0]["content"] == result_event.content
def test_emit_content_usage_emits_custom_usage_event():
"""Usage content should be emitted as a custom usage event."""
flow = FlowState()
content = Content.from_usage({"input_token_count": 3, "output_token_count": 2, "total_token_count": 5})
events = _emit_content(content, flow)
assert len(events) == 1
assert events[0].type == "CUSTOM"
assert events[0].name == "usage"
assert events[0].value["total_token_count"] == 5
def test_emit_approval_request_populates_interrupt_metadata():
"""Approval requests should populate FlowState interrupts for RUN_FINISHED metadata."""
flow = FlowState(message_id="msg-1")
function_call = Content.from_function_call(call_id="call_123", name="write_doc", arguments={"content": "x"})
approval_content = Content.from_function_approval_request(id="approval_1", function_call=function_call)
_emit_approval_request(approval_content, flow)
assert flow.waiting_for_approval is True
assert len(flow.interrupts) == 1
assert flow.interrupts[0]["id"] == "call_123"
assert flow.interrupts[0]["value"]["type"] == "function_approval_request"
def test_emit_approval_request_accumulates_multiple_interrupts():
"""Multiple approval requests in the same turn should accumulate in flow.interrupts."""
flow = FlowState(message_id="msg-1")
for i in range(1, 4):
function_call = Content.from_function_call(
call_id=f"call_{i}",
name=f"tool_{i}",
arguments={"arg": f"value_{i}"},
)
approval_content = Content.from_function_approval_request(
id=f"approval_{i}",
function_call=function_call,
)
_emit_approval_request(approval_content, flow)
assert len(flow.interrupts) == 3
interrupt_ids = {intr["id"] for intr in flow.interrupts}
assert interrupt_ids == {"call_1", "call_2", "call_3"}
def test_resume_to_tool_messages_from_interrupts_payload():
"""Resume payload interrupt responses map to tool messages."""
resume = {
"interrupts": [
{"id": "req_1", "value": {"accepted": True, "steps": []}},
{"id": "req_2", "value": "plain value"},
]
}
messages = _resume_to_tool_messages(resume)
assert len(messages) == 2
assert messages[0]["role"] == "tool"
assert messages[0]["toolCallId"] == "req_1"
assert '"accepted": true' in messages[0]["content"]
assert messages[1]["content"] == "plain value"
def test_extract_resume_payload_prefers_top_level_resume():
"""Top-level resume should take precedence over forwarded props."""
payload = {
"resume": {"interrupts": [{"id": "req_1", "value": "approved"}]},
"forwarded_props": {"command": {"resume": "ignored"}},
}
result = _extract_resume_payload(payload)
assert result == {"interrupts": [{"id": "req_1", "value": "approved"}]}
def test_extract_resume_payload_reads_forwarded_command_resume():
"""Forwarded command.resume should be treated as a resume payload."""
payload = {
"forwarded_props": {
"command": {"resume": '{"airline":"KLM","departure":"Amsterdam (AMS)","arrival":"San Francisco (SFO)"}'}
}
}
result = _extract_resume_payload(payload)
assert isinstance(result, str)
assert "KLM" in result
def test_build_run_finished_event_with_interrupt():
"""RUN_FINISHED helper should preserve interrupt payloads."""
event = _build_run_finished_event("run-1", "thread-1", interrupts=[{"id": "req_1", "value": {"x": 1}}])
dumped = event.model_dump()
assert dumped["run_id"] == "run-1"
assert dumped["thread_id"] == "thread-1"
assert dumped["interrupt"] == [{"id": "req_1", "value": {"x": 1}}]
def test_extract_approved_state_updates_no_handler():
"""Test _extract_approved_state_updates returns empty with no handler."""
from agent_framework_ag_ui._agent_run import _extract_approved_state_updates
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
result = _extract_approved_state_updates(messages, None)
assert result == {}
def test_extract_approved_state_updates_no_approval():
"""Test _extract_approved_state_updates returns empty when no approval content."""
from agent_framework_ag_ui._agent_run import _extract_approved_state_updates
from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler
handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "content"}})
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
result = _extract_approved_state_updates(messages, handler)
assert result == {}
class TestBuildMessagesSnapshot:
"""Tests for _build_messages_snapshot function."""
def test_tool_calls_and_text_are_separate_messages(self):
"""Test that tool calls and text content are emitted as separate messages.
This is a regression test for issue #3619 where tool calls and content
were incorrectly merged into a single assistant message.
"""
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
flow = FlowState()
flow.message_id = "msg-123"
flow.pending_tool_calls = [
{"id": "call_1", "function": {"name": "get_weather", "arguments": '{"city": "NYC"}'}},
]
flow.accumulated_text = "Here is the weather information."
flow.tool_results = [{"id": "result-1", "role": "tool", "content": '{"temp": 72}', "toolCallId": "call_1"}]
result = _build_messages_snapshot(flow, [])
# Should have 3 messages: tool call msg, tool result, text content msg
assert len(result.messages) == 3
# First message: assistant with tool calls only (no content)
assistant_tool_msg = result.messages[0]
assert assistant_tool_msg.role == "assistant"
assert assistant_tool_msg.tool_calls is not None
assert len(assistant_tool_msg.tool_calls) == 1
assert assistant_tool_msg.content is None
# Second message: tool result
tool_result_msg = result.messages[1]
assert tool_result_msg.role == "tool"
# Third message: assistant with content only (no tool calls)
assistant_text_msg = result.messages[2]
assert assistant_text_msg.role == "assistant"
assert assistant_text_msg.content == "Here is the weather information."
assert assistant_text_msg.tool_calls is None
# The text message should have a different ID than the tool call message
assert assistant_text_msg.id != assistant_tool_msg.id
def test_only_tool_calls_no_text(self):
"""Test snapshot with only tool calls and no accumulated text."""
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
flow = FlowState()
flow.message_id = "msg-123"
flow.pending_tool_calls = [
{"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
]
flow.accumulated_text = ""
flow.tool_results = []
result = _build_messages_snapshot(flow, [])
# Should have 1 message: tool call msg only
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert result.messages[0].tool_calls is not None
assert result.messages[0].content is None
def test_only_text_no_tool_calls(self):
"""Test snapshot with only text and no tool calls."""
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
flow = FlowState()
flow.message_id = "msg-123"
flow.pending_tool_calls = []
flow.accumulated_text = "Hello world"
flow.tool_results = []
result = _build_messages_snapshot(flow, [])
# Should have 1 message: text content msg only
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert result.messages[0].content == "Hello world"
assert result.messages[0].tool_calls is None
# Should use the existing message_id
assert result.messages[0].id == "msg-123"
def test_preserves_snapshot_messages(self):
"""Test that existing snapshot messages are preserved."""
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
flow = FlowState()
flow.pending_tool_calls = []
flow.accumulated_text = ""
existing_messages = [
{"id": "user-1", "role": "user", "content": "Hello"},
{"id": "assist-1", "role": "assistant", "content": "Hi there"},
]
result = _build_messages_snapshot(flow, existing_messages)
assert len(result.messages) == 2
assert result.messages[0].id == "user-1"
assert result.messages[1].id == "assist-1"
def test_malformed_json_in_confirm_args_skips_confirmation():
"""Test that malformed JSON in tool arguments skips confirm_changes flow.
This is a regression test to ensure that when tool arguments contain malformed
JSON, the code skips the confirmation flow entirely rather than crashing or
showing incomplete data to the user.
"""
import json
# Simulate the parsing logic - malformed JSON should trigger skip
malformed_arguments = "{ invalid json }"
tool_call = {"function": {"name": "write_doc", "arguments": malformed_arguments}}
# This is what the code should do - detect parsing failure and skip
should_skip_confirmation = False
try:
json.loads(tool_call.get("function", {}).get("arguments", "{}"))
except json.JSONDecodeError:
should_skip_confirmation = True
# Should skip confirmation when JSON is malformed
assert should_skip_confirmation is True
# Valid JSON should proceed with confirmation
valid_arguments = '{"content": "hello"}'
tool_call_valid = {"function": {"name": "write_doc", "arguments": valid_arguments}}
should_skip_confirmation = False
try:
function_arguments = json.loads(tool_call_valid.get("function", {}).get("arguments", "{}"))
except json.JSONDecodeError:
should_skip_confirmation = True
assert should_skip_confirmation is False
assert function_arguments == {"content": "hello"}
class TestTextMessageEventBalancing:
"""Tests for proper TEXT_MESSAGE_START/END event balancing.
These tests verify that the streaming flow produces balanced pairs of
TextMessageStartEvent and TextMessageEndEvent, especially when tool
execution is involved.
"""
def test_tool_only_flow_produces_balanced_events(self):
"""Test that a tool-only response produces balanced TEXT_MESSAGE events.
This simulates the scenario where the LLM immediately calls a tool
without any initial text, then returns text after the tool result.
"""
flow = FlowState()
all_events: list = []
# Step 1: LLM outputs function_call only (no text)
func_call_content = Content.from_function_call(
call_id="call_weather",
name="get_weather",
arguments='{"city": "Seattle"}',
)
# Feature #4 check: this should trigger TextMessageStartEvent
contents = [func_call_content]
if not flow.message_id and _has_only_tool_calls(contents):
flow.message_id = "tool-msg-1"
all_events.append(TextMessageStartEvent(message_id=flow.message_id, role="assistant"))
# Emit tool call events
all_events.extend(_emit_content(func_call_content, flow))
# Step 2: Tool executes and returns result
func_result_content = Content.from_function_result(
call_id="call_weather",
result='{"temp": 55, "conditions": "rainy"}',
)
# This should close the text message
all_events.extend(_emit_tool_result(func_result_content, flow))
# Verify message_id was reset
assert flow.message_id is None, "message_id should be reset after tool result"
# Step 3: LLM outputs text response
text_content = Content.from_text("The weather in Seattle is 55°F and rainy.")
# Since message_id is None, _emit_text should create a new one
for event in _emit_content(text_content, flow):
all_events.append(event)
# Step 4: End of stream - emit final TextMessageEndEvent
if flow.message_id:
all_events.append(TextMessageEndEvent(message_id=flow.message_id))
# Verify event counts
start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)]
end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
# Should have 2 TextMessageStartEvent and 2 TextMessageEndEvent
assert len(start_events) == 2, f"Expected 2 start events, got {len(start_events)}"
assert len(end_events) == 2, f"Expected 2 end events, got {len(end_events)}"
# Verify order: first message should start and end before second starts
# Find indices
start_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageStartEvent)]
end_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageEndEvent)]
# First end should come before second start
assert end_indices[0] < start_indices[1], (
f"First TextMessageEndEvent (index {end_indices[0]}) should come "
f"before second TextMessageStartEvent (index {start_indices[1]})"
)
def test_text_then_tool_flow(self):
"""Test flow where LLM outputs text first, then calls a tool.
This simulates: "Let me check the weather..." -> tool call -> tool result -> "The weather is..."
"""
flow = FlowState()
all_events: list = []
# Step 1: LLM outputs text first
text1 = Content.from_text("Let me check the weather for you.")
all_events.extend(_emit_content(text1, flow))
# Verify message_id is set
assert flow.message_id is not None, "message_id should be set after text"
first_msg_id = flow.message_id
# Step 2: LLM outputs function_call
func_call = Content.from_function_call(
call_id="call_1",
name="get_weather",
arguments="{}",
)
all_events.extend(_emit_content(func_call, flow))
# Step 3: Tool result comes back
func_result = Content.from_function_result(call_id="call_1", result="sunny")
all_events.extend(_emit_tool_result(func_result, flow))
# Verify message_id was reset and first message was closed
assert flow.message_id is None
end_events_so_far = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
assert len(end_events_so_far) == 1
assert end_events_so_far[0].message_id == first_msg_id
# Step 4: LLM outputs follow-up text
text2 = Content.from_text("The weather is sunny!")
all_events.extend(_emit_content(text2, flow))
# Step 5: End of stream
if flow.message_id:
all_events.append(TextMessageEndEvent(message_id=flow.message_id))
# Verify balance
start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)]
end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
assert len(start_events) == 2
assert len(end_events) == 2
async def test_run_agent_stream_accumulates_multiple_confirm_interrupts():
"""Multiple predictive tool calls in a single streaming run should accumulate interrupts.
This exercises the confirm_changes path in run_agent_stream (_agent_run.py),
ensuring that flow.interrupts.append() works correctly for multiple tool calls
and all interrupts appear in the RUN_FINISHED event.
"""
import json
from conftest import StubAgent
from agent_framework_ag_ui import AgentFrameworkAgent
predict_config = {
"tasks": {"tool": "generate_tasks", "tool_argument": "steps"},
"notes": {"tool": "generate_notes", "tool_argument": "items"},
}
state_schema = {
"tasks": {"type": "array", "items": {"type": "object"}},
"notes": {"type": "array", "items": {"type": "object"}},
}
updates = [
AgentResponseUpdate(
contents=[
Content.from_function_call(
name="generate_tasks",
call_id="call-tasks",
arguments=json.dumps({"steps": [{"description": "Task 1"}]}),
),
Content.from_function_call(
name="generate_notes",
call_id="call-notes",
arguments=json.dumps({"items": [{"description": "Note 1"}]}),
),
],
role="assistant",
),
]
stub = StubAgent(updates=updates)
agent = AgentFrameworkAgent(
agent=stub,
state_schema=state_schema,
predict_state_config=predict_config,
require_confirmation=True,
)
payload = {
"thread_id": "thread-multi",
"run_id": "run-multi",
"messages": [{"role": "user", "content": "Generate tasks and notes"}],
"state": {"tasks": [], "notes": []},
}
events = [event async for event in agent.run(payload)]
# Find RUN_FINISHED event and verify multiple interrupts
finished_events = [
e
for e in events
if getattr(e, "type", None) == "RUN_FINISHED"
or getattr(getattr(e, "type", None), "value", None) == "RUN_FINISHED"
]
assert finished_events, f"Expected RUN_FINISHED event. Types: {[getattr(e, 'type', None) for e in events]}"
finished = finished_events[-1]
interrupt = getattr(finished, "interrupt", None)
assert interrupt is not None, "Expected interrupt metadata in RUN_FINISHED"
assert len(interrupt) == 2, f"Expected 2 interrupts (one per tool), got {len(interrupt)}"
# Verify both tool calls are represented in interrupt metadata
interrupt_tool_names = {i["value"]["function_call"]["name"] for i in interrupt}
assert interrupt_tool_names == {"generate_tasks", "generate_notes"}
def test_emit_oauth_consent_request():
"""Test that oauth_consent_request content emits a CustomEvent."""
content = Content.from_oauth_consent_request(
consent_link="https://login.microsoftonline.com/consent",
)
flow = FlowState()
events = _emit_content(content, flow)
assert len(events) == 1
assert isinstance(events[0], CustomEvent)
assert events[0].name == "oauth_consent_request"
assert events[0].value == {"consent_link": "https://login.microsoftonline.com/consent"}
def test_emit_oauth_consent_request_no_link():
"""Test that oauth_consent_request without a consent_link emits no events."""
content = Content("oauth_consent_request")
flow = FlowState()
events = _emit_content(content, flow)
assert len(events) == 0