Files
agent-framework/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py
T
Evan Mattson 35a8565495 Python: AG-UI protocol support (#1826)
* Add AG-UI integration

* Fix tests. PR feedback

* Cleanup

* PR Feedback

* Improve README and getting started experience

* Fix links
2025-11-05 05:25:24 +00:00

578 lines
21 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Comprehensive tests for AgentFrameworkAgent (_agent.py)."""
import json
import pytest
from agent_framework import ChatAgent, TextContent
from agent_framework._types import ChatResponseUpdate
async def test_agent_initialization_basic():
"""Test basic agent initialization without state schema."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
assert wrapper.name == "test_agent"
assert wrapper.agent == agent
assert wrapper.config.state_schema == {}
assert wrapper.config.predict_state_config == {}
async def test_agent_initialization_with_state_schema():
"""Test agent initialization with state_schema."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
state_schema = {"document": {"type": "string"}}
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
assert wrapper.config.state_schema == state_schema
async def test_agent_initialization_with_predict_state_config():
"""Test agent initialization with predict_state_config."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
wrapper = AgentFrameworkAgent(agent=agent, predict_state_config=predict_config)
assert wrapper.config.predict_state_config == predict_config
async def test_run_started_event_emission():
"""Test RunStartedEvent is emitted at start of run."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# First event should be RunStartedEvent
assert events[0].type == "RUN_STARTED"
assert events[0].run_id is not None
assert events[0].thread_id is not None
async def test_predict_state_custom_event_emission():
"""Test PredictState CustomEvent is emitted when predict_state_config is present."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
predict_config = {
"document": {"tool": "write_doc", "tool_argument": "content"},
"summary": {"tool": "summarize", "tool_argument": "text"},
}
wrapper = AgentFrameworkAgent(agent=agent, predict_state_config=predict_config)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Find PredictState event
predict_events = [e for e in events if e.type == "CUSTOM" and e.name == "PredictState"]
assert len(predict_events) == 1
predict_value = predict_events[0].value
assert len(predict_value) == 2
assert {"state_key": "document", "tool": "write_doc", "tool_argument": "content"} in predict_value
assert {"state_key": "summary", "tool": "summarize", "tool_argument": "text"} in predict_value
async def test_initial_state_snapshot_with_schema():
"""Test initial StateSnapshotEvent emission when state_schema present."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
state_schema = {"document": {"type": "string"}}
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
input_data = {
"messages": [{"role": "user", "content": "Hi"}],
"state": {"document": "Initial content"},
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Find StateSnapshotEvent
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
assert len(snapshot_events) >= 1
# First snapshot should have initial state
assert snapshot_events[0].snapshot == {"document": "Initial content"}
async def test_state_initialization_object_type():
"""Test state initialization with object type in schema."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
state_schema = {"recipe": {"type": "object", "properties": {}}}
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Find StateSnapshotEvent
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
assert len(snapshot_events) >= 1
# Should initialize as empty object
assert snapshot_events[0].snapshot == {"recipe": {}}
async def test_state_initialization_array_type():
"""Test state initialization with array type in schema."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
state_schema = {"steps": {"type": "array", "items": {}}}
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Find StateSnapshotEvent
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
assert len(snapshot_events) >= 1
# Should initialize as empty array
assert snapshot_events[0].snapshot == {"steps": []}
async def test_run_finished_event_emission():
"""Test RunFinishedEvent is emitted at end of run."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Last event should be RunFinishedEvent
assert events[-1].type == "RUN_FINISHED"
async def test_tool_result_confirm_changes_accepted():
"""Test confirm_changes tool result handling when accepted."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Document updated")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"document": {"type": "string"}},
predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}},
)
# Simulate tool result message with acceptance
tool_result = {"accepted": True, "steps": []}
input_data = {
"messages": [
{
"role": "tool", # Tool result from UI
"content": json.dumps(tool_result),
"toolCallId": "confirm_call_123",
}
],
"state": {"document": "Updated content"},
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit text message confirming acceptance
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert len(text_content_events) > 0
# Should contain confirmation message mentioning the state key or generic confirmation
confirmation_found = any(
"document" in e.delta.lower()
or "confirm" in e.delta.lower()
or "applied" in e.delta.lower()
or "changes" in e.delta.lower()
for e in text_content_events
)
assert confirmation_found, f"No confirmation in deltas: {[e.delta for e in text_content_events]}"
async def test_tool_result_confirm_changes_rejected():
"""Test confirm_changes tool result handling when rejected."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
# Simulate tool result message with rejection
tool_result = {"accepted": False, "steps": []}
input_data = {
"messages": [
{
"role": "tool",
"content": json.dumps(tool_result),
"toolCallId": "confirm_call_123",
}
],
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit text message asking what to change
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert len(text_content_events) > 0
assert any("what would you like me to change" in e.delta.lower() for e in text_content_events)
async def test_tool_result_function_approval_accepted():
"""Test function approval tool result when steps are accepted."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
# Simulate tool result with multiple steps
tool_result = {
"accepted": True,
"steps": [
{"id": "step1", "description": "Send email", "status": "enabled"},
{"id": "step2", "description": "Create calendar event", "status": "enabled"},
],
}
input_data = {
"messages": [
{
"role": "tool",
"content": json.dumps(tool_result),
"toolCallId": "approval_call_123",
}
],
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should list enabled steps
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert len(text_content_events) > 0
# Concatenate all text content
full_text = "".join(e.delta for e in text_content_events)
assert "executing" in full_text.lower()
assert "2 approved steps" in full_text.lower()
assert "send email" in full_text.lower()
assert "create calendar event" in full_text.lower()
async def test_tool_result_function_approval_rejected():
"""Test function approval tool result when rejected."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
# Simulate tool result rejection with steps
tool_result = {
"accepted": False,
"steps": [{"id": "step1", "description": "Send email", "status": "disabled"}],
}
input_data = {
"messages": [
{
"role": "tool",
"content": json.dumps(tool_result),
"toolCallId": "approval_call_123",
}
],
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should ask what to change about the plan
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert len(text_content_events) > 0
assert any("what would you like me to change about the plan" in e.delta.lower() for e in text_content_events)
async def test_thread_metadata_tracking():
"""Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id."""
from agent_framework_ag_ui import AgentFrameworkAgent
thread_metadata = {}
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
# Capture thread metadata from kwargs
nonlocal thread_metadata
if "thread" in kwargs:
thread_metadata = kwargs["thread"].metadata
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {
"messages": [{"role": "user", "content": "Hi"}],
"thread_id": "test_thread_123",
"run_id": "test_run_456",
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Check thread metadata was set
# Note: This test may need adjustment based on actual thread passing mechanism
async def test_state_context_injection():
"""Test that current state is injected into thread metadata."""
from agent_framework_ag_ui import AgentFrameworkAgent
thread_metadata = {}
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
# Track if state context message was added
nonlocal thread_metadata
# In actual implementation, thread is passed and state is in metadata
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"document": {"type": "string"}},
)
input_data = {
"messages": [{"role": "user", "content": "Hi"}],
"state": {"document": "Test content"},
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# State should be injected - this is validated by agent execution flow
async def test_no_messages_provided():
"""Test handling when no messages are provided."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": []}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit RunStartedEvent and RunFinishedEvent only
assert len(events) == 2
assert events[0].type == "RUN_STARTED"
assert events[-1].type == "RUN_FINISHED"
async def test_message_end_event_emission():
"""Test TextMessageEndEvent is emitted for assistant messages."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello world")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should have TextMessageEndEvent before RunFinishedEvent
end_events = [e for e in events if e.type == "TEXT_MESSAGE_END"]
assert len(end_events) == 1
# EndEvent should come before FinishedEvent
end_index = events.index(end_events[0])
finished_index = events.index([e for e in events if e.type == "RUN_FINISHED"][0])
assert end_index < finished_index
async def test_error_handling_with_exception():
"""Test that exceptions during agent execution are re-raised."""
from agent_framework_ag_ui import AgentFrameworkAgent
class FailingChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
if False:
yield
raise RuntimeError("Simulated failure")
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=FailingChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
with pytest.raises(RuntimeError, match="Simulated failure"):
async for event in wrapper.run_agent(input_data):
pass
async def test_json_decode_error_in_tool_result():
"""Test handling of JSONDecodeError when parsing tool result."""
from agent_framework_ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Fallback response")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(agent=agent)
# Send invalid JSON as tool result
input_data = {
"messages": [
{
"role": "tool",
"content": "invalid json {not valid}",
"toolCallId": "call_123",
}
],
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should fall through to normal agent processing
text_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert len(text_events) > 0
assert text_events[0].delta == "Fallback response"
async def test_suppressed_summary_with_document_state():
"""Test suppressed summary uses document state for confirmation message."""
from agent_framework_ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Response")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"document": {"type": "string"}},
predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}},
confirmation_strategy=DocumentWriterConfirmationStrategy(),
)
# Simulate confirmation with document state
tool_result = {"accepted": True, "steps": []}
input_data = {
"messages": [
{
"role": "tool",
"content": json.dumps(tool_result),
"toolCallId": "confirm_123",
}
],
"state": {"document": "This is the beginning of a document. It contains important information."},
}
events = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should generate fallback summary from document state
text_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert len(text_events) > 0
# Should contain some reference to the document
full_text = "".join(e.delta for e in text_events)
assert "written" in full_text.lower() or "document" in full_text.lower()