mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
94af83680e
* Fix flow.interrupts overwrite when multiple tools need approval (#4590) Change flow.interrupts assignment to append so that all interrupt entries accumulate when multiple tools require approval in a single turn. Both _run_common.py and _agent_run.py used assignment (=) which caused each new interrupt to overwrite the previous one. Switching to append() ensures RUN_FINISHED.interrupt contains all pending approvals. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add test for streaming path with multiple confirm_changes interrupts (#4590) Add integration test exercising run_agent_stream with multiple predictive tool calls requiring confirmation. Verifies that flow.interrupts.append() correctly accumulates all interrupt entries and they appear in the RUN_FINISHED event. Also confirms FlowState already declares interrupts field with default_factory=list, addressing the AttributeError concern from review. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
994 lines
36 KiB
Python
994 lines
36 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Tests for _agent_run.py helper functions and FlowState."""
|
|
|
|
import pytest
|
|
from ag_ui.core import (
|
|
CustomEvent,
|
|
TextMessageEndEvent,
|
|
TextMessageStartEvent,
|
|
ToolCallArgsEvent,
|
|
)
|
|
from agent_framework import AgentResponseUpdate, Content, Message, ResponseStream
|
|
from agent_framework.exceptions import AgentInvalidResponseException
|
|
|
|
from agent_framework_ag_ui._agent_run import (
|
|
_build_safe_metadata,
|
|
_create_state_context_message,
|
|
_inject_state_context,
|
|
_normalize_response_stream,
|
|
_resume_to_tool_messages,
|
|
_should_suppress_intermediate_snapshot,
|
|
)
|
|
from agent_framework_ag_ui._run_common import (
|
|
FlowState,
|
|
_build_run_finished_event,
|
|
_emit_approval_request,
|
|
_emit_content,
|
|
_emit_text,
|
|
_emit_tool_call,
|
|
_emit_tool_result,
|
|
_extract_resume_payload,
|
|
_has_only_tool_calls,
|
|
)
|
|
|
|
|
|
class TestBuildSafeMetadata:
|
|
"""Tests for _build_safe_metadata function."""
|
|
|
|
def test_none_metadata(self):
|
|
"""Returns empty dict for None."""
|
|
result = _build_safe_metadata(None)
|
|
assert result == {}
|
|
|
|
def test_empty_metadata(self):
|
|
"""Returns empty dict for empty dict."""
|
|
result = _build_safe_metadata({})
|
|
assert result == {}
|
|
|
|
def test_short_string_values(self):
|
|
"""Preserves short string values."""
|
|
metadata = {"key1": "short", "key2": "value"}
|
|
result = _build_safe_metadata(metadata)
|
|
assert result == metadata
|
|
|
|
def test_truncates_long_strings(self):
|
|
"""Truncates strings over 512 chars."""
|
|
long_value = "x" * 1000
|
|
metadata = {"key": long_value}
|
|
result = _build_safe_metadata(metadata)
|
|
assert len(result["key"]) == 512
|
|
|
|
def test_serializes_non_strings(self):
|
|
"""Serializes non-string values to JSON."""
|
|
metadata = {"count": 42, "items": [1, 2, 3]}
|
|
result = _build_safe_metadata(metadata)
|
|
assert result["count"] == "42"
|
|
assert result["items"] == "[1, 2, 3]"
|
|
|
|
def test_truncates_serialized_values(self):
|
|
"""Truncates serialized values over 512 chars."""
|
|
long_list = list(range(200))
|
|
metadata = {"data": long_list}
|
|
result = _build_safe_metadata(metadata)
|
|
assert len(result["data"]) == 512
|
|
|
|
|
|
class TestHasOnlyToolCalls:
|
|
"""Tests for _has_only_tool_calls function."""
|
|
|
|
def test_only_tool_calls(self):
|
|
"""Returns True when only function_call content."""
|
|
contents = [
|
|
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
|
|
]
|
|
assert _has_only_tool_calls(contents) is True
|
|
|
|
def test_tool_call_with_text(self):
|
|
"""Returns False when both tool call and text."""
|
|
contents = [
|
|
Content.from_text("Some text"),
|
|
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
|
|
]
|
|
assert _has_only_tool_calls(contents) is False
|
|
|
|
def test_only_text(self):
|
|
"""Returns False when only text."""
|
|
contents = [Content.from_text("Just text")]
|
|
assert _has_only_tool_calls(contents) is False
|
|
|
|
def test_empty_contents(self):
|
|
"""Returns False for empty contents."""
|
|
assert _has_only_tool_calls([]) is False
|
|
|
|
def test_tool_call_with_empty_text(self):
|
|
"""Returns True when text content has empty text."""
|
|
contents = [
|
|
Content.from_text(""),
|
|
Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"),
|
|
]
|
|
assert _has_only_tool_calls(contents) is True
|
|
|
|
|
|
class TestShouldSuppressIntermediateSnapshot:
|
|
"""Tests for _should_suppress_intermediate_snapshot function."""
|
|
|
|
def test_no_tool_name(self):
|
|
"""Returns False when no tool name."""
|
|
result = _should_suppress_intermediate_snapshot(
|
|
None, {"key": {"tool": "write_doc", "tool_argument": "content"}}, False
|
|
)
|
|
assert result is False
|
|
|
|
def test_no_config(self):
|
|
"""Returns False when no config."""
|
|
result = _should_suppress_intermediate_snapshot("write_doc", None, False)
|
|
assert result is False
|
|
|
|
def test_confirmation_required(self):
|
|
"""Returns False when confirmation is required."""
|
|
config = {"key": {"tool": "write_doc", "tool_argument": "content"}}
|
|
result = _should_suppress_intermediate_snapshot("write_doc", config, True)
|
|
assert result is False
|
|
|
|
def test_tool_not_in_config(self):
|
|
"""Returns False when tool not in config."""
|
|
config = {"key": {"tool": "other_tool", "tool_argument": "content"}}
|
|
result = _should_suppress_intermediate_snapshot("write_doc", config, False)
|
|
assert result is False
|
|
|
|
def test_suppresses_predictive_tool(self):
|
|
"""Returns True for predictive tool without confirmation."""
|
|
config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
|
|
result = _should_suppress_intermediate_snapshot("write_doc", config, False)
|
|
assert result is True
|
|
|
|
|
|
class TestFlowState:
|
|
"""Tests for FlowState dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Tests default initialization."""
|
|
flow = FlowState()
|
|
assert flow.message_id is None
|
|
assert flow.tool_call_id is None
|
|
assert flow.tool_call_name is None
|
|
assert flow.waiting_for_approval is False
|
|
assert flow.current_state == {}
|
|
assert flow.accumulated_text == ""
|
|
assert flow.pending_tool_calls == []
|
|
assert flow.tool_calls_by_id == {}
|
|
assert flow.tool_results == []
|
|
assert flow.tool_calls_ended == set()
|
|
assert flow.interrupts == []
|
|
|
|
def test_get_tool_name(self):
|
|
"""Tests get_tool_name method."""
|
|
flow = FlowState()
|
|
flow.tool_calls_by_id = {"call_123": {"function": {"name": "get_weather", "arguments": "{}"}}}
|
|
|
|
assert flow.get_tool_name("call_123") == "get_weather"
|
|
assert flow.get_tool_name("nonexistent") is None
|
|
assert flow.get_tool_name(None) is None
|
|
|
|
def test_get_tool_name_empty_name(self):
|
|
"""Tests get_tool_name with empty name."""
|
|
flow = FlowState()
|
|
flow.tool_calls_by_id = {"call_123": {"function": {"name": "", "arguments": "{}"}}}
|
|
|
|
assert flow.get_tool_name("call_123") is None
|
|
|
|
def test_get_pending_without_end(self):
|
|
"""Tests get_pending_without_end method."""
|
|
flow = FlowState()
|
|
flow.pending_tool_calls = [
|
|
{"id": "call_1", "function": {"name": "tool1"}},
|
|
{"id": "call_2", "function": {"name": "tool2"}},
|
|
{"id": "call_3", "function": {"name": "tool3"}},
|
|
]
|
|
flow.tool_calls_ended = {"call_1", "call_3"}
|
|
|
|
result = flow.get_pending_without_end()
|
|
assert len(result) == 1
|
|
assert result[0]["id"] == "call_2"
|
|
|
|
|
|
class TestNormalizeResponseStream:
|
|
"""Tests for _normalize_response_stream helper."""
|
|
|
|
async def test_accepts_response_stream(self):
|
|
"""Accept standard ResponseStream values."""
|
|
|
|
async def _stream():
|
|
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
|
|
|
stream = await _normalize_response_stream(ResponseStream(_stream()))
|
|
updates = [update async for update in stream]
|
|
|
|
assert len(updates) == 1
|
|
assert updates[0].contents[0].text == "hello"
|
|
|
|
async def test_accepts_async_iterable(self):
|
|
"""Accept workflow-style async generator streams."""
|
|
|
|
async def _stream():
|
|
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
|
|
|
stream = await _normalize_response_stream(_stream())
|
|
updates = [update async for update in stream]
|
|
|
|
assert len(updates) == 1
|
|
assert updates[0].contents[0].text == "hello"
|
|
|
|
async def test_accepts_awaitable_resolving_to_async_iterable(self):
|
|
"""Accept awaitables that resolve to async iterable streams."""
|
|
|
|
async def _stream():
|
|
yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant")
|
|
|
|
async def _resolve():
|
|
return _stream()
|
|
|
|
stream = await _normalize_response_stream(_resolve())
|
|
updates = [update async for update in stream]
|
|
|
|
assert len(updates) == 1
|
|
assert updates[0].contents[0].text == "hello"
|
|
|
|
async def test_rejects_non_stream_values(self):
|
|
"""Reject unsupported stream return values."""
|
|
with pytest.raises(AgentInvalidResponseException):
|
|
await _normalize_response_stream("not-a-stream")
|
|
|
|
|
|
class TestCreateStateContextMessage:
|
|
"""Tests for _create_state_context_message function."""
|
|
|
|
def test_no_state(self):
|
|
"""Returns None when no state."""
|
|
result = _create_state_context_message({}, {"properties": {}})
|
|
assert result is None
|
|
|
|
def test_no_schema(self):
|
|
"""Returns None when no schema."""
|
|
result = _create_state_context_message({"key": "value"}, {})
|
|
assert result is None
|
|
|
|
def test_creates_message(self):
|
|
"""Creates state context message."""
|
|
|
|
state = {"document": "Hello world"}
|
|
schema = {"properties": {"document": {"type": "string"}}}
|
|
|
|
result = _create_state_context_message(state, schema)
|
|
|
|
assert result is not None
|
|
assert result.role == "system"
|
|
assert len(result.contents) == 1
|
|
assert "Hello world" in result.contents[0].text
|
|
assert "Current state" in result.contents[0].text
|
|
|
|
|
|
class TestInjectStateContext:
|
|
"""Tests for _inject_state_context function."""
|
|
|
|
def test_no_state_message(self):
|
|
"""Returns original messages when no state context needed."""
|
|
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
|
|
result = _inject_state_context(messages, {}, {})
|
|
assert result == messages
|
|
|
|
def test_empty_messages(self):
|
|
"""Returns empty list for empty messages."""
|
|
result = _inject_state_context([], {"key": "value"}, {"properties": {}})
|
|
assert result == []
|
|
|
|
def test_last_message_not_user(self):
|
|
"""Returns original messages when last message is not from user."""
|
|
messages = [
|
|
Message(role="user", contents=[Content.from_text("Hello")]),
|
|
Message(role="assistant", contents=[Content.from_text("Hi")]),
|
|
]
|
|
state = {"key": "value"}
|
|
schema = {"properties": {"key": {"type": "string"}}}
|
|
|
|
result = _inject_state_context(messages, state, schema)
|
|
assert result == messages
|
|
|
|
def test_injects_before_last_user_message(self):
|
|
"""Injects state context before last user message."""
|
|
|
|
messages = [
|
|
Message(role="system", contents=[Content.from_text("You are helpful")]),
|
|
Message(role="user", contents=[Content.from_text("Hello")]),
|
|
]
|
|
state = {"document": "content"}
|
|
schema = {"properties": {"document": {"type": "string"}}}
|
|
|
|
result = _inject_state_context(messages, state, schema)
|
|
|
|
assert len(result) == 3
|
|
# System message first
|
|
assert result[0].role == "system"
|
|
assert "helpful" in result[0].contents[0].text
|
|
# State context second
|
|
assert result[1].role == "system"
|
|
assert "Current state" in result[1].contents[0].text
|
|
# User message last
|
|
assert result[2].role == "user"
|
|
assert "Hello" in result[2].contents[0].text
|
|
|
|
|
|
# Additional tests for _agent_run.py functions
|
|
|
|
|
|
def test_emit_text_basic():
|
|
"""Test _emit_text emits correct events."""
|
|
flow = FlowState()
|
|
content = Content.from_text("Hello world")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert len(events) == 2 # TextMessageStartEvent + TextMessageContentEvent
|
|
assert flow.message_id is not None
|
|
assert flow.accumulated_text == "Hello world"
|
|
|
|
|
|
def test_emit_text_skip_empty():
|
|
"""Test _emit_text skips empty text."""
|
|
flow = FlowState()
|
|
content = Content.from_text("")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert len(events) == 0
|
|
|
|
|
|
def test_emit_text_continues_existing_message():
|
|
"""Test _emit_text continues existing message."""
|
|
flow = FlowState()
|
|
flow.message_id = "existing-id"
|
|
content = Content.from_text("more text")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert len(events) == 1 # Only TextMessageContentEvent, no new start
|
|
assert flow.message_id == "existing-id"
|
|
|
|
|
|
def test_emit_text_skips_duplicate_full_message_delta():
|
|
"""Test _emit_text skips replayed full-message chunks on an open message."""
|
|
flow = FlowState()
|
|
flow.message_id = "existing-id"
|
|
flow.accumulated_text = "Case complete."
|
|
content = Content.from_text("Case complete.")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert events == []
|
|
assert flow.accumulated_text == "Case complete."
|
|
|
|
|
|
def test_emit_text_skips_when_waiting_for_approval():
|
|
"""Test _emit_text skips when waiting for approval."""
|
|
flow = FlowState()
|
|
flow.waiting_for_approval = True
|
|
content = Content.from_text("should skip")
|
|
|
|
events = _emit_text(content, flow)
|
|
|
|
assert len(events) == 0
|
|
|
|
|
|
def test_emit_text_skips_when_skip_text_flag():
|
|
"""Test _emit_text skips with skip_text flag."""
|
|
flow = FlowState()
|
|
content = Content.from_text("should skip")
|
|
|
|
events = _emit_text(content, flow, skip_text=True)
|
|
|
|
assert len(events) == 0
|
|
|
|
|
|
def test_emit_tool_call_basic():
|
|
"""Test _emit_tool_call emits correct events."""
|
|
flow = FlowState()
|
|
content = Content.from_function_call(
|
|
call_id="call_123",
|
|
name="get_weather",
|
|
arguments='{"city": "NYC"}',
|
|
)
|
|
|
|
events = _emit_tool_call(content, flow)
|
|
|
|
assert len(events) >= 1 # At least ToolCallStartEvent
|
|
assert flow.tool_call_id == "call_123"
|
|
assert flow.tool_call_name == "get_weather"
|
|
|
|
|
|
def test_emit_tool_call_generates_id():
|
|
"""Test _emit_tool_call generates ID when not provided."""
|
|
flow = FlowState()
|
|
# Create content without call_id
|
|
content = Content(type="function_call", name="test_tool", arguments="{}")
|
|
|
|
events = _emit_tool_call(content, flow)
|
|
|
|
assert len(events) >= 1
|
|
assert flow.tool_call_id is not None # ID should be generated
|
|
|
|
|
|
def test_emit_tool_call_skips_duplicate_full_arguments_replay():
|
|
"""Test _emit_tool_call skips replayed full-arguments on an existing tool call.
|
|
|
|
This is a regression test for issue #4194 where some streaming providers
|
|
send the full arguments string again after streaming deltas, causing the
|
|
arguments to be doubled in MESSAGES_SNAPSHOT events.
|
|
|
|
Mirrors test_emit_text_skips_duplicate_full_message_delta for consistency.
|
|
"""
|
|
flow = FlowState()
|
|
full_args = '{"city": "Seattle"}'
|
|
|
|
# Step 1: Initial tool call with name + arguments (normal start)
|
|
content_start = Content.from_function_call(
|
|
call_id="call_dup",
|
|
name="get_weather",
|
|
arguments=full_args,
|
|
)
|
|
events_start = _emit_tool_call(content_start, flow)
|
|
|
|
# Should emit ToolCallStartEvent + ToolCallArgsEvent
|
|
assert any(isinstance(e, ToolCallArgsEvent) for e in events_start)
|
|
assert flow.tool_calls_by_id["call_dup"]["function"]["arguments"] == full_args
|
|
|
|
# Step 2: Provider replays the full arguments (duplicate)
|
|
content_replay = Content(type="function_call", call_id="call_dup", arguments=full_args)
|
|
events_replay = _emit_tool_call(content_replay, flow)
|
|
|
|
# Should NOT emit any ToolCallArgsEvent (early return on replay)
|
|
args_events = [e for e in events_replay if isinstance(e, ToolCallArgsEvent)]
|
|
assert args_events == [], "Duplicate full-arguments replay should not emit ToolCallArgsEvent"
|
|
|
|
# Accumulated arguments should remain unchanged
|
|
assert flow.tool_calls_by_id["call_dup"]["function"]["arguments"] == full_args
|
|
|
|
|
|
def test_emit_tool_result_closes_open_message():
|
|
"""Test _emit_tool_result emits TextMessageEndEvent for open text message.
|
|
|
|
This is a regression test for where TEXT_MESSAGE_END was not
|
|
emitted when using MCP tools because the message_id was reset without
|
|
closing the message first.
|
|
"""
|
|
flow = FlowState()
|
|
# Simulate an open text message (e.g., from Feature #4 tool-only detection)
|
|
flow.message_id = "open-msg-123"
|
|
flow.tool_call_id = "call_456"
|
|
|
|
content = Content.from_function_result(call_id="call_456", result="tool result")
|
|
|
|
events = _emit_tool_result(content, flow, predictive_handler=None)
|
|
|
|
# Should have: ToolCallEndEvent, ToolCallResultEvent, TextMessageEndEvent
|
|
assert len(events) == 3
|
|
|
|
# Verify TextMessageEndEvent is emitted for the open message
|
|
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
|
|
assert len(text_end_events) == 1
|
|
assert text_end_events[0].message_id == "open-msg-123"
|
|
|
|
# Verify message_id is reset after
|
|
assert flow.message_id is None
|
|
|
|
|
|
def test_emit_tool_result_no_open_message():
|
|
"""Test _emit_tool_result works when there's no open text message."""
|
|
flow = FlowState()
|
|
# No open message
|
|
flow.message_id = None
|
|
flow.tool_call_id = "call_456"
|
|
|
|
content = Content.from_function_result(call_id="call_456", result="tool result")
|
|
|
|
events = _emit_tool_result(content, flow, predictive_handler=None)
|
|
|
|
# Should have: ToolCallEndEvent, ToolCallResultEvent (no TextMessageEndEvent)
|
|
text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)]
|
|
assert len(text_end_events) == 0
|
|
|
|
|
|
def test_emit_tool_result_serializes_non_string_result():
|
|
"""Non-string tool results should be serialized before emitting TOOL_CALL_RESULT."""
|
|
flow = FlowState()
|
|
content = Content.from_function_result(call_id="call_789", result={"ok": True, "items": [1, 2]})
|
|
|
|
events = _emit_tool_result(content, flow, predictive_handler=None)
|
|
result_event = next(event for event in events if getattr(event, "type", None) == "TOOL_CALL_RESULT")
|
|
|
|
assert isinstance(result_event.content, str)
|
|
assert '"ok": true' in result_event.content
|
|
assert flow.tool_results[0]["content"] == result_event.content
|
|
|
|
|
|
def test_emit_content_usage_emits_custom_usage_event():
|
|
"""Usage content should be emitted as a custom usage event."""
|
|
flow = FlowState()
|
|
content = Content.from_usage({"input_token_count": 3, "output_token_count": 2, "total_token_count": 5})
|
|
|
|
events = _emit_content(content, flow)
|
|
|
|
assert len(events) == 1
|
|
assert events[0].type == "CUSTOM"
|
|
assert events[0].name == "usage"
|
|
assert events[0].value["total_token_count"] == 5
|
|
|
|
|
|
def test_emit_approval_request_populates_interrupt_metadata():
|
|
"""Approval requests should populate FlowState interrupts for RUN_FINISHED metadata."""
|
|
flow = FlowState(message_id="msg-1")
|
|
function_call = Content.from_function_call(call_id="call_123", name="write_doc", arguments={"content": "x"})
|
|
approval_content = Content.from_function_approval_request(id="approval_1", function_call=function_call)
|
|
|
|
_emit_approval_request(approval_content, flow)
|
|
|
|
assert flow.waiting_for_approval is True
|
|
assert len(flow.interrupts) == 1
|
|
assert flow.interrupts[0]["id"] == "call_123"
|
|
assert flow.interrupts[0]["value"]["type"] == "function_approval_request"
|
|
|
|
|
|
def test_emit_approval_request_accumulates_multiple_interrupts():
|
|
"""Multiple approval requests in the same turn should accumulate in flow.interrupts."""
|
|
flow = FlowState(message_id="msg-1")
|
|
|
|
for i in range(1, 4):
|
|
function_call = Content.from_function_call(
|
|
call_id=f"call_{i}",
|
|
name=f"tool_{i}",
|
|
arguments={"arg": f"value_{i}"},
|
|
)
|
|
approval_content = Content.from_function_approval_request(
|
|
id=f"approval_{i}",
|
|
function_call=function_call,
|
|
)
|
|
_emit_approval_request(approval_content, flow)
|
|
|
|
assert len(flow.interrupts) == 3
|
|
interrupt_ids = {intr["id"] for intr in flow.interrupts}
|
|
assert interrupt_ids == {"call_1", "call_2", "call_3"}
|
|
|
|
|
|
def test_resume_to_tool_messages_from_interrupts_payload():
|
|
"""Resume payload interrupt responses map to tool messages."""
|
|
resume = {
|
|
"interrupts": [
|
|
{"id": "req_1", "value": {"accepted": True, "steps": []}},
|
|
{"id": "req_2", "value": "plain value"},
|
|
]
|
|
}
|
|
|
|
messages = _resume_to_tool_messages(resume)
|
|
assert len(messages) == 2
|
|
assert messages[0]["role"] == "tool"
|
|
assert messages[0]["toolCallId"] == "req_1"
|
|
assert '"accepted": true' in messages[0]["content"]
|
|
assert messages[1]["content"] == "plain value"
|
|
|
|
|
|
def test_extract_resume_payload_prefers_top_level_resume():
|
|
"""Top-level resume should take precedence over forwarded props."""
|
|
payload = {
|
|
"resume": {"interrupts": [{"id": "req_1", "value": "approved"}]},
|
|
"forwarded_props": {"command": {"resume": "ignored"}},
|
|
}
|
|
|
|
result = _extract_resume_payload(payload)
|
|
assert result == {"interrupts": [{"id": "req_1", "value": "approved"}]}
|
|
|
|
|
|
def test_extract_resume_payload_reads_forwarded_command_resume():
|
|
"""Forwarded command.resume should be treated as a resume payload."""
|
|
payload = {
|
|
"forwarded_props": {
|
|
"command": {"resume": '{"airline":"KLM","departure":"Amsterdam (AMS)","arrival":"San Francisco (SFO)"}'}
|
|
}
|
|
}
|
|
|
|
result = _extract_resume_payload(payload)
|
|
assert isinstance(result, str)
|
|
assert "KLM" in result
|
|
|
|
|
|
def test_build_run_finished_event_with_interrupt():
|
|
"""RUN_FINISHED helper should preserve interrupt payloads."""
|
|
event = _build_run_finished_event("run-1", "thread-1", interrupts=[{"id": "req_1", "value": {"x": 1}}])
|
|
dumped = event.model_dump()
|
|
|
|
assert dumped["run_id"] == "run-1"
|
|
assert dumped["thread_id"] == "thread-1"
|
|
assert dumped["interrupt"] == [{"id": "req_1", "value": {"x": 1}}]
|
|
|
|
|
|
def test_extract_approved_state_updates_no_handler():
|
|
"""Test _extract_approved_state_updates returns empty with no handler."""
|
|
from agent_framework_ag_ui._agent_run import _extract_approved_state_updates
|
|
|
|
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
|
|
result = _extract_approved_state_updates(messages, None)
|
|
assert result == {}
|
|
|
|
|
|
def test_extract_approved_state_updates_no_approval():
|
|
"""Test _extract_approved_state_updates returns empty when no approval content."""
|
|
from agent_framework_ag_ui._agent_run import _extract_approved_state_updates
|
|
from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler
|
|
|
|
handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "content"}})
|
|
messages = [Message(role="user", contents=[Content.from_text("Hello")])]
|
|
result = _extract_approved_state_updates(messages, handler)
|
|
assert result == {}
|
|
|
|
|
|
class TestBuildMessagesSnapshot:
|
|
"""Tests for _build_messages_snapshot function."""
|
|
|
|
def test_tool_calls_and_text_are_separate_messages(self):
|
|
"""Test that tool calls and text content are emitted as separate messages.
|
|
|
|
This is a regression test for issue #3619 where tool calls and content
|
|
were incorrectly merged into a single assistant message.
|
|
"""
|
|
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.message_id = "msg-123"
|
|
flow.pending_tool_calls = [
|
|
{"id": "call_1", "function": {"name": "get_weather", "arguments": '{"city": "NYC"}'}},
|
|
]
|
|
flow.accumulated_text = "Here is the weather information."
|
|
flow.tool_results = [{"id": "result-1", "role": "tool", "content": '{"temp": 72}', "toolCallId": "call_1"}]
|
|
|
|
result = _build_messages_snapshot(flow, [])
|
|
|
|
# Should have 3 messages: tool call msg, tool result, text content msg
|
|
assert len(result.messages) == 3
|
|
|
|
# First message: assistant with tool calls only (no content)
|
|
assistant_tool_msg = result.messages[0]
|
|
assert assistant_tool_msg.role == "assistant"
|
|
assert assistant_tool_msg.tool_calls is not None
|
|
assert len(assistant_tool_msg.tool_calls) == 1
|
|
assert assistant_tool_msg.content is None
|
|
|
|
# Second message: tool result
|
|
tool_result_msg = result.messages[1]
|
|
assert tool_result_msg.role == "tool"
|
|
|
|
# Third message: assistant with content only (no tool calls)
|
|
assistant_text_msg = result.messages[2]
|
|
assert assistant_text_msg.role == "assistant"
|
|
assert assistant_text_msg.content == "Here is the weather information."
|
|
assert assistant_text_msg.tool_calls is None
|
|
|
|
# The text message should have a different ID than the tool call message
|
|
assert assistant_text_msg.id != assistant_tool_msg.id
|
|
|
|
def test_only_tool_calls_no_text(self):
|
|
"""Test snapshot with only tool calls and no accumulated text."""
|
|
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.message_id = "msg-123"
|
|
flow.pending_tool_calls = [
|
|
{"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}},
|
|
]
|
|
flow.accumulated_text = ""
|
|
flow.tool_results = []
|
|
|
|
result = _build_messages_snapshot(flow, [])
|
|
|
|
# Should have 1 message: tool call msg only
|
|
assert len(result.messages) == 1
|
|
assert result.messages[0].role == "assistant"
|
|
assert result.messages[0].tool_calls is not None
|
|
assert result.messages[0].content is None
|
|
|
|
def test_only_text_no_tool_calls(self):
|
|
"""Test snapshot with only text and no tool calls."""
|
|
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.message_id = "msg-123"
|
|
flow.pending_tool_calls = []
|
|
flow.accumulated_text = "Hello world"
|
|
flow.tool_results = []
|
|
|
|
result = _build_messages_snapshot(flow, [])
|
|
|
|
# Should have 1 message: text content msg only
|
|
assert len(result.messages) == 1
|
|
assert result.messages[0].role == "assistant"
|
|
assert result.messages[0].content == "Hello world"
|
|
assert result.messages[0].tool_calls is None
|
|
# Should use the existing message_id
|
|
assert result.messages[0].id == "msg-123"
|
|
|
|
def test_preserves_snapshot_messages(self):
|
|
"""Test that existing snapshot messages are preserved."""
|
|
from agent_framework_ag_ui._agent_run import FlowState, _build_messages_snapshot
|
|
|
|
flow = FlowState()
|
|
flow.pending_tool_calls = []
|
|
flow.accumulated_text = ""
|
|
|
|
existing_messages = [
|
|
{"id": "user-1", "role": "user", "content": "Hello"},
|
|
{"id": "assist-1", "role": "assistant", "content": "Hi there"},
|
|
]
|
|
|
|
result = _build_messages_snapshot(flow, existing_messages)
|
|
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].id == "user-1"
|
|
assert result.messages[1].id == "assist-1"
|
|
|
|
|
|
def test_malformed_json_in_confirm_args_skips_confirmation():
|
|
"""Test that malformed JSON in tool arguments skips confirm_changes flow.
|
|
|
|
This is a regression test to ensure that when tool arguments contain malformed
|
|
JSON, the code skips the confirmation flow entirely rather than crashing or
|
|
showing incomplete data to the user.
|
|
"""
|
|
import json
|
|
|
|
# Simulate the parsing logic - malformed JSON should trigger skip
|
|
malformed_arguments = "{ invalid json }"
|
|
tool_call = {"function": {"name": "write_doc", "arguments": malformed_arguments}}
|
|
|
|
# This is what the code should do - detect parsing failure and skip
|
|
should_skip_confirmation = False
|
|
try:
|
|
json.loads(tool_call.get("function", {}).get("arguments", "{}"))
|
|
except json.JSONDecodeError:
|
|
should_skip_confirmation = True
|
|
|
|
# Should skip confirmation when JSON is malformed
|
|
assert should_skip_confirmation is True
|
|
|
|
# Valid JSON should proceed with confirmation
|
|
valid_arguments = '{"content": "hello"}'
|
|
tool_call_valid = {"function": {"name": "write_doc", "arguments": valid_arguments}}
|
|
should_skip_confirmation = False
|
|
try:
|
|
function_arguments = json.loads(tool_call_valid.get("function", {}).get("arguments", "{}"))
|
|
except json.JSONDecodeError:
|
|
should_skip_confirmation = True
|
|
|
|
assert should_skip_confirmation is False
|
|
assert function_arguments == {"content": "hello"}
|
|
|
|
|
|
class TestTextMessageEventBalancing:
|
|
"""Tests for proper TEXT_MESSAGE_START/END event balancing.
|
|
|
|
These tests verify that the streaming flow produces balanced pairs of
|
|
TextMessageStartEvent and TextMessageEndEvent, especially when tool
|
|
execution is involved.
|
|
"""
|
|
|
|
def test_tool_only_flow_produces_balanced_events(self):
|
|
"""Test that a tool-only response produces balanced TEXT_MESSAGE events.
|
|
|
|
This simulates the scenario where the LLM immediately calls a tool
|
|
without any initial text, then returns text after the tool result.
|
|
"""
|
|
flow = FlowState()
|
|
all_events: list = []
|
|
|
|
# Step 1: LLM outputs function_call only (no text)
|
|
func_call_content = Content.from_function_call(
|
|
call_id="call_weather",
|
|
name="get_weather",
|
|
arguments='{"city": "Seattle"}',
|
|
)
|
|
|
|
# Feature #4 check: this should trigger TextMessageStartEvent
|
|
contents = [func_call_content]
|
|
if not flow.message_id and _has_only_tool_calls(contents):
|
|
flow.message_id = "tool-msg-1"
|
|
all_events.append(TextMessageStartEvent(message_id=flow.message_id, role="assistant"))
|
|
|
|
# Emit tool call events
|
|
all_events.extend(_emit_content(func_call_content, flow))
|
|
|
|
# Step 2: Tool executes and returns result
|
|
func_result_content = Content.from_function_result(
|
|
call_id="call_weather",
|
|
result='{"temp": 55, "conditions": "rainy"}',
|
|
)
|
|
|
|
# This should close the text message
|
|
all_events.extend(_emit_tool_result(func_result_content, flow))
|
|
|
|
# Verify message_id was reset
|
|
assert flow.message_id is None, "message_id should be reset after tool result"
|
|
|
|
# Step 3: LLM outputs text response
|
|
text_content = Content.from_text("The weather in Seattle is 55°F and rainy.")
|
|
|
|
# Since message_id is None, _emit_text should create a new one
|
|
for event in _emit_content(text_content, flow):
|
|
all_events.append(event)
|
|
|
|
# Step 4: End of stream - emit final TextMessageEndEvent
|
|
if flow.message_id:
|
|
all_events.append(TextMessageEndEvent(message_id=flow.message_id))
|
|
|
|
# Verify event counts
|
|
start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)]
|
|
end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
|
|
|
|
# Should have 2 TextMessageStartEvent and 2 TextMessageEndEvent
|
|
assert len(start_events) == 2, f"Expected 2 start events, got {len(start_events)}"
|
|
assert len(end_events) == 2, f"Expected 2 end events, got {len(end_events)}"
|
|
|
|
# Verify order: first message should start and end before second starts
|
|
# Find indices
|
|
start_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageStartEvent)]
|
|
end_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageEndEvent)]
|
|
|
|
# First end should come before second start
|
|
assert end_indices[0] < start_indices[1], (
|
|
f"First TextMessageEndEvent (index {end_indices[0]}) should come "
|
|
f"before second TextMessageStartEvent (index {start_indices[1]})"
|
|
)
|
|
|
|
def test_text_then_tool_flow(self):
|
|
"""Test flow where LLM outputs text first, then calls a tool.
|
|
|
|
This simulates: "Let me check the weather..." -> tool call -> tool result -> "The weather is..."
|
|
"""
|
|
flow = FlowState()
|
|
all_events: list = []
|
|
|
|
# Step 1: LLM outputs text first
|
|
text1 = Content.from_text("Let me check the weather for you.")
|
|
all_events.extend(_emit_content(text1, flow))
|
|
|
|
# Verify message_id is set
|
|
assert flow.message_id is not None, "message_id should be set after text"
|
|
first_msg_id = flow.message_id
|
|
|
|
# Step 2: LLM outputs function_call
|
|
func_call = Content.from_function_call(
|
|
call_id="call_1",
|
|
name="get_weather",
|
|
arguments="{}",
|
|
)
|
|
all_events.extend(_emit_content(func_call, flow))
|
|
|
|
# Step 3: Tool result comes back
|
|
func_result = Content.from_function_result(call_id="call_1", result="sunny")
|
|
all_events.extend(_emit_tool_result(func_result, flow))
|
|
|
|
# Verify message_id was reset and first message was closed
|
|
assert flow.message_id is None
|
|
end_events_so_far = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
|
|
assert len(end_events_so_far) == 1
|
|
assert end_events_so_far[0].message_id == first_msg_id
|
|
|
|
# Step 4: LLM outputs follow-up text
|
|
text2 = Content.from_text("The weather is sunny!")
|
|
all_events.extend(_emit_content(text2, flow))
|
|
|
|
# Step 5: End of stream
|
|
if flow.message_id:
|
|
all_events.append(TextMessageEndEvent(message_id=flow.message_id))
|
|
|
|
# Verify balance
|
|
start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)]
|
|
end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)]
|
|
|
|
assert len(start_events) == 2
|
|
assert len(end_events) == 2
|
|
|
|
|
|
async def test_run_agent_stream_accumulates_multiple_confirm_interrupts():
|
|
"""Multiple predictive tool calls in a single streaming run should accumulate interrupts.
|
|
|
|
This exercises the confirm_changes path in run_agent_stream (_agent_run.py),
|
|
ensuring that flow.interrupts.append() works correctly for multiple tool calls
|
|
and all interrupts appear in the RUN_FINISHED event.
|
|
"""
|
|
import json
|
|
|
|
from conftest import StubAgent
|
|
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
predict_config = {
|
|
"tasks": {"tool": "generate_tasks", "tool_argument": "steps"},
|
|
"notes": {"tool": "generate_notes", "tool_argument": "items"},
|
|
}
|
|
state_schema = {
|
|
"tasks": {"type": "array", "items": {"type": "object"}},
|
|
"notes": {"type": "array", "items": {"type": "object"}},
|
|
}
|
|
|
|
updates = [
|
|
AgentResponseUpdate(
|
|
contents=[
|
|
Content.from_function_call(
|
|
name="generate_tasks",
|
|
call_id="call-tasks",
|
|
arguments=json.dumps({"steps": [{"description": "Task 1"}]}),
|
|
),
|
|
Content.from_function_call(
|
|
name="generate_notes",
|
|
call_id="call-notes",
|
|
arguments=json.dumps({"items": [{"description": "Note 1"}]}),
|
|
),
|
|
],
|
|
role="assistant",
|
|
),
|
|
]
|
|
|
|
stub = StubAgent(updates=updates)
|
|
agent = AgentFrameworkAgent(
|
|
agent=stub,
|
|
state_schema=state_schema,
|
|
predict_state_config=predict_config,
|
|
require_confirmation=True,
|
|
)
|
|
|
|
payload = {
|
|
"thread_id": "thread-multi",
|
|
"run_id": "run-multi",
|
|
"messages": [{"role": "user", "content": "Generate tasks and notes"}],
|
|
"state": {"tasks": [], "notes": []},
|
|
}
|
|
|
|
events = [event async for event in agent.run(payload)]
|
|
|
|
# Find RUN_FINISHED event and verify multiple interrupts
|
|
finished_events = [
|
|
e
|
|
for e in events
|
|
if getattr(e, "type", None) == "RUN_FINISHED"
|
|
or getattr(getattr(e, "type", None), "value", None) == "RUN_FINISHED"
|
|
]
|
|
assert finished_events, f"Expected RUN_FINISHED event. Types: {[getattr(e, 'type', None) for e in events]}"
|
|
finished = finished_events[-1]
|
|
interrupt = getattr(finished, "interrupt", None)
|
|
assert interrupt is not None, "Expected interrupt metadata in RUN_FINISHED"
|
|
assert len(interrupt) == 2, f"Expected 2 interrupts (one per tool), got {len(interrupt)}"
|
|
|
|
# Verify both tool calls are represented in interrupt metadata
|
|
interrupt_tool_names = {i["value"]["function_call"]["name"] for i in interrupt}
|
|
assert interrupt_tool_names == {"generate_tasks", "generate_notes"}
|
|
|
|
|
|
def test_emit_oauth_consent_request():
|
|
"""Test that oauth_consent_request content emits a CustomEvent."""
|
|
content = Content.from_oauth_consent_request(
|
|
consent_link="https://login.microsoftonline.com/consent",
|
|
)
|
|
flow = FlowState()
|
|
events = _emit_content(content, flow)
|
|
|
|
assert len(events) == 1
|
|
assert isinstance(events[0], CustomEvent)
|
|
assert events[0].name == "oauth_consent_request"
|
|
assert events[0].value == {"consent_link": "https://login.microsoftonline.com/consent"}
|
|
|
|
|
|
def test_emit_oauth_consent_request_no_link():
|
|
"""Test that oauth_consent_request without a consent_link emits no events."""
|
|
content = Content("oauth_consent_request")
|
|
flow = FlowState()
|
|
events = _emit_content(content, flow)
|
|
|
|
assert len(events) == 0
|