Files
agent-framework/python/packages/ag-ui/tests/test_structured_output.py
T
Evan Mattson 8cf8b0f995 Python: Refactor ag-ui to clean up some patterns (#2363)
* Refactor ag-ui to clean up some patterns

* Mypy fixes

* Fix imports, typing, tests, logging.

* Fix test import error

* Fix imports again

* Fix thread handling
2025-11-27 02:13:03 +00:00

270 lines
9.6 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for structured output handling in _agent.py."""
import json
import sys
from collections.abc import AsyncIterator, MutableSequence
from pathlib import Path
from typing import Any
from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent
from agent_framework._types import ChatResponseUpdate
from pydantic import BaseModel
sys.path.insert(0, str(Path(__file__).parent))
from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates
class RecipeOutput(BaseModel):
"""Test Pydantic model for recipe output."""
recipe: dict[str, Any]
message: str | None = None
class StepsOutput(BaseModel):
"""Test Pydantic model for steps output."""
steps: list[dict[str, Any]]
message: str | None = None
class GenericOutput(BaseModel):
"""Test Pydantic model for generic data."""
data: dict[str, Any]
async def test_structured_output_with_recipe():
"""Test structured output processing with recipe state."""
from agent_framework.ag_ui import AgentFrameworkAgent
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(
contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')]
)
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=RecipeOutput)
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"recipe": {"type": "object"}},
)
input_data = {"messages": [{"role": "user", "content": "Make pasta"}]}
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit StateSnapshotEvent with recipe
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
assert len(snapshot_events) >= 1
# Find snapshot with recipe
recipe_snapshots = [e for e in snapshot_events if "recipe" in e.snapshot]
assert len(recipe_snapshots) >= 1
assert recipe_snapshots[0].snapshot["recipe"] == {"name": "Pasta"}
# Should also emit message as text
text_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert any("Here is your recipe" in e.delta for e in text_events)
async def test_structured_output_with_steps():
"""Test structured output processing with steps state."""
from agent_framework.ag_ui import AgentFrameworkAgent
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
steps_data = {
"steps": [
{"id": "1", "description": "Step 1", "status": "pending"},
{"id": "2", "description": "Step 2", "status": "pending"},
]
}
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))])
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=StepsOutput)
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"steps": {"type": "array"}},
)
input_data = {"messages": [{"role": "user", "content": "Do steps"}]}
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit StateSnapshotEvent with steps
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
assert len(snapshot_events) >= 1
# Snapshot should contain steps
steps_snapshots = [e for e in snapshot_events if "steps" in e.snapshot]
assert len(steps_snapshots) >= 1
assert len(steps_snapshots[0].snapshot["steps"]) == 2
assert steps_snapshots[0].snapshot["steps"][0]["id"] == "1"
async def test_structured_output_with_no_schema_match():
"""Test structured output when response fields don't match state_schema keys."""
from agent_framework.ag_ui import AgentFrameworkAgent
updates = [
ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')]),
]
agent = ChatAgent(
name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_from_updates(updates))
)
agent.chat_options = ChatOptions(response_format=GenericOutput)
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"result": {"type": "object"}}, # Schema expects "result", not "data"
)
input_data = {"messages": [{"role": "user", "content": "Generate data"}]}
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit StateSnapshotEvent but with no state updates since no schema fields match
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
# Initial state snapshot from state_schema initialization
assert len(snapshot_events) >= 1
async def test_structured_output_without_schema():
"""Test structured output without state_schema treats all fields as state."""
from agent_framework.ag_ui import AgentFrameworkAgent
class DataOutput(BaseModel):
"""Output with data and info fields."""
data: dict[str, Any]
info: str
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')])
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=DataOutput)
wrapper = AgentFrameworkAgent(
agent=agent,
# No state_schema - all non-message fields treated as state
)
input_data = {"messages": [{"role": "user", "content": "Generate data"}]}
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit StateSnapshotEvent with both data and info fields
snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
assert len(snapshot_events) >= 1
assert "data" in snapshot_events[0].snapshot
assert "info" in snapshot_events[0].snapshot
assert snapshot_events[0].snapshot["data"] == {"key": "value"}
assert snapshot_events[0].snapshot["info"] == "processed"
async def test_no_structured_output_when_no_response_format():
"""Test that structured output path is skipped when no response_format."""
from agent_framework.ag_ui import AgentFrameworkAgent
updates = [ChatResponseUpdate(contents=[TextContent(text="Regular text")])]
agent = ChatAgent(
name="test",
instructions="Test",
chat_client=StreamingChatClientStub(stream_from_updates(updates)),
)
# No response_format set
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit text content normally
text_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert len(text_events) > 0
assert text_events[0].delta == "Regular text"
async def test_structured_output_with_message_field():
"""Test structured output that includes a message field."""
from agent_framework.ag_ui import AgentFrameworkAgent
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"}
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))])
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=RecipeOutput)
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"recipe": {"type": "object"}},
)
input_data = {"messages": [{"role": "user", "content": "Make salad"}]}
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should emit the message as text
text_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert any("Fresh salad recipe ready" in e.delta for e in text_events)
# Should also have TextMessageStart and TextMessageEnd
start_events = [e for e in events if e.type == "TEXT_MESSAGE_START"]
end_events = [e for e in events if e.type == "TEXT_MESSAGE_END"]
assert len(start_events) >= 1
assert len(end_events) >= 1
async def test_empty_updates_no_structured_processing():
"""Test that empty updates don't trigger structured output processing."""
from agent_framework.ag_ui import AgentFrameworkAgent
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
if False:
yield ChatResponseUpdate(contents=[])
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=RecipeOutput)
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Test"}]}
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Should only have start and end events
assert len(events) == 2 # RunStarted, RunFinished