mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
35a8565495
* Add AG-UI integration * Fix tests. PR feedback * Cleanup * PR Feedback * Improve README and getting started experience * Fix links
578 lines
21 KiB
Python
578 lines
21 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Comprehensive tests for AgentFrameworkAgent (_agent.py)."""
|
|
|
|
import json
|
|
|
|
import pytest
|
|
from agent_framework import ChatAgent, TextContent
|
|
from agent_framework._types import ChatResponseUpdate
|
|
|
|
|
|
async def test_agent_initialization_basic():
|
|
"""Test basic agent initialization without state schema."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
assert wrapper.name == "test_agent"
|
|
assert wrapper.agent == agent
|
|
assert wrapper.config.state_schema == {}
|
|
assert wrapper.config.predict_state_config == {}
|
|
|
|
|
|
async def test_agent_initialization_with_state_schema():
|
|
"""Test agent initialization with state_schema."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
state_schema = {"document": {"type": "string"}}
|
|
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
|
|
|
|
assert wrapper.config.state_schema == state_schema
|
|
|
|
|
|
async def test_agent_initialization_with_predict_state_config():
|
|
"""Test agent initialization with predict_state_config."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
|
|
wrapper = AgentFrameworkAgent(agent=agent, predict_state_config=predict_config)
|
|
|
|
assert wrapper.config.predict_state_config == predict_config
|
|
|
|
|
|
async def test_run_started_event_emission():
|
|
"""Test RunStartedEvent is emitted at start of run."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# First event should be RunStartedEvent
|
|
assert events[0].type == "RUN_STARTED"
|
|
assert events[0].run_id is not None
|
|
assert events[0].thread_id is not None
|
|
|
|
|
|
async def test_predict_state_custom_event_emission():
|
|
"""Test PredictState CustomEvent is emitted when predict_state_config is present."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
predict_config = {
|
|
"document": {"tool": "write_doc", "tool_argument": "content"},
|
|
"summary": {"tool": "summarize", "tool_argument": "text"},
|
|
}
|
|
wrapper = AgentFrameworkAgent(agent=agent, predict_state_config=predict_config)
|
|
|
|
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Find PredictState event
|
|
predict_events = [e for e in events if e.type == "CUSTOM" and e.name == "PredictState"]
|
|
assert len(predict_events) == 1
|
|
|
|
predict_value = predict_events[0].value
|
|
assert len(predict_value) == 2
|
|
assert {"state_key": "document", "tool": "write_doc", "tool_argument": "content"} in predict_value
|
|
assert {"state_key": "summary", "tool": "summarize", "tool_argument": "text"} in predict_value
|
|
|
|
|
|
async def test_initial_state_snapshot_with_schema():
|
|
"""Test initial StateSnapshotEvent emission when state_schema present."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
state_schema = {"document": {"type": "string"}}
|
|
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
|
|
|
|
input_data = {
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"state": {"document": "Initial content"},
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Find StateSnapshotEvent
|
|
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
|
|
assert len(snapshot_events) >= 1
|
|
|
|
# First snapshot should have initial state
|
|
assert snapshot_events[0].snapshot == {"document": "Initial content"}
|
|
|
|
|
|
async def test_state_initialization_object_type():
|
|
"""Test state initialization with object type in schema."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
state_schema = {"recipe": {"type": "object", "properties": {}}}
|
|
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
|
|
|
|
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Find StateSnapshotEvent
|
|
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
|
|
assert len(snapshot_events) >= 1
|
|
|
|
# Should initialize as empty object
|
|
assert snapshot_events[0].snapshot == {"recipe": {}}
|
|
|
|
|
|
async def test_state_initialization_array_type():
|
|
"""Test state initialization with array type in schema."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
state_schema = {"steps": {"type": "array", "items": {}}}
|
|
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
|
|
|
|
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Find StateSnapshotEvent
|
|
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
|
|
assert len(snapshot_events) >= 1
|
|
|
|
# Should initialize as empty array
|
|
assert snapshot_events[0].snapshot == {"steps": []}
|
|
|
|
|
|
async def test_run_finished_event_emission():
|
|
"""Test RunFinishedEvent is emitted at end of run."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Last event should be RunFinishedEvent
|
|
assert events[-1].type == "RUN_FINISHED"
|
|
|
|
|
|
async def test_tool_result_confirm_changes_accepted():
|
|
"""Test confirm_changes tool result handling when accepted."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Document updated")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(
|
|
agent=agent,
|
|
state_schema={"document": {"type": "string"}},
|
|
predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}},
|
|
)
|
|
|
|
# Simulate tool result message with acceptance
|
|
tool_result = {"accepted": True, "steps": []}
|
|
input_data = {
|
|
"messages": [
|
|
{
|
|
"role": "tool", # Tool result from UI
|
|
"content": json.dumps(tool_result),
|
|
"toolCallId": "confirm_call_123",
|
|
}
|
|
],
|
|
"state": {"document": "Updated content"},
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Should emit text message confirming acceptance
|
|
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
|
|
assert len(text_content_events) > 0
|
|
# Should contain confirmation message mentioning the state key or generic confirmation
|
|
confirmation_found = any(
|
|
"document" in e.delta.lower()
|
|
or "confirm" in e.delta.lower()
|
|
or "applied" in e.delta.lower()
|
|
or "changes" in e.delta.lower()
|
|
for e in text_content_events
|
|
)
|
|
assert confirmation_found, f"No confirmation in deltas: {[e.delta for e in text_content_events]}"
|
|
|
|
|
|
async def test_tool_result_confirm_changes_rejected():
|
|
"""Test confirm_changes tool result handling when rejected."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
# Simulate tool result message with rejection
|
|
tool_result = {"accepted": False, "steps": []}
|
|
input_data = {
|
|
"messages": [
|
|
{
|
|
"role": "tool",
|
|
"content": json.dumps(tool_result),
|
|
"toolCallId": "confirm_call_123",
|
|
}
|
|
],
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Should emit text message asking what to change
|
|
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
|
|
assert len(text_content_events) > 0
|
|
assert any("what would you like me to change" in e.delta.lower() for e in text_content_events)
|
|
|
|
|
|
async def test_tool_result_function_approval_accepted():
|
|
"""Test function approval tool result when steps are accepted."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
# Simulate tool result with multiple steps
|
|
tool_result = {
|
|
"accepted": True,
|
|
"steps": [
|
|
{"id": "step1", "description": "Send email", "status": "enabled"},
|
|
{"id": "step2", "description": "Create calendar event", "status": "enabled"},
|
|
],
|
|
}
|
|
input_data = {
|
|
"messages": [
|
|
{
|
|
"role": "tool",
|
|
"content": json.dumps(tool_result),
|
|
"toolCallId": "approval_call_123",
|
|
}
|
|
],
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Should list enabled steps
|
|
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
|
|
assert len(text_content_events) > 0
|
|
|
|
# Concatenate all text content
|
|
full_text = "".join(e.delta for e in text_content_events)
|
|
assert "executing" in full_text.lower()
|
|
assert "2 approved steps" in full_text.lower()
|
|
assert "send email" in full_text.lower()
|
|
assert "create calendar event" in full_text.lower()
|
|
|
|
|
|
async def test_tool_result_function_approval_rejected():
|
|
"""Test function approval tool result when rejected."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
# Simulate tool result rejection with steps
|
|
tool_result = {
|
|
"accepted": False,
|
|
"steps": [{"id": "step1", "description": "Send email", "status": "disabled"}],
|
|
}
|
|
input_data = {
|
|
"messages": [
|
|
{
|
|
"role": "tool",
|
|
"content": json.dumps(tool_result),
|
|
"toolCallId": "approval_call_123",
|
|
}
|
|
],
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Should ask what to change about the plan
|
|
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
|
|
assert len(text_content_events) > 0
|
|
assert any("what would you like me to change about the plan" in e.delta.lower() for e in text_content_events)
|
|
|
|
|
|
async def test_thread_metadata_tracking():
|
|
"""Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
thread_metadata = {}
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
# Capture thread metadata from kwargs
|
|
nonlocal thread_metadata
|
|
if "thread" in kwargs:
|
|
thread_metadata = kwargs["thread"].metadata
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
input_data = {
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"thread_id": "test_thread_123",
|
|
"run_id": "test_run_456",
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Check thread metadata was set
|
|
# Note: This test may need adjustment based on actual thread passing mechanism
|
|
|
|
|
|
async def test_state_context_injection():
|
|
"""Test that current state is injected into thread metadata."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
thread_metadata = {}
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
# Track if state context message was added
|
|
nonlocal thread_metadata
|
|
# In actual implementation, thread is passed and state is in metadata
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(
|
|
agent=agent,
|
|
state_schema={"document": {"type": "string"}},
|
|
)
|
|
|
|
input_data = {
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"state": {"document": "Test content"},
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# State should be injected - this is validated by agent execution flow
|
|
|
|
|
|
async def test_no_messages_provided():
|
|
"""Test handling when no messages are provided."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
input_data = {"messages": []}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Should emit RunStartedEvent and RunFinishedEvent only
|
|
assert len(events) == 2
|
|
assert events[0].type == "RUN_STARTED"
|
|
assert events[-1].type == "RUN_FINISHED"
|
|
|
|
|
|
async def test_message_end_event_emission():
|
|
"""Test TextMessageEndEvent is emitted for assistant messages."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Hello world")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Should have TextMessageEndEvent before RunFinishedEvent
|
|
end_events = [e for e in events if e.type == "TEXT_MESSAGE_END"]
|
|
assert len(end_events) == 1
|
|
|
|
# EndEvent should come before FinishedEvent
|
|
end_index = events.index(end_events[0])
|
|
finished_index = events.index([e for e in events if e.type == "RUN_FINISHED"][0])
|
|
assert end_index < finished_index
|
|
|
|
|
|
async def test_error_handling_with_exception():
|
|
"""Test that exceptions during agent execution are re-raised."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class FailingChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
if False:
|
|
yield
|
|
raise RuntimeError("Simulated failure")
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=FailingChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
|
|
|
with pytest.raises(RuntimeError, match="Simulated failure"):
|
|
async for event in wrapper.run_agent(input_data):
|
|
pass
|
|
|
|
|
|
async def test_json_decode_error_in_tool_result():
|
|
"""Test handling of JSONDecodeError when parsing tool result."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Fallback response")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(agent=agent)
|
|
|
|
# Send invalid JSON as tool result
|
|
input_data = {
|
|
"messages": [
|
|
{
|
|
"role": "tool",
|
|
"content": "invalid json {not valid}",
|
|
"toolCallId": "call_123",
|
|
}
|
|
],
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Should fall through to normal agent processing
|
|
text_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
|
|
assert len(text_events) > 0
|
|
assert text_events[0].delta == "Fallback response"
|
|
|
|
|
|
async def test_suppressed_summary_with_document_state():
|
|
"""Test suppressed summary uses document state for confirmation message."""
|
|
from agent_framework_ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy
|
|
|
|
class MockChatClient:
|
|
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
|
yield ChatResponseUpdate(contents=[TextContent(text="Response")])
|
|
|
|
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
|
wrapper = AgentFrameworkAgent(
|
|
agent=agent,
|
|
state_schema={"document": {"type": "string"}},
|
|
predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}},
|
|
confirmation_strategy=DocumentWriterConfirmationStrategy(),
|
|
)
|
|
|
|
# Simulate confirmation with document state
|
|
tool_result = {"accepted": True, "steps": []}
|
|
input_data = {
|
|
"messages": [
|
|
{
|
|
"role": "tool",
|
|
"content": json.dumps(tool_result),
|
|
"toolCallId": "confirm_123",
|
|
}
|
|
],
|
|
"state": {"document": "This is the beginning of a document. It contains important information."},
|
|
}
|
|
|
|
events = []
|
|
async for event in wrapper.run_agent(input_data):
|
|
events.append(event)
|
|
|
|
# Should generate fallback summary from document state
|
|
text_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
|
|
assert len(text_events) > 0
|
|
# Should contain some reference to the document
|
|
full_text = "".join(e.delta for e in text_events)
|
|
assert "written" in full_text.lower() or "document" in full_text.lower()
|