Python: Refactor ag-ui to clean up some patterns (#2363)

* Refactor ag-ui to clean up some patterns

* Mypy fixes

* Fix imports, typing, tests, logging.

* Fix test import error

* Fix imports again

* Fix thread handling
This commit is contained in:
Evan Mattson
2025-11-27 11:13:03 +09:00
committed by GitHub
Unverified
parent 6c624319db
commit 8cf8b0f995
26 changed files with 1887 additions and 1415 deletions
@@ -3,7 +3,7 @@
"""AgentFrameworkAgent wrapper for AG-UI protocol - Clean Architecture."""
from collections.abc import AsyncGenerator
from typing import Any
from typing import Any, cast
from ag_ui.core import BaseEvent
from agent_framework import AgentProtocol
@@ -22,21 +22,48 @@ class AgentConfig:
def __init__(
self,
state_schema: dict[str, Any] | None = None,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
):
"""Initialize agent configuration.
Args:
state_schema: Optional state schema for state management
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
predict_state_config: Configuration for predictive state updates
require_confirmation: Whether predictive updates require confirmation
"""
self.state_schema = state_schema or {}
self.state_schema = self._normalize_state_schema(state_schema)
self.predict_state_config = predict_state_config or {}
self.require_confirmation = require_confirmation
@staticmethod
def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]:
"""Accept dict or Pydantic model/class and return a properties dict."""
if state_schema is None:
return {}
if isinstance(state_schema, dict):
return cast(dict[str, Any], state_schema)
base_model_type: type[Any] | None
try:
from pydantic import BaseModel as ImportedBaseModel
base_model_type = ImportedBaseModel
except Exception: # pragma: no cover
base_model_type = None
if base_model_type is not None and isinstance(state_schema, base_model_type):
schema_dict = state_schema.__class__.model_json_schema()
return schema_dict.get("properties", {}) or {}
if base_model_type is not None and isinstance(state_schema, type) and issubclass(state_schema, base_model_type):
schema_dict = state_schema.model_json_schema()
return schema_dict.get("properties", {}) or {}
return {}
class AgentFrameworkAgent:
"""Wraps Agent Framework agents for AG-UI protocol compatibility.
@@ -55,7 +82,7 @@ class AgentFrameworkAgent:
agent: AgentProtocol,
name: str | None = None,
description: str | None = None,
state_schema: dict[str, Any] | None = None,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
orchestrators: list[Orchestrator] | None = None,
@@ -67,7 +94,7 @@ class AgentFrameworkAgent:
agent: The Agent Framework agent to wrap
name: Optional name for the agent
description: Optional description
state_schema: Optional state schema for state management
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
predict_state_config: Configuration for predictive state updates.
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
require_confirmation: Whether predictive updates require confirmation.
@@ -2,6 +2,7 @@
"""FastAPI endpoint creation for AG-UI agents."""
import copy
import logging
from typing import Any
@@ -19,9 +20,10 @@ def add_agent_framework_fastapi_endpoint(
app: FastAPI,
agent: AgentProtocol | AgentFrameworkAgent,
path: str = "/",
state_schema: dict[str, Any] | None = None,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
allow_origins: list[str] | None = None,
default_state: dict[str, Any] | None = None,
) -> None:
"""Add an AG-UI endpoint to a FastAPI app.
@@ -29,10 +31,11 @@ def add_agent_framework_fastapi_endpoint(
app: The FastAPI application
agent: The agent to expose (can be raw AgentProtocol or wrapped)
path: The endpoint path
state_schema: Optional state schema for shared state management
state_schema: Optional state schema for shared state management; accepts dict or Pydantic model/class
predict_state_config: Optional predictive state update configuration.
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
allow_origins: CORS origins (not yet implemented)
default_state: Optional initial state to seed when the client does not provide state keys
"""
if isinstance(agent, AgentProtocol):
wrapped_agent = AgentFrameworkAgent(
@@ -52,6 +55,11 @@ def add_agent_framework_fastapi_endpoint(
"""
try:
input_data = await request.json()
if default_state:
state = input_data.setdefault("state", {})
for key, value in default_state.items():
if key not in state:
state[key] = copy.deepcopy(value)
logger.debug(
f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, "
f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, "
File diff suppressed because it is too large Load Diff
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
@@ -0,0 +1,176 @@
# Copyright (c) Microsoft. All rights reserved.
"""Message hygiene utilities for orchestrators."""
import json
import logging
from typing import Any
from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent
logger = logging.getLogger(__name__)
def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
"""Normalize tool ordering and inject synthetic results for AG-UI edge cases."""
sanitized: list[ChatMessage] = []
pending_tool_call_ids: set[str] | None = None
pending_confirm_changes_id: str | None = None
for msg in messages:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value == "assistant":
tool_ids = {
str(content.call_id)
for content in msg.contents or []
if isinstance(content, FunctionCallContent) and content.call_id
}
confirm_changes_call = None
for content in msg.contents or []:
if isinstance(content, FunctionCallContent) and content.name == "confirm_changes":
confirm_changes_call = content
break
sanitized.append(msg)
pending_tool_call_ids = tool_ids if tool_ids else None
pending_confirm_changes_id = (
str(confirm_changes_call.call_id) if confirm_changes_call and confirm_changes_call.call_id else None
)
continue
if role_value == "user":
if pending_confirm_changes_id:
user_text = ""
for content in msg.contents or []:
if isinstance(content, TextContent):
user_text = content.text
break
try:
parsed = json.loads(user_text)
if "accepted" in parsed:
logger.info(
f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}"
)
synthetic_result = ChatMessage(
role="tool",
contents=[
FunctionResultContent(
call_id=pending_confirm_changes_id,
result="Confirmed" if parsed.get("accepted") else "Rejected",
)
],
)
sanitized.append(synthetic_result)
if pending_tool_call_ids:
pending_tool_call_ids.discard(pending_confirm_changes_id)
pending_confirm_changes_id = None
continue
except (json.JSONDecodeError, KeyError) as exc:
logger.debug("Could not parse user message as confirm_changes response: %s", type(exc).__name__)
if pending_tool_call_ids:
logger.info(
f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - injecting synthetic results"
)
for pending_call_id in pending_tool_call_ids:
logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}")
synthetic_result = ChatMessage(
role="tool",
contents=[
FunctionResultContent(
call_id=pending_call_id,
result="Tool execution skipped - user provided follow-up message",
)
],
)
sanitized.append(synthetic_result)
pending_tool_call_ids = None
pending_confirm_changes_id = None
sanitized.append(msg)
pending_confirm_changes_id = None
continue
if role_value == "tool":
if not pending_tool_call_ids:
continue
keep = False
for content in msg.contents or []:
if isinstance(content, FunctionResultContent):
call_id = str(content.call_id)
if call_id in pending_tool_call_ids:
keep = True
if call_id == pending_confirm_changes_id:
pending_confirm_changes_id = None
break
if keep:
sanitized.append(msg)
continue
sanitized.append(msg)
pending_tool_call_ids = None
pending_confirm_changes_id = None
return sanitized
def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
"""Remove duplicate messages while preserving order."""
seen_keys: dict[Any, int] = {}
unique_messages: list[ChatMessage] = []
for idx, msg in enumerate(messages):
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent):
call_id = str(msg.contents[0].call_id)
key: Any = (role_value, call_id)
if key in seen_keys:
existing_idx = seen_keys[key]
existing_msg = unique_messages[existing_idx]
existing_result = None
if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent):
existing_result = existing_msg.contents[0].result
new_result = msg.contents[0].result
if (not existing_result or existing_result == "") and new_result:
logger.info(f"Replacing empty tool result at index {existing_idx} with data from index {idx}")
unique_messages[existing_idx] = msg
else:
logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}")
continue
seen_keys[key] = len(unique_messages)
unique_messages.append(msg)
elif (
role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents)
):
tool_call_ids = tuple(
sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id)
)
key = (role_value, tool_call_ids)
if key in seen_keys:
logger.info(f"Skipping duplicate assistant tool call at index {idx}")
continue
seen_keys[key] = len(unique_messages)
unique_messages.append(msg)
else:
content_str = str([str(c) for c in msg.contents]) if msg.contents else ""
key = (role_value, hash(content_str))
if key in seen_keys:
logger.info(f"Skipping duplicate message at index {idx}: role={role_value}")
continue
seen_keys[key] = len(unique_messages)
unique_messages.append(msg)
return unique_messages
@@ -0,0 +1,102 @@
# Copyright (c) Microsoft. All rights reserved.
"""State orchestration utilities."""
import json
from typing import Any
from ag_ui.core import CustomEvent, EventType
from agent_framework import ChatMessage, TextContent
class StateManager:
"""Coordinates state defaults, snapshots, and structured updates."""
def __init__(
self,
state_schema: dict[str, Any] | None,
predict_state_config: dict[str, dict[str, str]] | None,
require_confirmation: bool,
) -> None:
self.state_schema = state_schema or {}
self.predict_state_config = predict_state_config or {}
self.require_confirmation = require_confirmation
self.current_state: dict[str, Any] = {}
def initialize(self, initial_state: dict[str, Any] | None) -> dict[str, Any]:
"""Initialize state with schema defaults."""
self.current_state = (initial_state or {}).copy()
self._apply_schema_defaults()
return self.current_state
def predict_state_event(self) -> CustomEvent | None:
"""Create predict-state custom event when configured."""
if not self.predict_state_config:
return None
predict_state_value = [
{
"state_key": state_key,
"tool": config["tool"],
"tool_argument": config["tool_argument"],
}
for state_key, config in self.predict_state_config.items()
]
return CustomEvent(
type=EventType.CUSTOM,
name="PredictState",
value=predict_state_value,
)
def initial_snapshot_event(self, event_bridge: Any) -> Any:
"""Emit initial snapshot when schema and state present."""
if not self.state_schema:
return None
self._apply_schema_defaults()
return event_bridge.create_state_snapshot_event(self.current_state)
def state_context_message(self, is_new_user_turn: bool, conversation_has_tool_calls: bool) -> ChatMessage | None:
"""Inject state context only when starting a new user turn."""
if not self.current_state or not self.state_schema:
return None
if not is_new_user_turn or conversation_has_tool_calls:
return None
state_json = json.dumps(self.current_state, indent=2)
return ChatMessage(
role="system",
contents=[
TextContent(
text=(
"Current state of the application:\n"
f"{state_json}\n\n"
"When modifying state, you MUST include ALL existing data plus your changes.\n"
"For example, if adding one new item to a list, include ALL existing items PLUS the one new item.\n"
"Never replace existing data - always preserve and append or merge."
)
)
],
)
def extract_state_updates(self, response_dict: dict[str, Any]) -> dict[str, Any]:
"""Extract state updates from structured response payloads."""
if self.state_schema:
return {key: response_dict[key] for key in self.state_schema.keys() if key in response_dict}
return {k: v for k, v in response_dict.items() if k != "message"}
def apply_state_updates(self, updates: dict[str, Any]) -> None:
"""Merge state updates into current state."""
if not updates:
return
self.current_state.update(updates)
def _apply_schema_defaults(self) -> None:
"""Fill missing state fields based on schema hints."""
for key, schema in self.state_schema.items():
if key in self.current_state:
continue
if isinstance(schema, dict) and schema.get("type") == "array": # type: ignore
self.current_state[key] = []
else:
self.current_state[key] = {}
@@ -0,0 +1,80 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tool handling helpers."""
import logging
from typing import Any
from agent_framework import BaseChatClient, ChatAgent
logger = logging.getLogger(__name__)
def collect_server_tools(agent: Any) -> list[Any]:
"""Collect server tools from ChatAgent or duck-typed agent."""
if isinstance(agent, ChatAgent):
tools_from_agent = agent.chat_options.tools
server_tools = list(tools_from_agent) if tools_from_agent else []
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
for tool in server_tools:
tool_name = getattr(tool, "name", "unknown")
approval_mode = getattr(tool, "approval_mode", None)
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
return server_tools
try:
chat_options_attr = getattr(agent, "chat_options", None)
if chat_options_attr is not None:
return getattr(chat_options_attr, "tools", None) or []
except AttributeError:
return []
return []
def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) -> None:
"""Register client tools as additional declaration-only tools to avoid server execution."""
if not client_tools:
return
if isinstance(agent, ChatAgent):
chat_client = agent.chat_client
if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None:
chat_client.function_invocation_configuration.additional_tools = client_tools
logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)")
return
try:
chat_client_attr = getattr(agent, "chat_client", None)
if chat_client_attr is not None:
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
if fic is not None:
fic.additional_tools = client_tools # type: ignore[attr-defined]
logger.debug(
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
)
except AttributeError:
return
def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None:
"""Combine server and client tools without overriding server metadata."""
if not client_tools:
logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)")
return None
server_tool_names = {getattr(tool, "name", None) for tool in server_tools}
unique_client_tools = [tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names]
if not unique_client_tools:
logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter")
return None
combined_tools: list[Any] = []
if server_tools:
combined_tools.extend(server_tools)
combined_tools.extend(unique_client_tools)
logger.info(
f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools "
f"({len(server_tools)} server + {len(unique_client_tools)} unique client)"
)
return combined_tools
@@ -21,7 +21,6 @@ from agent_framework import (
AgentProtocol,
AgentThread,
ChatAgent,
ChatMessage,
FunctionCallContent,
FunctionResultContent,
TextContent,
@@ -271,144 +270,29 @@ class DefaultOrchestrator(Orchestrator):
AG-UI events
"""
from ._events import AgentFrameworkEventBridge
from ._message_adapters import agui_messages_to_snapshot_format
from ._orchestration._message_hygiene import deduplicate_messages, sanitize_tool_history
from ._orchestration._state_manager import StateManager
from ._orchestration._tooling import (
collect_server_tools,
merge_tools,
register_additional_client_tools,
)
logger.info(f"Starting default agent run for thread_id={context.thread_id}, run_id={context.run_id}")
# Initialize state tracking
initial_state = context.input_data.get("state", {})
current_state: dict[str, Any] = initial_state.copy() if initial_state else {}
# Check if agent uses structured outputs (response_format)
# Use isinstance to narrow type for proper attribute access
response_format = None
if isinstance(context.agent, ChatAgent):
response_format = context.agent.chat_options.response_format
skip_text_content = response_format is not None
# Sanitizer: ensure tool results only follow assistant tool calls
# Also inject synthetic tool results for confirm_changes
def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
sanitized: list[ChatMessage] = []
pending_tool_call_ids: set[str] | None = None
pending_confirm_changes_id: str | None = None
state_manager = StateManager(
state_schema=context.config.state_schema,
predict_state_config=context.config.predict_state_config,
require_confirmation=context.config.require_confirmation,
)
current_state = state_manager.initialize(context.input_data.get("state", {}))
for msg in messages:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value == "assistant":
tool_ids = {
str(content.call_id)
for content in msg.contents or []
if isinstance(content, FunctionCallContent) and content.call_id
}
# Check for confirm_changes tool call
confirm_changes_call = None
for content in msg.contents or []:
if isinstance(content, FunctionCallContent) and content.name == "confirm_changes":
confirm_changes_call = content
break
sanitized.append(msg)
pending_tool_call_ids = tool_ids if tool_ids else None
pending_confirm_changes_id = (
str(confirm_changes_call.call_id)
if confirm_changes_call and confirm_changes_call.call_id
else None
)
continue
if role_value == "user":
# Check if this user message is a confirm_changes response (JSON with "accepted" field)
# This must be checked BEFORE injecting synthetic results for pending tool calls
if pending_confirm_changes_id:
user_text = ""
for content in msg.contents or []:
if isinstance(content, TextContent):
user_text = content.text
break
try:
parsed = json.loads(user_text)
if "accepted" in parsed:
# This is a confirm_changes response - inject synthetic tool result
logger.info(
f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}"
)
synthetic_result = ChatMessage(
role="tool",
contents=[
FunctionResultContent(
call_id=pending_confirm_changes_id,
result="Confirmed" if parsed.get("accepted") else "Rejected",
)
],
)
sanitized.append(synthetic_result)
if pending_tool_call_ids:
pending_tool_call_ids.discard(pending_confirm_changes_id)
pending_confirm_changes_id = None
# Don't add the user message to sanitized - it's been converted to tool result
continue
except (json.JSONDecodeError, KeyError) as e:
# Failed to parse user message as confirm_changes response; continue normal processing
logger.debug(f"Could not parse user message as confirm_changes response: {e}")
# Before processing user message, check if there are pending tool calls without results
# This happens when assistant made multiple tool calls but only some got results
# This is checked AFTER confirm_changes special handling above
if pending_tool_call_ids:
logger.info(
f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - injecting synthetic results"
)
for pending_call_id in pending_tool_call_ids:
logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}")
synthetic_result = ChatMessage(
role="tool",
contents=[
FunctionResultContent(
call_id=pending_call_id,
result="Tool execution skipped - user provided follow-up message",
)
],
)
sanitized.append(synthetic_result)
pending_tool_call_ids = None
pending_confirm_changes_id = None
# Normal user message processing
sanitized.append(msg)
pending_confirm_changes_id = None
continue
if role_value == "tool":
if not pending_tool_call_ids:
continue
keep = False
for content in msg.contents or []:
if isinstance(content, FunctionResultContent):
call_id = str(content.call_id)
if call_id in pending_tool_call_ids:
keep = True
# Note: We do NOT remove call_id from pending here.
# This allows duplicate tool results to pass through sanitization
# so the deduplicator can choose the best one (prefer non-empty results).
# We only clear pending_tool_call_ids when a user message arrives.
if call_id == pending_confirm_changes_id:
# For confirm_changes specifically, we do want to clear it
# since we only expect one response
pending_confirm_changes_id = None
break
if keep:
sanitized.append(msg)
continue
sanitized.append(msg)
pending_tool_call_ids = None
pending_confirm_changes_id = None
return sanitized
# Create event bridge
event_bridge = AgentFrameworkEventBridge(
run_id=context.run_id,
thread_id=context.thread_id,
@@ -421,42 +305,19 @@ class DefaultOrchestrator(Orchestrator):
yield event_bridge.create_run_started_event()
# Emit PredictState custom event if we have predictive state config
if context.config.predict_state_config:
from ag_ui.core import CustomEvent, EventType
predict_event = state_manager.predict_state_event()
if predict_event:
yield predict_event
predict_state_value = [
{
"state_key": state_key,
"tool": config["tool"],
"tool_argument": config["tool_argument"],
}
for state_key, config in context.config.predict_state_config.items()
]
snapshot_event = state_manager.initial_snapshot_event(event_bridge)
if snapshot_event:
yield snapshot_event
yield CustomEvent(
type=EventType.CUSTOM,
name="PredictState",
value=predict_state_value,
)
# If we have a state schema, ensure we emit initial state snapshot
if context.config.state_schema:
# Initialize missing state fields with appropriate empty values based on schema type
for key, schema in context.config.state_schema.items():
if key not in current_state:
# Default to empty object; use empty array if schema specifies "array" type
current_state[key] = [] if isinstance(schema, dict) and schema.get("type") == "array" else {} # type: ignore
yield event_bridge.create_state_snapshot_event(current_state)
# Create thread for context tracking
thread = AgentThread()
thread.metadata = { # type: ignore[attr-defined]
"ag_ui_thread_id": context.thread_id,
"ag_ui_run_id": context.run_id,
}
# Inject current state into thread metadata so agent can access it
if current_state:
thread.metadata["current_state"] = current_state # type: ignore[attr-defined]
@@ -475,90 +336,24 @@ class DefaultOrchestrator(Orchestrator):
for j, content in enumerate(msg.contents):
content_type = type(content).__name__
if isinstance(content, TextContent):
logger.debug(f" Content {j}: {content_type} - {content.text}")
logger.debug(" Content %s: %s - text_length=%s", j, content_type, len(content.text))
elif isinstance(content, FunctionCallContent):
logger.debug(f" Content {j}: {content_type} - {content.name}({content.arguments})")
elif isinstance(content, FunctionResultContent):
arg_length = len(str(content.arguments)) if content.arguments else 0
logger.debug(
f" Content {j}: {content_type} - call_id={content.call_id}, result={content.result}"
" Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length
)
elif isinstance(content, FunctionResultContent):
result_preview = type(content.result).__name__ if content.result is not None else "None"
logger.debug(
" Content %s: %s - call_id=%s, result_type=%s",
j,
content_type,
content.call_id,
result_preview,
)
else:
logger.debug(f" Content {j}: {content_type} - {content}")
logger.debug(f" Content {j}: {content_type}")
# After getting sanitized_messages, deduplicate them
def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
"""Remove duplicate messages while preserving order.
For tool results with the same call_id, prefer the one with actual data.
"""
seen_keys: dict[Any, int] = {} # key -> index in unique_messages (key can be various tuple types)
unique_messages: list[ChatMessage] = []
for idx, msg in enumerate(messages):
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
# For tool messages, use call_id as unique key
if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent):
call_id = str(msg.contents[0].call_id)
key: Any = (role_value, call_id)
# Check if we already have this tool result
if key in seen_keys:
existing_idx = seen_keys[key]
existing_msg = unique_messages[existing_idx]
# Compare results - prefer non-empty over empty
existing_result = None
if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent):
existing_result = existing_msg.contents[0].result
new_result = msg.contents[0].result
# Replace if existing is empty/None and new has data
if (not existing_result or existing_result == "") and new_result:
logger.info(
f"Replacing empty tool result at index {existing_idx} with data from index {idx}"
)
unique_messages[existing_idx] = msg
else:
logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}")
continue
seen_keys[key] = len(unique_messages)
unique_messages.append(msg)
elif (
role_value == "assistant"
and msg.contents
and any(isinstance(c, FunctionCallContent) for c in msg.contents)
):
# For assistant messages with tool_calls, use the tool call IDs
tool_call_ids = tuple(
sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id)
)
key = (role_value, tool_call_ids)
if key in seen_keys:
logger.info(f"Skipping duplicate assistant tool call at index {idx}")
continue
seen_keys[key] = len(unique_messages)
unique_messages.append(msg)
else:
# For other messages (system, user, assistant without tools), hash the content
content_str = str([str(c) for c in msg.contents]) if msg.contents else ""
key = (role_value, hash(content_str))
if key in seen_keys:
logger.info(f"Skipping duplicate message at index {idx}: role={role_value}")
continue
seen_keys[key] = len(unique_messages)
unique_messages.append(msg)
return unique_messages
# Then use it:
sanitized_messages = sanitize_tool_history(raw_messages)
provider_messages = deduplicate_messages(sanitized_messages)
@@ -575,66 +370,45 @@ class DefaultOrchestrator(Orchestrator):
for j, content in enumerate(msg.contents):
content_type = type(content).__name__
if isinstance(content, TextContent):
logger.info(f" Content {j}: {content_type} - {content.text}")
logger.info(f" Content {j}: {content_type} - text_length={len(content.text)}")
elif isinstance(content, FunctionCallContent):
logger.info(f" Content {j}: {content_type} - {content.name}({content.arguments})")
arg_length = len(str(content.arguments)) if content.arguments else 0
logger.info(" Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length)
elif isinstance(content, FunctionResultContent):
result_preview = type(content.result).__name__ if content.result is not None else "None"
logger.info(
f" Content {j}: {content_type} - call_id={content.call_id}, result={content.result}"
" Content %s: %s - call_id=%s, result_type=%s",
j,
content_type,
content.call_id,
result_preview,
)
else:
logger.info(f" Content {j}: {content_type} - {content}")
logger.info(f" Content {j}: {content_type}")
# NOTE: For AG-UI, the client sends the full conversation history on each request.
# We should NOT add to thread.on_new_messages() as that would cause duplication.
# Instead, we pass messages directly to the agent via messages_to_run.
# Inject current state as system message context if we have state and this is a new user turn
messages_to_run: list[Any] = []
# Check if the last message is from the user (new turn) vs assistant/tool (mid-execution)
is_new_user_turn = False
if provider_messages:
last_msg = provider_messages[-1]
is_new_user_turn = last_msg.role.value == "user"
role_value = last_msg.role.value if hasattr(last_msg.role, "value") else str(last_msg.role)
is_new_user_turn = role_value == "user"
# Check if conversation has tool calls (indicates mid-execution)
conversation_has_tool_calls = False
for msg in provider_messages:
if msg.role.value == "assistant" and hasattr(msg, "contents") and msg.contents:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value == "assistant" and hasattr(msg, "contents") and msg.contents:
if any(isinstance(content, FunctionCallContent) for content in msg.contents):
conversation_has_tool_calls = True
break
# Only inject state context on new user turns AND when conversation doesn't have tool calls
# (tool calls indicate we're mid-execution, so state context was already injected)
if current_state and context.config.state_schema and is_new_user_turn and not conversation_has_tool_calls:
state_json = json.dumps(current_state, indent=2)
state_context_msg = ChatMessage(
role="system",
contents=[
TextContent(
text=f"""Current state of the application:
{state_json}
When modifying state, you MUST include ALL existing data plus your changes.
For example, if adding one new item to a list, include ALL existing items PLUS the one new item.
Never replace existing data - always preserve and append or merge."""
)
],
)
state_context_msg = state_manager.state_context_message(
is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls
)
if state_context_msg:
messages_to_run.append(state_context_msg)
# Add all provider messages to messages_to_run
# AG-UI sends full conversation history on each request, so we pass it directly to the agent
messages_to_run.extend(provider_messages)
# Handle client tools for hybrid execution
# Client sends tool metadata, server merges with its own tools.
# Client tools have func=None (declaration-only), so @use_function_invocation
# will return the function call without executing (passes back to client).
from agent_framework import BaseChatClient
client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools"))
logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools")
if client_tools:
@@ -643,85 +417,31 @@ class DefaultOrchestrator(Orchestrator):
declaration_only = getattr(tool, "declaration_only", None)
logger.info(f"[TOOLS] - Client tool: {tool_name}, declaration_only={declaration_only}")
# Extract server tools - use type narrowing when possible
server_tools: list[Any] = []
if isinstance(context.agent, ChatAgent):
tools_from_agent = context.agent.chat_options.tools
server_tools = list(tools_from_agent) if tools_from_agent else []
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
for tool in server_tools:
tool_name = getattr(tool, "name", "unknown")
approval_mode = getattr(tool, "approval_mode", None)
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
else:
# AgentProtocol allows duck-typed implementations - fallback to attribute access
# This supports test mocks and custom agent implementations
try:
chat_options_attr = getattr(context.agent, "chat_options", None)
if chat_options_attr is not None:
server_tools = getattr(chat_options_attr, "tools", None) or []
except AttributeError:
pass
server_tools = collect_server_tools(context.agent)
register_additional_client_tools(context.agent, client_tools)
tools_param = merge_tools(server_tools, client_tools)
# Register client tools as additional (declaration-only) so they are not executed on server
if client_tools:
if isinstance(context.agent, ChatAgent):
# Type-safe path for ChatAgent
chat_client = context.agent.chat_client
if (
isinstance(chat_client, BaseChatClient)
and chat_client.function_invocation_configuration is not None
):
chat_client.function_invocation_configuration.additional_tools = client_tools
logger.debug(
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
)
else:
# Fallback for AgentProtocol implementations (test mocks, custom agents)
try:
chat_client_attr = getattr(context.agent, "chat_client", None)
if chat_client_attr is not None:
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
if fic is not None:
fic.additional_tools = client_tools # type: ignore[attr-defined]
logger.debug(
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
)
except AttributeError:
pass
# For tools parameter: only pass if we have client tools to add
# If we pass tools=, it overrides the agent's configured tools and loses metadata like approval_mode
# So only pass tools when we need to add client tools on top of server tools
# IMPORTANT: Don't include client tools that duplicate server tools (same name)
tools_param = None
if client_tools:
# Get server tool names
server_tool_names = {getattr(tool, "name", None) for tool in server_tools}
# Filter out client tools that duplicate server tools
unique_client_tools = [
tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names
]
if unique_client_tools:
combined_tools: list[Any] = []
if server_tools:
combined_tools.extend(server_tools)
combined_tools.extend(unique_client_tools)
tools_param = combined_tools
logger.info(
f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools ({len(server_tools)} server + {len(unique_client_tools)} unique client)"
)
else:
logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter")
else:
logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)")
# Collect all updates to get the final structured output
all_updates: list[Any] = []
update_count = 0
async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=tools_param):
# Prepare metadata for chat client (Azure requires string values)
safe_metadata: dict[str, Any] = {}
thread_metadata = getattr(thread, "metadata", None)
if thread_metadata:
for key, value in thread_metadata.items():
value_str = value if isinstance(value, str) else json.dumps(value)
if len(value_str) > 512:
value_str = value_str[:512]
safe_metadata[key] = value_str
run_kwargs: dict[str, Any] = {
"thread": thread,
"tools": tools_param,
"metadata": safe_metadata,
}
if safe_metadata:
run_kwargs["store"] = True
async for update in context.agent.run_stream(messages_to_run, **run_kwargs):
update_count += 1
logger.info(f"[STREAM] Received update #{update_count} from agent")
all_updates.append(update)
@@ -733,23 +453,19 @@ class DefaultOrchestrator(Orchestrator):
logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}")
# After agent completes, check if we should stop (waiting for user to confirm changes)
if event_bridge.should_stop_after_confirm:
logger.info("Stopping run after confirm_changes - waiting for user response")
yield event_bridge.create_run_finished_event()
return
# Check if there are pending tool calls (declaration-only tools that weren't executed)
# These need ToolCallEndEvent to signal the client to execute them
# Only emit for tool calls that haven't already had ToolCallEndEvent emitted
# (approval-required tools already had their end event emitted)
if event_bridge.pending_tool_calls:
pending_without_end = [
tc for tc in event_bridge.pending_tool_calls if tc.get("id") not in event_bridge.tool_calls_ended
]
if pending_without_end:
logger.info(
f"Found {len(pending_without_end)} pending tool calls without end event - emitting ToolCallEndEvent"
"Found %s pending tool calls without end event - emitting ToolCallEndEvent",
len(pending_without_end),
)
for tool_call in pending_without_end:
tool_call_id = tool_call.get("id")
@@ -760,76 +476,47 @@ class DefaultOrchestrator(Orchestrator):
logger.info(f"Emitting ToolCallEndEvent for declaration-only tool call '{tool_call_id}'")
yield end_event
# After streaming completes, check if agent has response_format and extract structured output
if all_updates and response_format:
from agent_framework import AgentRunResponse
from pydantic import BaseModel
logger.info(f"Processing structured output, update count: {len(all_updates)}")
# Convert streaming updates to final response to get the structured output
final_response = AgentRunResponse.from_agent_run_response_updates(
all_updates, output_format_type=response_format
)
if final_response.value and isinstance(final_response.value, BaseModel):
# Convert Pydantic model to dict
response_dict = final_response.value.model_dump(mode="json", exclude_none=True)
logger.info(f"Received structured output: {list(response_dict.keys())}")
logger.info(f"Received structured output keys: {list(response_dict.keys())}")
# Extract state fields based on state_schema
state_updates: dict[str, Any] = {}
if context.config.state_schema:
# Use state_schema to determine which fields are state
for state_key in context.config.state_schema.keys():
if state_key in response_dict:
state_updates[state_key] = response_dict[state_key]
else:
# No schema: treat all non-message fields as state
state_updates = {k: v for k, v in response_dict.items() if k != "message"}
# Apply state updates if any found
state_updates = state_manager.extract_state_updates(response_dict)
if state_updates:
current_state.update(state_updates)
# Emit StateSnapshotEvent with the updated state
state_manager.apply_state_updates(state_updates)
state_snapshot = event_bridge.create_state_snapshot_event(current_state)
yield state_snapshot
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")
# If there's a message field, emit it as chat text
if "message" in response_dict and response_dict["message"]:
message_id = generate_event_id()
yield TextMessageStartEvent(message_id=message_id, role="assistant")
yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"])
yield TextMessageEndEvent(message_id=message_id)
logger.info(f"Emitted conversational message: {response_dict['message'][:100]}...")
logger.info(f"Emitted conversational message with length={len(response_dict['message'])}")
logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}")
if event_bridge.current_message_id:
logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}")
yield event_bridge.create_message_end_event(event_bridge.current_message_id)
# Emit MessagesSnapshotEvent to persist the final assistant text message
from ._message_adapters import agui_messages_to_snapshot_format
# Build the final assistant message with accumulated text content
assistant_text_message = {
"id": event_bridge.current_message_id,
"role": "assistant",
"content": event_bridge.accumulated_text_content,
}
# Convert input messages to snapshot format (normalize content structure)
# event_bridge.input_messages are already in AG-UI format, just need normalization
converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages)
# Build complete messages array
# Include: input messages + any pending tool calls/results + final text message
all_messages = converted_input_messages.copy()
# Add assistant message with tool calls if any
if event_bridge.pending_tool_calls:
tool_call_message = {
"id": generate_event_id(),
@@ -838,18 +525,16 @@ class DefaultOrchestrator(Orchestrator):
}
all_messages.append(tool_call_message)
# Add tool results if any
all_messages.extend(event_bridge.tool_results.copy())
# Add final text message
all_messages.append(assistant_text_message)
messages_snapshot = MessagesSnapshotEvent(
messages=all_messages, # type: ignore[arg-type]
)
logger.info(
f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages "
f"(text content length: {len(event_bridge.accumulated_text_content)})"
"[FINALIZE] Emitting MessagesSnapshotEvent with %s messages (text content length: %s)",
len(all_messages),
len(event_bridge.accumulated_text_content),
)
yield messages_snapshot
else:
+1
View File
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
+112 -54
View File
@@ -1,10 +1,61 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for AGUIChatClient."""
import json
from collections.abc import AsyncGenerator, AsyncIterable, MutableSequence
from typing import Any
from agent_framework import ChatMessage, ChatOptions, FunctionCallContent, Role, ai_function
from agent_framework import (
ChatMessage,
ChatOptions,
ChatResponseUpdate,
FunctionCallContent,
Role,
TextContent,
ai_function,
)
from agent_framework._types import ChatResponse
from pytest import MonkeyPatch
from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent
from agent_framework_ag_ui._http_service import AGUIHttpService
class TestableAGUIChatClient(AGUIChatClient):
"""Testable wrapper exposing protected helpers."""
@property
def http_service(self) -> AGUIHttpService:
"""Expose http service for monkeypatching."""
return self._http_service
def extract_state_from_messages(
self, messages: list[ChatMessage]
) -> tuple[list[ChatMessage], dict[str, Any] | None]:
"""Expose state extraction helper."""
return self._extract_state_from_messages(messages)
def convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[dict[str, Any]]:
"""Expose message conversion helper."""
return self._convert_messages_to_agui_format(messages)
def get_thread_id(self, chat_options: ChatOptions) -> str:
"""Expose thread id helper."""
return self._get_thread_id(chat_options)
async def inner_get_streaming_response(
self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
) -> AsyncIterable[ChatResponseUpdate]:
"""Proxy to protected streaming call."""
async for update in self._inner_get_streaming_response(messages=messages, chat_options=chat_options):
yield update
async def inner_get_response(
self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
) -> ChatResponse:
"""Proxy to protected response call."""
return await self._inner_get_response(messages=messages, chat_options=chat_options)
class TestAGUIChatClient:
@@ -12,25 +63,25 @@ class TestAGUIChatClient:
async def test_client_initialization(self) -> None:
"""Test client initialization."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
assert client._http_service is not None
assert client._http_service.endpoint.startswith("http://localhost:8888")
assert client.http_service is not None
assert client.http_service.endpoint.startswith("http://localhost:8888")
async def test_client_context_manager(self) -> None:
"""Test client as async context manager."""
async with AGUIChatClient(endpoint="http://localhost:8888/") as client:
async with TestableAGUIChatClient(endpoint="http://localhost:8888/") as client:
assert client is not None
async def test_extract_state_from_messages_no_state(self) -> None:
"""Test state extraction when no state is present."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
messages = [
ChatMessage(role="user", text="Hello"),
ChatMessage(role="assistant", text="Hi there"),
]
result_messages, state = client._extract_state_from_messages(messages)
result_messages, state = client.extract_state_from_messages(messages)
assert result_messages == messages
assert state is None
@@ -39,7 +90,7 @@ class TestAGUIChatClient:
"""Test state extraction from last message."""
import base64
client = AGUIChatClient(endpoint="http://localhost:8888/")
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
state_data = {"key": "value", "count": 42}
state_json = json.dumps(state_data)
@@ -55,7 +106,7 @@ class TestAGUIChatClient:
),
]
result_messages, state = client._extract_state_from_messages(messages)
result_messages, state = client.extract_state_from_messages(messages)
assert len(result_messages) == 1
assert result_messages[0].text == "Hello"
@@ -65,7 +116,7 @@ class TestAGUIChatClient:
"""Test state extraction with invalid JSON."""
import base64
client = AGUIChatClient(endpoint="http://localhost:8888/")
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
invalid_json = "not valid json"
state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8")
@@ -79,20 +130,20 @@ class TestAGUIChatClient:
),
]
result_messages, state = client._extract_state_from_messages(messages)
result_messages, state = client.extract_state_from_messages(messages)
assert result_messages == messages
assert state is None
async def test_convert_messages_to_agui_format(self) -> None:
"""Test message conversion to AG-UI format."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
messages = [
ChatMessage(role=Role.USER, text="What is the weather?"),
ChatMessage(role=Role.ASSISTANT, text="Let me check.", message_id="msg_123"),
]
agui_messages = client._convert_messages_to_agui_format(messages)
agui_messages = client.convert_messages_to_agui_format(messages)
assert len(agui_messages) == 2
assert agui_messages[0]["role"] == "user"
@@ -103,24 +154,24 @@ class TestAGUIChatClient:
async def test_get_thread_id_from_metadata(self) -> None:
"""Test thread ID extraction from metadata."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"})
thread_id = client._get_thread_id(chat_options)
thread_id = client.get_thread_id(chat_options)
assert thread_id == "existing_thread_123"
async def test_get_thread_id_generation(self) -> None:
"""Test automatic thread ID generation."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
chat_options = ChatOptions()
thread_id = client._get_thread_id(chat_options)
thread_id = client.get_thread_id(chat_options)
assert thread_id.startswith("thread_")
assert len(thread_id) > 7
async def test_get_streaming_response(self, monkeypatch) -> None:
async def test_get_streaming_response(self, monkeypatch: MonkeyPatch) -> None:
"""Test streaming response method."""
mock_events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
@@ -129,26 +180,32 @@ class TestAGUIChatClient:
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test message")]
chat_options = ChatOptions()
updates = []
async for update in client._inner_get_streaming_response(messages=messages, chat_options=chat_options):
updates: list[ChatResponseUpdate] = []
async for update in client.inner_get_streaming_response(messages=messages, chat_options=chat_options):
updates.append(update)
assert len(updates) == 4
assert updates[0].additional_properties is not None
assert updates[0].additional_properties["thread_id"] == "thread_1"
assert updates[1].contents[0].text == "Hello"
assert updates[2].contents[0].text == " world"
async def test_get_response_non_streaming(self, monkeypatch) -> None:
first_content = updates[1].contents[0]
second_content = updates[2].contents[0]
assert isinstance(first_content, TextContent)
assert isinstance(second_content, TextContent)
assert first_content.text == "Hello"
assert second_content.text == " world"
async def test_get_response_non_streaming(self, monkeypatch: MonkeyPatch) -> None:
"""Test non-streaming response method."""
mock_events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
@@ -156,23 +213,23 @@ class TestAGUIChatClient:
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test message")]
chat_options = ChatOptions()
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
response = await client.inner_get_response(messages=messages, chat_options=chat_options)
assert response is not None
assert len(response.messages) > 0
assert "Complete response" in response.text
async def test_tool_handling(self, monkeypatch) -> None:
async def test_tool_handling(self, monkeypatch: MonkeyPatch) -> None:
"""Test that client tool metadata is sent to server.
Client tool metadata (name, description, schema) is sent to server for planning.
@@ -191,28 +248,29 @@ class TestAGUIChatClient:
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
# Client tool metadata should be sent to server
tools = kwargs.get("tools")
tools: list[dict[str, Any]] | None = kwargs.get("tools")
assert tools is not None
assert len(tools) == 1
assert tools[0]["name"] == "test_tool"
assert tools[0]["description"] == "Test tool."
assert "parameters" in tools[0]
tool_entry = tools[0]
assert tool_entry["name"] == "test_tool"
assert tool_entry["description"] == "Test tool."
assert "parameters" in tool_entry
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test with tools")]
chat_options = ChatOptions(tools=[test_tool])
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
response = await client.inner_get_response(messages=messages, chat_options=chat_options)
assert response is not None
async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch) -> None:
async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch: MonkeyPatch) -> None:
"""Ensure server-side tool calls are exposed as FunctionCallContent after processing."""
mock_events = [
@@ -222,17 +280,17 @@ class TestAGUIChatClient:
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test server tool execution")]
chat_options = ChatOptions()
updates = []
updates: list[ChatResponseUpdate] = []
async for update in client.get_streaming_response(messages, chat_options=chat_options):
updates.append(update)
@@ -245,7 +303,7 @@ class TestAGUIChatClient:
isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents
)
async def test_server_tool_calls_not_executed_locally(self, monkeypatch) -> None:
async def test_server_tool_calls_not_executed_locally(self, monkeypatch: MonkeyPatch) -> None:
"""Server tools should not trigger local function invocation even when client tools exist."""
@ai_function
@@ -260,18 +318,18 @@ class TestAGUIChatClient:
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
for event in mock_events:
yield event
async def fake_auto_invoke(*args, **kwargs):
async def fake_auto_invoke(*args: object, **kwargs: Any) -> None:
function_call = kwargs.get("function_call_content") or args[0]
raise AssertionError(f"Unexpected local execution of server tool: {getattr(function_call, 'name', '?')}")
monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke)
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test server tool execution")]
chat_options = ChatOptions(tool_choice="auto", tools=[client_tool])
@@ -279,7 +337,7 @@ class TestAGUIChatClient:
async for _ in client.get_streaming_response(messages, chat_options=chat_options):
pass
async def test_state_transmission(self, monkeypatch) -> None:
async def test_state_transmission(self, monkeypatch: MonkeyPatch) -> None:
"""Test state is properly transmitted to server."""
import base64
@@ -302,16 +360,16 @@ class TestAGUIChatClient:
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
assert kwargs.get("state") == state_data
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
chat_options = ChatOptions()
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
response = await client.inner_get_response(messages=messages, chat_options=chat_options)
assert response is not None
@@ -3,21 +3,30 @@
"""Comprehensive tests for AgentFrameworkAgent (_agent.py)."""
import json
import sys
from collections.abc import AsyncIterator, MutableSequence
from pathlib import Path
from typing import Any
import pytest
from agent_framework import ChatAgent, TextContent
from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent
from agent_framework._types import ChatResponseUpdate
from pydantic import BaseModel
sys.path.insert(0, str(Path(__file__).parent))
from test_helpers_ag_ui import StreamingChatClientStub
async def test_agent_initialization_basic():
"""Test basic agent initialization without state schema."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
assert wrapper.name == "test_agent"
@@ -30,12 +39,13 @@ async def test_agent_initialization_with_state_schema():
"""Test agent initialization with state_schema."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
state_schema = {"document": {"type": "string"}}
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
state_schema: dict[str, dict[str, Any]] = {"document": {"type": "string"}}
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
assert wrapper.config.state_schema == state_schema
@@ -45,31 +55,56 @@ async def test_agent_initialization_with_predict_state_config():
"""Test agent initialization with predict_state_config."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
wrapper = AgentFrameworkAgent(agent=agent, predict_state_config=predict_config)
assert wrapper.config.predict_state_config == predict_config
async def test_agent_initialization_with_pydantic_state_schema():
"""Test agent initialization when state_schema is provided as Pydantic model/class."""
from agent_framework.ag_ui import AgentFrameworkAgent
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
class MyState(BaseModel):
document: str
tags: list[str] = []
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper_class_schema = AgentFrameworkAgent(agent=agent, state_schema=MyState)
wrapper_instance_schema = AgentFrameworkAgent(agent=agent, state_schema=MyState(document="hi"))
expected_properties = MyState.model_json_schema().get("properties", {})
assert wrapper_class_schema.config.state_schema == expected_properties
assert wrapper_instance_schema.config.state_schema == expected_properties
async def test_run_started_event_emission():
"""Test RunStartedEvent is emitted at start of run."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -83,11 +118,12 @@ async def test_predict_state_custom_event_emission():
"""Test PredictState CustomEvent is emitted when predict_state_config is present."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
predict_config = {
"document": {"tool": "write_doc", "tool_argument": "content"},
"summary": {"tool": "summarize", "tool_argument": "text"},
@@ -96,7 +132,7 @@ async def test_predict_state_custom_event_emission():
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -114,11 +150,12 @@ async def test_initial_state_snapshot_with_schema():
"""Test initial StateSnapshotEvent emission when state_schema present."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
state_schema = {"document": {"type": "string"}}
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
@@ -127,7 +164,7 @@ async def test_initial_state_snapshot_with_schema():
"state": {"document": "Initial content"},
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -143,17 +180,18 @@ async def test_state_initialization_object_type():
"""Test state initialization with object type in schema."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
state_schema = {"recipe": {"type": "object", "properties": {}}}
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
state_schema: dict[str, dict[str, Any]] = {"recipe": {"type": "object", "properties": {}}}
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -169,17 +207,18 @@ async def test_state_initialization_array_type():
"""Test state initialization with array type in schema."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
state_schema = {"steps": {"type": "array", "items": {}}}
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
state_schema: dict[str, dict[str, Any]] = {"steps": {"type": "array", "items": {}}}
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -195,16 +234,17 @@ async def test_run_finished_event_emission():
"""Test RunFinishedEvent is emitted at end of run."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -216,11 +256,12 @@ async def test_tool_result_confirm_changes_accepted():
"""Test confirm_changes tool result handling when accepted."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Document updated")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Document updated")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"document": {"type": "string"}},
@@ -228,8 +269,8 @@ async def test_tool_result_confirm_changes_accepted():
)
# Simulate tool result message with acceptance
tool_result = {"accepted": True, "steps": []}
input_data = {
tool_result: dict[str, Any] = {"accepted": True, "steps": []}
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool", # Tool result from UI
@@ -240,7 +281,7 @@ async def test_tool_result_confirm_changes_accepted():
"state": {"document": "Updated content"},
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -262,16 +303,17 @@ async def test_tool_result_confirm_changes_rejected():
"""Test confirm_changes tool result handling when rejected."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
# Simulate tool result message with rejection
tool_result = {"accepted": False, "steps": []}
input_data = {
tool_result: dict[str, Any] = {"accepted": False, "steps": []}
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool",
@@ -281,7 +323,7 @@ async def test_tool_result_confirm_changes_rejected():
],
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -295,22 +337,23 @@ async def test_tool_result_function_approval_accepted():
"""Test function approval tool result when steps are accepted."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
# Simulate tool result with multiple steps
tool_result = {
tool_result: dict[str, Any] = {
"accepted": True,
"steps": [
{"id": "step1", "description": "Send email", "status": "enabled"},
{"id": "step2", "description": "Create calendar event", "status": "enabled"},
],
}
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool",
@@ -320,7 +363,7 @@ async def test_tool_result_function_approval_accepted():
],
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -340,19 +383,20 @@ async def test_tool_result_function_approval_rejected():
"""Test function approval tool result when rejected."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
# Simulate tool result rejection with steps
tool_result = {
tool_result: dict[str, Any] = {
"accepted": False,
"steps": [{"id": "step1", "description": "Send email", "status": "disabled"}],
}
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool",
@@ -362,7 +406,7 @@ async def test_tool_result_function_approval_rejected():
],
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -376,17 +420,16 @@ async def test_thread_metadata_tracking():
"""Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id."""
from agent_framework.ag_ui import AgentFrameworkAgent
thread_metadata = {}
thread_metadata: dict[str, Any] = {}
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
# Capture thread metadata from kwargs
nonlocal thread_metadata
if "thread" in kwargs:
thread_metadata = kwargs["thread"].metadata
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
if chat_options.metadata:
thread_metadata.update(chat_options.metadata)
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {
@@ -395,28 +438,28 @@ async def test_thread_metadata_tracking():
"run_id": "test_run_456",
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# Check thread metadata was set
# Note: This test may need adjustment based on actual thread passing mechanism
assert thread_metadata.get("ag_ui_thread_id") == "test_thread_123"
assert thread_metadata.get("ag_ui_run_id") == "test_run_456"
async def test_state_context_injection():
"""Test that current state is injected into thread metadata."""
from agent_framework.ag_ui import AgentFrameworkAgent
thread_metadata = {}
thread_metadata: dict[str, Any] = {}
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
# Track if state context message was added
nonlocal thread_metadata
# In actual implementation, thread is passed and state is in metadata
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
if chat_options.metadata:
thread_metadata.update(chat_options.metadata)
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"document": {"type": "string"}},
@@ -427,27 +470,31 @@ async def test_state_context_injection():
"state": {"document": "Test content"},
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
# State should be injected - this is validated by agent execution flow
current_state = thread_metadata.get("current_state")
if isinstance(current_state, str):
current_state = json.loads(current_state)
assert current_state == {"document": "Test content"}
async def test_no_messages_provided():
"""Test handling when no messages are provided."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": []}
input_data: dict[str, Any] = {"messages": []}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -461,16 +508,17 @@ async def test_message_end_event_emission():
"""Test TextMessageEndEvent is emitted for assistant messages."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello world")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Hello world")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -488,19 +536,20 @@ async def test_error_handling_with_exception():
"""Test that exceptions during agent execution are re-raised."""
from agent_framework.ag_ui import AgentFrameworkAgent
class FailingChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
if False:
yield
raise RuntimeError("Simulated failure")
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
if False:
yield ChatResponseUpdate(contents=[])
raise RuntimeError("Simulated failure")
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=FailingChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}]}
with pytest.raises(RuntimeError, match="Simulated failure"):
async for event in wrapper.run_agent(input_data):
async for _ in wrapper.run_agent(input_data):
pass
@@ -508,18 +557,18 @@ async def test_json_decode_error_in_tool_result():
"""Test handling of orphaned tool result - should be sanitized out."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
# Should not be called since orphaned tool result is dropped
if False:
yield
raise AssertionError("ChatClient should not be called with orphaned tool result")
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
if False:
yield ChatResponseUpdate(contents=[])
raise AssertionError("ChatClient should not be called with orphaned tool result")
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
# Send invalid JSON as tool result without preceding tool call
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool",
@@ -529,7 +578,7 @@ async def test_json_decode_error_in_tool_result():
],
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -545,11 +594,12 @@ async def test_suppressed_summary_with_document_state():
"""Test suppressed summary uses document state for confirmation message."""
from agent_framework.ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Response")])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="Response")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(
agent=agent,
state_schema={"document": {"type": "string"}},
@@ -558,8 +608,8 @@ async def test_suppressed_summary_with_document_state():
)
# Simulate confirmation with document state
tool_result = {"accepted": True, "steps": []}
input_data = {
tool_result: dict[str, Any] = {"accepted": True, "steps": []}
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool",
@@ -570,7 +620,7 @@ async def test_suppressed_summary_with_document_state():
"state": {"document": "This is the beginning of a document. It contains important information."},
}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -2,6 +2,8 @@
"""Tests for backend tool rendering."""
from typing import cast
from ag_ui.core import (
TextMessageContentEvent,
TextMessageStartEvent,
@@ -119,6 +121,9 @@ async def test_multiple_tool_results():
assert isinstance(events[end_idx], ToolCallEndEvent)
assert isinstance(events[result_idx], ToolCallResultEvent)
assert events[end_idx].tool_call_id == f"tool-{i + 1}"
assert events[result_idx].tool_call_id == f"tool-{i + 1}"
assert f"Result {i + 1}" in events[result_idx].content
end_event = cast(ToolCallEndEvent, events[end_idx])
result_event = cast(ToolCallResultEvent, events[result_idx])
assert end_event.tool_call_id == f"tool-{i + 1}"
assert result_event.tool_call_id == f"tool-{i + 1}"
assert f"Result {i + 1}" in result_event.content
@@ -14,7 +14,7 @@ from agent_framework_ag_ui._confirmation_strategies import (
@pytest.fixture
def sample_steps():
def sample_steps() -> list[dict[str, str]]:
"""Sample steps for testing approval messages."""
return [
{"description": "Step 1: Do something", "status": "enabled"},
@@ -24,7 +24,7 @@ def sample_steps():
@pytest.fixture
def all_enabled_steps():
def all_enabled_steps() -> list[dict[str, str]]:
"""All steps enabled."""
return [
{"description": "Task A", "status": "enabled"},
@@ -34,7 +34,7 @@ def all_enabled_steps():
@pytest.fixture
def empty_steps():
def empty_steps() -> list[dict[str, str]]:
"""Empty steps list."""
return []
@@ -42,7 +42,7 @@ def empty_steps():
class TestDefaultConfirmationStrategy:
"""Tests for DefaultConfirmationStrategy."""
def test_on_approval_accepted_with_enabled_steps(self, sample_steps):
def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None:
strategy = DefaultConfirmationStrategy()
message = strategy.on_approval_accepted(sample_steps)
@@ -52,7 +52,7 @@ class TestDefaultConfirmationStrategy:
assert "Step 3" not in message # Disabled step shouldn't appear
assert "All steps completed successfully!" in message
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps):
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None:
strategy = DefaultConfirmationStrategy()
message = strategy.on_approval_accepted(all_enabled_steps)
@@ -61,28 +61,28 @@ class TestDefaultConfirmationStrategy:
assert "Task B" in message
assert "Task C" in message
def test_on_approval_accepted_with_empty_steps(self, empty_steps):
def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None:
strategy = DefaultConfirmationStrategy()
message = strategy.on_approval_accepted(empty_steps)
assert "Executing 0 approved steps" in message
assert "All steps completed successfully!" in message
def test_on_approval_rejected(self, sample_steps):
def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None:
strategy = DefaultConfirmationStrategy()
message = strategy.on_approval_rejected(sample_steps)
assert "No problem!" in message
assert "What would you like me to change" in message
def test_on_state_confirmed(self):
def test_on_state_confirmed(self) -> None:
strategy = DefaultConfirmationStrategy()
message = strategy.on_state_confirmed()
assert "Changes confirmed" in message
assert "successfully" in message
def test_on_state_rejected(self):
def test_on_state_rejected(self) -> None:
strategy = DefaultConfirmationStrategy()
message = strategy.on_state_rejected()
@@ -93,7 +93,7 @@ class TestDefaultConfirmationStrategy:
class TestTaskPlannerConfirmationStrategy:
"""Tests for TaskPlannerConfirmationStrategy."""
def test_on_approval_accepted_with_enabled_steps(self, sample_steps):
def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None:
strategy = TaskPlannerConfirmationStrategy()
message = strategy.on_approval_accepted(sample_steps)
@@ -103,7 +103,7 @@ class TestTaskPlannerConfirmationStrategy:
assert "Step 3" not in message
assert "All tasks completed successfully!" in message
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps):
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None:
strategy = TaskPlannerConfirmationStrategy()
message = strategy.on_approval_accepted(all_enabled_steps)
@@ -112,28 +112,28 @@ class TestTaskPlannerConfirmationStrategy:
assert "2. Task B" in message
assert "3. Task C" in message
def test_on_approval_accepted_with_empty_steps(self, empty_steps):
def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None:
strategy = TaskPlannerConfirmationStrategy()
message = strategy.on_approval_accepted(empty_steps)
assert "Executing your requested tasks" in message
assert "All tasks completed successfully!" in message
def test_on_approval_rejected(self, sample_steps):
def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None:
strategy = TaskPlannerConfirmationStrategy()
message = strategy.on_approval_rejected(sample_steps)
assert "No problem!" in message
assert "revise the plan" in message
def test_on_state_confirmed(self):
def test_on_state_confirmed(self) -> None:
strategy = TaskPlannerConfirmationStrategy()
message = strategy.on_state_confirmed()
assert "Tasks confirmed" in message
assert "ready to execute" in message
def test_on_state_rejected(self):
def test_on_state_rejected(self) -> None:
strategy = TaskPlannerConfirmationStrategy()
message = strategy.on_state_rejected()
@@ -144,7 +144,7 @@ class TestTaskPlannerConfirmationStrategy:
class TestRecipeConfirmationStrategy:
"""Tests for RecipeConfirmationStrategy."""
def test_on_approval_accepted_with_enabled_steps(self, sample_steps):
def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None:
strategy = RecipeConfirmationStrategy()
message = strategy.on_approval_accepted(sample_steps)
@@ -154,7 +154,7 @@ class TestRecipeConfirmationStrategy:
assert "Step 3" not in message
assert "Recipe updated successfully!" in message
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps):
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None:
strategy = RecipeConfirmationStrategy()
message = strategy.on_approval_accepted(all_enabled_steps)
@@ -163,28 +163,28 @@ class TestRecipeConfirmationStrategy:
assert "2. Task B" in message
assert "3. Task C" in message
def test_on_approval_accepted_with_empty_steps(self, empty_steps):
def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None:
strategy = RecipeConfirmationStrategy()
message = strategy.on_approval_accepted(empty_steps)
assert "Updating your recipe" in message
assert "Recipe updated successfully!" in message
def test_on_approval_rejected(self, sample_steps):
def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None:
strategy = RecipeConfirmationStrategy()
message = strategy.on_approval_rejected(sample_steps)
assert "No problem!" in message
assert "ingredients or steps" in message
def test_on_state_confirmed(self):
def test_on_state_confirmed(self) -> None:
strategy = RecipeConfirmationStrategy()
message = strategy.on_state_confirmed()
assert "Recipe changes applied" in message
assert "successfully" in message
def test_on_state_rejected(self):
def test_on_state_rejected(self) -> None:
strategy = RecipeConfirmationStrategy()
message = strategy.on_state_rejected()
@@ -195,7 +195,7 @@ class TestRecipeConfirmationStrategy:
class TestDocumentWriterConfirmationStrategy:
"""Tests for DocumentWriterConfirmationStrategy."""
def test_on_approval_accepted_with_enabled_steps(self, sample_steps):
def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None:
strategy = DocumentWriterConfirmationStrategy()
message = strategy.on_approval_accepted(sample_steps)
@@ -205,7 +205,7 @@ class TestDocumentWriterConfirmationStrategy:
assert "Step 3" not in message
assert "Document updated successfully!" in message
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps):
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None:
strategy = DocumentWriterConfirmationStrategy()
message = strategy.on_approval_accepted(all_enabled_steps)
@@ -214,27 +214,27 @@ class TestDocumentWriterConfirmationStrategy:
assert "2. Task B" in message
assert "3. Task C" in message
def test_on_approval_accepted_with_empty_steps(self, empty_steps):
def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None:
strategy = DocumentWriterConfirmationStrategy()
message = strategy.on_approval_accepted(empty_steps)
assert "Applying your edits" in message
assert "Document updated successfully!" in message
def test_on_approval_rejected(self, sample_steps):
def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None:
strategy = DocumentWriterConfirmationStrategy()
message = strategy.on_approval_rejected(sample_steps)
assert "No problem!" in message
assert "keep or modify" in message
def test_on_state_confirmed(self):
def test_on_state_confirmed(self) -> None:
strategy = DocumentWriterConfirmationStrategy()
message = strategy.on_state_confirmed()
assert "Document edits applied!" in message
def test_on_state_rejected(self):
def test_on_state_rejected(self) -> None:
strategy = DocumentWriterConfirmationStrategy()
message = strategy.on_state_rejected()
@@ -2,7 +2,7 @@
"""Tests for document writer predictive state flow with confirm_changes."""
from ag_ui.core import EventType
from ag_ui.core import EventType, StateDeltaEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent
from agent_framework import FunctionCallContent, FunctionResultContent, TextContent
from agent_framework._types import AgentRunResponseUpdate
@@ -35,16 +35,12 @@ async def test_streaming_document_with_state_deltas():
assert any(e.type == EventType.TOOL_CALL_ARGS for e in events1)
# Second chunk - incomplete JSON, should try partial extraction
tool_call_chunk2 = FunctionCallContent(
call_id="call_123",
name=None, # Name only in first chunk
arguments=" upon a time",
)
tool_call_chunk2 = FunctionCallContent(call_id="call_123", name="write_document_local", arguments=" upon a time")
update2 = AgentRunResponseUpdate(contents=[tool_call_chunk2])
events2 = await bridge.from_agent_run_update(update2)
# Should emit StateDeltaEvent with partial document
state_deltas = [e for e in events2 if e.type == EventType.STATE_DELTA]
state_deltas = [e for e in events2 if isinstance(e, StateDeltaEvent)]
assert len(state_deltas) >= 1
# Check JSON Patch format
@@ -62,7 +58,7 @@ async def test_confirm_changes_emission():
"document": {"tool": "write_document_local", "tool_argument": "document"},
}
current_state = {}
current_state: dict[str, str] = {}
bridge = AgentFrameworkEventBridge(
run_id="test_run",
@@ -90,15 +86,13 @@ async def test_confirm_changes_emission():
assert any(e.type == EventType.STATE_SNAPSHOT for e in events)
# Check for confirm_changes tool call
confirm_starts = [
e for e in events if e.type == EventType.TOOL_CALL_START and e.tool_call_name == "confirm_changes"
]
confirm_starts = [e for e in events if isinstance(e, ToolCallStartEvent) and e.tool_call_name == "confirm_changes"]
assert len(confirm_starts) == 1
confirm_args = [e for e in events if e.type == EventType.TOOL_CALL_ARGS and e.delta == "{}"]
confirm_args = [e for e in events if isinstance(e, ToolCallArgsEvent) and e.delta == "{}"]
assert len(confirm_args) >= 1
confirm_ends = [e for e in events if e.type == EventType.TOOL_CALL_END]
confirm_ends = [e for e in events if isinstance(e, ToolCallEndEvent)]
# At least 2: one for write_document_local, one for confirm_changes
assert len(confirm_ends) >= 2
@@ -141,7 +135,7 @@ async def test_no_confirm_for_non_predictive_tools():
"document": {"tool": "write_document_local", "tool_argument": "document"},
}
current_state = {}
current_state: dict[str, str] = {}
bridge = AgentFrameworkEventBridge(
run_id="test_run",
@@ -162,9 +156,7 @@ async def test_no_confirm_for_non_predictive_tools():
events = await bridge.from_agent_run_update(update)
# Should NOT have confirm_changes
confirm_starts = [
e for e in events if e.type == EventType.TOOL_CALL_START and e.tool_call_name == "confirm_changes"
]
confirm_starts = [e for e in events if isinstance(e, ToolCallStartEvent) and e.tool_call_name == "confirm_changes"]
assert len(confirm_starts) == 0
# Stop flag should NOT be set
@@ -193,14 +185,14 @@ async def test_state_delta_deduplication():
events1 = await bridge.from_agent_run_update(update1)
# Count state deltas
state_deltas_1 = [e for e in events1 if e.type == EventType.STATE_DELTA]
state_deltas_1 = [e for e in events1 if isinstance(e, StateDeltaEvent)]
assert len(state_deltas_1) >= 1
# Second tool call with SAME document (shouldn't emit new delta)
bridge.current_tool_call_name = "write_document_local"
tool_call2 = FunctionCallContent(
call_id="call_2",
name=None,
name="write_document_local",
arguments='{"document":"Same text"}', # Identical content
)
update2 = AgentRunResponseUpdate(contents=[tool_call2])
@@ -234,7 +226,7 @@ async def test_predict_state_config_multiple_fields():
events = await bridge.from_agent_run_update(update)
# Should emit StateDeltaEvent for both fields
state_deltas = [e for e in events if e.type == EventType.STATE_DELTA]
state_deltas = [e for e in events if isinstance(e, StateDeltaEvent)]
assert len(state_deltas) >= 2
# Check both fields are present
+49 -23
View File
@@ -3,7 +3,8 @@
"""Tests for FastAPI endpoint creation (_endpoint.py)."""
import json
from typing import Any
import sys
from pathlib import Path
from agent_framework import ChatAgent, TextContent
from agent_framework._types import ChatResponseUpdate
@@ -13,22 +14,20 @@ from fastapi.testclient import TestClient
from agent_framework_ag_ui._agent import AgentFrameworkAgent
from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint
sys.path.insert(0, str(Path(__file__).parent))
from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates
class MockChatClient:
"""Mock chat client for testing."""
def __init__(self, response_text: str = "Test response"):
self.response_text = response_text
async def get_streaming_response(self, messages: list[Any], chat_options: Any, **kwargs: Any):
"""Mock streaming response."""
yield ChatResponseUpdate(contents=[TextContent(text=self.response_text)])
def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub:
"""Create a typed chat client stub for endpoint tests."""
updates = [ChatResponseUpdate(contents=[TextContent(text=response_text)])]
return StreamingChatClientStub(stream_from_updates(updates))
async def test_add_endpoint_with_agent_protocol():
"""Test adding endpoint with raw AgentProtocol."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
add_agent_framework_fastapi_endpoint(app, agent, path="/test-agent")
@@ -42,7 +41,7 @@ async def test_add_endpoint_with_agent_protocol():
async def test_add_endpoint_with_wrapped_agent():
"""Test adding endpoint with pre-wrapped AgentFrameworkAgent."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
wrapped_agent = AgentFrameworkAgent(agent=agent, name="wrapped")
add_agent_framework_fastapi_endpoint(app, wrapped_agent, path="/wrapped-agent")
@@ -57,7 +56,7 @@ async def test_add_endpoint_with_wrapped_agent():
async def test_endpoint_with_state_schema():
"""Test endpoint with state_schema parameter."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
state_schema = {"document": {"type": "string"}}
add_agent_framework_fastapi_endpoint(app, agent, path="/stateful", state_schema=state_schema)
@@ -70,10 +69,37 @@ async def test_endpoint_with_state_schema():
assert response.status_code == 200
async def test_endpoint_with_default_state_seed():
"""Test endpoint seeds default state when client omits it."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
state_schema = {"proverbs": {"type": "array"}}
default_state = {"proverbs": ["Keep the original."]}
add_agent_framework_fastapi_endpoint(
app,
agent,
path="/default-state",
state_schema=state_schema,
default_state=default_state,
)
client = TestClient(app)
response = client.post("/default-state", json={"messages": [{"role": "user", "content": "Hello"}]})
assert response.status_code == 200
content = response.content.decode("utf-8")
lines = [line for line in content.split("\n") if line.startswith("data: ")]
snapshots = [json.loads(line[6:]) for line in lines if json.loads(line[6:]).get("type") == "STATE_SNAPSHOT"]
assert snapshots, "Expected a STATE_SNAPSHOT event"
assert snapshots[0]["snapshot"]["proverbs"] == default_state["proverbs"]
async def test_endpoint_with_predict_state_config():
"""Test endpoint with predict_state_config parameter."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
add_agent_framework_fastapi_endpoint(app, agent, path="/predictive", predict_state_config=predict_config)
@@ -87,7 +113,7 @@ async def test_endpoint_with_predict_state_config():
async def test_endpoint_request_logging():
"""Test that endpoint logs request details."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
add_agent_framework_fastapi_endpoint(app, agent, path="/logged")
@@ -107,7 +133,7 @@ async def test_endpoint_request_logging():
async def test_endpoint_event_streaming():
"""Test that endpoint streams events correctly."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient("Streamed response"))
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client("Streamed response"))
add_agent_framework_fastapi_endpoint(app, agent, path="/stream")
@@ -141,14 +167,14 @@ async def test_endpoint_event_streaming():
async def test_endpoint_error_handling():
"""Test endpoint error handling during request parsing."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
add_agent_framework_fastapi_endpoint(app, agent, path="/failing")
client = TestClient(app)
# Send invalid JSON to trigger parsing error before streaming
response = client.post("/failing", data="invalid json", headers={"content-type": "application/json"})
response = client.post("/failing", data=b"invalid json", headers={"content-type": "application/json"}) # type: ignore
# The exception handler catches it and returns JSON error
assert response.status_code == 200
@@ -160,8 +186,8 @@ async def test_endpoint_error_handling():
async def test_endpoint_multiple_paths():
"""Test adding multiple endpoints with different paths."""
app = FastAPI()
agent1 = ChatAgent(name="agent1", instructions="First agent", chat_client=MockChatClient("Response 1"))
agent2 = ChatAgent(name="agent2", instructions="Second agent", chat_client=MockChatClient("Response 2"))
agent1 = ChatAgent(name="agent1", instructions="First agent", chat_client=build_chat_client("Response 1"))
agent2 = ChatAgent(name="agent2", instructions="Second agent", chat_client=build_chat_client("Response 2"))
add_agent_framework_fastapi_endpoint(app, agent1, path="/agent1")
add_agent_framework_fastapi_endpoint(app, agent2, path="/agent2")
@@ -178,7 +204,7 @@ async def test_endpoint_multiple_paths():
async def test_endpoint_default_path():
"""Test endpoint with default path."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
add_agent_framework_fastapi_endpoint(app, agent)
@@ -191,7 +217,7 @@ async def test_endpoint_default_path():
async def test_endpoint_response_headers():
"""Test that endpoint sets correct response headers."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
add_agent_framework_fastapi_endpoint(app, agent, path="/headers")
@@ -207,7 +233,7 @@ async def test_endpoint_response_headers():
async def test_endpoint_empty_messages():
"""Test endpoint with empty messages list."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
add_agent_framework_fastapi_endpoint(app, agent, path="/empty")
@@ -220,7 +246,7 @@ async def test_endpoint_empty_messages():
async def test_endpoint_complex_input():
"""Test endpoint with complex input data."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
add_agent_framework_fastapi_endpoint(app, agent, path="/complex")
@@ -1,3 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for AG-UI event converter."""
from agent_framework import FinishReason, Role
@@ -0,0 +1,138 @@
# Copyright (c) Microsoft. All rights reserved.
"""Shared test stubs for AG-UI tests."""
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, MutableSequence
from types import SimpleNamespace
from typing import Any
from agent_framework import (
AgentProtocol,
AgentRunResponse,
AgentRunResponseUpdate,
AgentThread,
ChatMessage,
ChatOptions,
TextContent,
)
from agent_framework._clients import BaseChatClient
from agent_framework._types import ChatResponse, ChatResponseUpdate
from agent_framework_ag_ui._orchestrators import ExecutionContext
StreamFn = Callable[..., AsyncIterator[ChatResponseUpdate]]
ResponseFn = Callable[..., Awaitable[ChatResponse]]
class StreamingChatClientStub(BaseChatClient):
"""Typed streaming stub that satisfies ChatClientProtocol."""
def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None:
super().__init__()
self._stream_fn = stream_fn
self._response_fn = response_fn
async def _inner_get_streaming_response(
self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
async for update in self._stream_fn(messages, chat_options, **kwargs):
yield update
async def _inner_get_response(
self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> ChatResponse:
if self._response_fn is not None:
return await self._response_fn(messages, chat_options, **kwargs)
contents: list[Any] = []
async for update in self._stream_fn(messages, chat_options, **kwargs):
contents.extend(update.contents)
return ChatResponse(
messages=[ChatMessage(role="assistant", contents=contents)],
response_id="stub-response",
)
def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn:
"""Create a stream function that yields from a static list of updates."""
async def _stream(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
for update in updates:
yield update
return _stream
class StubAgent(AgentProtocol):
"""Minimal AgentProtocol stub for orchestrator tests."""
def __init__(
self,
updates: list[AgentRunResponseUpdate] | None = None,
*,
agent_id: str = "stub-agent",
agent_name: str | None = "stub-agent",
chat_options: Any | None = None,
chat_client: Any | None = None,
) -> None:
self._id = agent_id
self._name = agent_name
self._description = "stub agent"
self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")]
self.chat_options = chat_options or SimpleNamespace(tools=None, response_format=None)
self.chat_client = chat_client or SimpleNamespace(function_invocation_configuration=None)
self.messages_received: list[Any] = []
self.tools_received: list[Any] | None = None
@property
def id(self) -> str:
return self._id
@property
def name(self) -> str | None:
return self._name
@property
def display_name(self) -> str:
return self._name or self._id
@property
def description(self) -> str | None:
return self._description
async def run(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
) -> AgentRunResponse:
return AgentRunResponse(messages=[], response_id="stub-response")
def run_stream(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
async def _stream() -> AsyncIterator[AgentRunResponseUpdate]:
self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type]
self.tools_received = kwargs.get("tools")
for update in self.updates:
yield update
return _stream()
def get_new_thread(self, **kwargs: Any) -> AgentThread:
return AgentThread()
class TestExecutionContext(ExecutionContext):
"""ExecutionContext helper that allows setting messages for tests."""
def set_messages(self, messages: list[ChatMessage]) -> None:
self._messages = messages
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft. All rights reserved.
from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent
from agent_framework_ag_ui._orchestration._message_hygiene import (
deduplicate_messages,
sanitize_tool_history,
)
def test_sanitize_tool_history_injects_confirm_changes_result() -> None:
messages = [
ChatMessage(
role="assistant",
contents=[
FunctionCallContent(
name="confirm_changes",
call_id="call_confirm_123",
arguments='{"changes": "test"}',
)
],
),
ChatMessage(
role="user",
contents=[TextContent(text='{"accepted": true}')],
),
]
sanitized = sanitize_tool_history(messages)
tool_messages = [
msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool"
]
assert len(tool_messages) == 1
assert str(tool_messages[0].contents[0].call_id) == "call_confirm_123"
assert tool_messages[0].contents[0].result == "Confirmed"
def test_deduplicate_messages_prefers_non_empty_tool_results() -> None:
messages = [
ChatMessage(
role="tool",
contents=[FunctionResultContent(call_id="call1", result="")],
),
ChatMessage(
role="tool",
contents=[FunctionResultContent(call_id="call1", result="result data")],
),
]
deduped = deduplicate_messages(messages)
assert len(deduped) == 1
assert deduped[0].contents[0].result == "result data"
@@ -1,3 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for AG-UI orchestrators."""
from collections.abc import AsyncGenerator
@@ -34,6 +36,7 @@ class DummyAgent:
*,
thread: Any,
tools: list[Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
self.seen_tools = tools
yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
@@ -2,7 +2,9 @@
"""Comprehensive tests for orchestrator coverage."""
import sys
from collections.abc import AsyncGenerator
from pathlib import Path
from types import SimpleNamespace
from typing import Any
@@ -15,11 +17,10 @@ from agent_framework import (
from pydantic import BaseModel
from agent_framework_ag_ui._agent import AgentConfig
from agent_framework_ag_ui._orchestrators import (
DefaultOrchestrator,
ExecutionContext,
HumanInTheLoopOrchestrator,
)
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, HumanInTheLoopOrchestrator
sys.path.insert(0, str(Path(__file__).parent))
from test_helpers_ag_ui import StubAgent, TestExecutionContext
@ai_function(approval_mode="always_require")
@@ -28,34 +29,14 @@ def approval_tool(param: str) -> str:
return f"executed: {param}"
class MockAgent:
"""Mock agent for testing."""
def __init__(self, updates: list[AgentRunResponseUpdate] | None = None) -> None:
self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")]
self.chat_options = SimpleNamespace(tools=[approval_tool], response_format=None)
self.chat_client = SimpleNamespace(function_invocation_configuration=None)
self.messages_received: list[Any] = []
self.tools_received: list[Any] | None = None
async def run_stream(
self,
messages: list[Any],
*,
thread: Any = None,
tools: list[Any] | None = None,
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
self.messages_received = messages
self.tools_received = tools
for update in self.updates:
yield update
DEFAULT_CHAT_OPTIONS = SimpleNamespace(tools=[approval_tool], response_format=None)
async def test_human_in_the_loop_json_decode_error() -> None:
"""Test HumanInTheLoopOrchestrator handles invalid JSON in tool result."""
orchestrator = HumanInTheLoopOrchestrator()
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool",
@@ -72,21 +53,25 @@ async def test_human_in_the_loop_json_decode_error() -> None:
)
]
context = ExecutionContext(
agent = StubAgent(
chat_options=SimpleNamespace(tools=[approval_tool], response_format=None),
updates=[AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")],
)
context = TestExecutionContext(
input_data=input_data,
agent=MockAgent(),
agent=agent,
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
assert orchestrator.can_handle(context)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
# Should emit RunErrorEvent for invalid JSON
error_events = [e for e in events if e.type == "RUN_ERROR"]
error_events: list[Any] = [e for e in events if e.type == "RUN_ERROR"]
assert len(error_events) == 1
assert "Invalid tool result format" in error_events[0].message
@@ -118,18 +103,20 @@ async def test_sanitize_tool_history_confirm_changes() -> None:
orchestrator = DefaultOrchestrator()
# Use pre-constructed ChatMessage objects to bypass message adapter
input_data = {"messages": []}
input_data: dict[str, Any] = {"messages": []}
agent = MockAgent()
context = ExecutionContext(
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
# Override the messages property to use our pre-constructed messages
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -162,16 +149,18 @@ async def test_sanitize_tool_history_orphaned_tool_result() -> None:
]
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
agent = MockAgent()
context = ExecutionContext(
input_data: dict[str, Any] = {"messages": []}
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -188,7 +177,7 @@ async def test_orphaned_tool_result_sanitization() -> None:
"""Test that orphaned tool results are filtered out."""
orchestrator = DefaultOrchestrator()
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool",
@@ -201,14 +190,16 @@ async def test_orphaned_tool_result_sanitization() -> None:
],
}
agent = MockAgent()
context = ExecutionContext(
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -241,16 +232,18 @@ async def test_deduplicate_messages_empty_tool_results() -> None:
]
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
agent = MockAgent()
context = ExecutionContext(
input_data: dict[str, Any] = {"messages": []}
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -284,16 +277,18 @@ async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None:
]
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
agent = MockAgent()
context = ExecutionContext(
input_data: dict[str, Any] = {"messages": []}
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -326,16 +321,18 @@ async def test_deduplicate_messages_duplicate_system_messages() -> None:
]
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
agent = MockAgent()
context = ExecutionContext(
input_data: dict[str, Any] = {"messages": []}
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -354,7 +351,7 @@ async def test_state_context_injection() -> None:
"""Test state context message injection for first request."""
orchestrator = DefaultOrchestrator()
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "user",
@@ -364,14 +361,16 @@ async def test_state_context_injection() -> None:
"state": {"items": ["apple", "banana"]},
}
agent = MockAgent()
context = ExecutionContext(
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(state_schema={"items": {"type": "array"}}),
)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -406,16 +405,18 @@ async def test_no_state_context_injection_with_tool_calls() -> None:
]
orchestrator = DefaultOrchestrator()
input_data = {"messages": [], "state": {"weather": "sunny"}}
agent = MockAgent()
context = ExecutionContext(
input_data: dict[str, Any] = {"messages": [], "state": {"weather": "sunny"}}
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(state_schema={"weather": {"type": "string"}}),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -437,7 +438,7 @@ async def test_structured_output_processing() -> None:
orchestrator = DefaultOrchestrator()
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "user",
@@ -447,32 +448,33 @@ async def test_structured_output_processing() -> None:
}
# Agent with structured output
agent = MockAgent(
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
updates=[
AgentRunResponseUpdate(
contents=[TextContent(text='{"ingredients": ["tomato"], "message": "Added tomato"}')],
role="assistant",
)
]
],
)
agent.chat_options.response_format = RecipeState
context = ExecutionContext(
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(state_schema={"ingredients": {"type": "array"}}),
)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
# Should emit StateSnapshotEvent with ingredients
state_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
state_events: list[Any] = [e for e in events if e.type == "STATE_SNAPSHOT"]
assert len(state_events) >= 1
# Should emit TextMessage with message field
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
text_content_events: list[Any] = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
assert len(text_content_events) >= 1
assert any("Added tomato" in e.delta for e in text_content_events)
@@ -487,7 +489,7 @@ async def test_duplicate_client_tools_filtered() -> None:
orchestrator = DefaultOrchestrator()
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "user",
@@ -507,16 +509,18 @@ async def test_duplicate_client_tools_filtered() -> None:
],
}
agent = MockAgent()
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
agent.chat_options.tools = [get_weather]
context = ExecutionContext(
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -534,7 +538,7 @@ async def test_unique_client_tools_merged() -> None:
orchestrator = DefaultOrchestrator()
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "user",
@@ -554,16 +558,18 @@ async def test_unique_client_tools_merged() -> None:
],
}
agent = MockAgent()
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
agent.chat_options.tools = [server_tool]
context = ExecutionContext(
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -578,16 +584,18 @@ async def test_empty_messages_handling() -> None:
"""Test orchestrator handles empty message list gracefully."""
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
input_data: dict[str, Any] = {"messages": []}
agent = MockAgent()
context = ExecutionContext(
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -603,7 +611,7 @@ async def test_all_messages_filtered_handling() -> None:
"""Test orchestrator handles case where all messages are filtered out."""
orchestrator = DefaultOrchestrator()
input_data = {
input_data: dict[str, Any] = {
"messages": [
{
"role": "tool",
@@ -612,14 +620,16 @@ async def test_all_messages_filtered_handling() -> None:
]
}
agent = MockAgent()
context = ExecutionContext(
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -651,16 +661,18 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None:
]
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
agent = MockAgent()
context = ExecutionContext(
input_data: dict[str, Any] = {"messages": []}
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -689,16 +701,18 @@ async def test_tool_result_kept_when_call_id_matches() -> None:
]
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
agent = MockAgent()
context = ExecutionContext(
input_data: dict[str, Any] = {"messages": []}
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -729,6 +743,7 @@ async def test_agent_protocol_fallback_paths() -> None:
*,
thread: Any = None,
tools: list[Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
self.messages_received = messages
yield AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")
@@ -738,16 +753,16 @@ async def test_agent_protocol_fallback_paths() -> None:
messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
input_data: dict[str, Any] = {"messages": []}
agent = CustomAgent()
context = ExecutionContext(
context = TestExecutionContext(
input_data=input_data,
agent=agent, # type: ignore
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -762,21 +777,23 @@ async def test_initial_state_snapshot_with_array_schema() -> None:
messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
orchestrator = DefaultOrchestrator()
input_data = {"messages": [], "state": {}}
agent = MockAgent()
context = ExecutionContext(
input_data: dict[str, Any] = {"messages": [], "state": {}}
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(state_schema={"items": {"type": "array"}}),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
# Should emit state snapshot with empty array for items
state_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
state_events: list[Any] = [e for e in events if e.type == "STATE_SNAPSHOT"]
assert len(state_events) >= 1
@@ -791,19 +808,21 @@ async def test_response_format_skip_text_content() -> None:
messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
orchestrator = DefaultOrchestrator()
input_data = {"messages": []}
input_data: dict[str, Any] = {"messages": []}
agent = MockAgent()
agent = StubAgent(
chat_options=DEFAULT_CHAT_OPTIONS,
)
agent.chat_options.response_format = OutputModel
context = ExecutionContext(
context = TestExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
context._messages = messages
context.set_messages(messages)
events = []
events: list[Any] = []
async for event in orchestrator.run(context):
events.append(event)
@@ -2,6 +2,10 @@
"""Tests for shared state management."""
import sys
from pathlib import Path
from typing import Any
import pytest
from ag_ui.core import StateSnapshotEvent
from agent_framework import ChatAgent, TextContent
@@ -10,20 +14,16 @@ from agent_framework._types import ChatResponseUpdate
from agent_framework_ag_ui._agent import AgentFrameworkAgent
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
sys.path.insert(0, str(Path(__file__).parent))
from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates
@pytest.fixture
def mock_agent():
def mock_agent() -> ChatAgent:
"""Create a mock agent for testing."""
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Hello!")])
return ChatAgent(
name="test_agent",
instructions="Test agent",
chat_client=MockChatClient(),
)
updates = [ChatResponseUpdate(contents=[TextContent(text="Hello!")])]
chat_client = StreamingChatClientStub(stream_from_updates(updates))
return ChatAgent(name="test_agent", instructions="Test agent", chat_client=chat_client)
def test_state_snapshot_event():
@@ -65,9 +65,9 @@ def test_state_delta_event():
assert event.delta[1]["op"] == "replace"
async def test_agent_with_initial_state(mock_agent):
async def test_agent_with_initial_state(mock_agent: ChatAgent) -> None:
"""Test agent emits state snapshot when initial state provided."""
state_schema = {"recipe": {"type": "object", "properties": {"name": {"type": "string"}}}}
state_schema: dict[str, Any] = {"recipe": {"type": "object", "properties": {"name": {"type": "string"}}}}
agent = AgentFrameworkAgent(
agent=mock_agent,
@@ -76,12 +76,12 @@ async def test_agent_with_initial_state(mock_agent):
initial_state = {"recipe": {"name": "Test Recipe"}}
input_data = {
input_data: dict[str, Any] = {
"messages": [{"role": "user", "content": "Hello"}],
"state": initial_state,
}
events = []
events: list[Any] = []
async for event in agent.run_agent(input_data):
events.append(event)
@@ -91,16 +91,16 @@ async def test_agent_with_initial_state(mock_agent):
assert snapshot_events[0].snapshot == initial_state
async def test_agent_without_state_schema(mock_agent):
async def test_agent_without_state_schema(mock_agent: ChatAgent) -> None:
"""Test agent doesn't emit state events without state schema."""
agent = AgentFrameworkAgent(agent=mock_agent)
input_data = {
input_data: dict[str, Any] = {
"messages": [{"role": "user", "content": "Hello"}],
"state": {"some": "state"},
}
events = []
events: list[Any] = []
async for event in agent.run_agent(input_data):
events.append(event)
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft. All rights reserved.
from ag_ui.core import CustomEvent, EventType
from agent_framework import ChatMessage, TextContent
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
from agent_framework_ag_ui._orchestration._state_manager import StateManager
def test_state_manager_initializes_defaults_and_snapshot() -> None:
state_manager = StateManager(
state_schema={"items": {"type": "array"}, "metadata": {"type": "object"}},
predict_state_config=None,
require_confirmation=True,
)
current_state = state_manager.initialize({"metadata": {"a": 1}})
bridge = AgentFrameworkEventBridge(run_id="run", thread_id="thread", current_state=current_state)
snapshot_event = state_manager.initial_snapshot_event(bridge)
assert snapshot_event is not None
assert snapshot_event.snapshot["items"] == []
assert snapshot_event.snapshot["metadata"] == {"a": 1}
def test_state_manager_predict_state_event_shape() -> None:
state_manager = StateManager(
state_schema=None,
predict_state_config={"doc": {"tool": "write_document_local", "tool_argument": "document"}},
require_confirmation=True,
)
predict_event = state_manager.predict_state_event()
assert isinstance(predict_event, CustomEvent)
assert predict_event.type == EventType.CUSTOM
assert predict_event.name == "PredictState"
assert predict_event.value[0]["state_key"] == "doc"
def test_state_context_only_when_new_user_turn() -> None:
state_manager = StateManager(
state_schema={"items": {"type": "array"}},
predict_state_config=None,
require_confirmation=True,
)
state_manager.initialize({"items": [1]})
assert state_manager.state_context_message(is_new_user_turn=False, conversation_has_tool_calls=False) is None
message = state_manager.state_context_message(is_new_user_turn=True, conversation_has_tool_calls=False)
assert isinstance(message, ChatMessage)
assert isinstance(message.contents[0], TextContent)
assert "Current state of the application" in message.contents[0].text
@@ -3,12 +3,18 @@
"""Tests for structured output handling in _agent.py."""
import json
import sys
from collections.abc import AsyncIterator, MutableSequence
from pathlib import Path
from typing import Any
from agent_framework import ChatAgent, ChatOptions, TextContent
from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent
from agent_framework._types import ChatResponseUpdate
from pydantic import BaseModel
sys.path.insert(0, str(Path(__file__).parent))
from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates
class RecipeOutput(BaseModel):
"""Test Pydantic model for recipe output."""
@@ -34,14 +40,14 @@ async def test_structured_output_with_recipe():
"""Test structured output processing with recipe state."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
# Simulate structured output
yield ChatResponseUpdate(
contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')]
)
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(
contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')]
)
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=RecipeOutput)
wrapper = AgentFrameworkAgent(
@@ -51,7 +57,7 @@ async def test_structured_output_with_recipe():
input_data = {"messages": [{"role": "user", "content": "Make pasta"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -72,17 +78,18 @@ async def test_structured_output_with_steps():
"""Test structured output processing with steps state."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
steps_data = {
"steps": [
{"id": "1", "description": "Step 1", "status": "pending"},
{"id": "2", "description": "Step 2", "status": "pending"},
]
}
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
steps_data = {
"steps": [
{"id": "1", "description": "Step 1", "status": "pending"},
{"id": "2", "description": "Step 2", "status": "pending"},
]
}
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))])
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=StepsOutput)
wrapper = AgentFrameworkAgent(
@@ -92,7 +99,7 @@ async def test_structured_output_with_steps():
input_data = {"messages": [{"role": "user", "content": "Do steps"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -111,12 +118,13 @@ async def test_structured_output_with_no_schema_match():
"""Test structured output when response fields don't match state_schema keys."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
# Response has "data" field but schema expects "result" field
yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')])
updates = [
ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')]),
]
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(
name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_from_updates(updates))
)
agent.chat_options = ChatOptions(response_format=GenericOutput)
wrapper = AgentFrameworkAgent(
@@ -126,7 +134,7 @@ async def test_structured_output_with_no_schema_match():
input_data = {"messages": [{"role": "user", "content": "Generate data"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -146,11 +154,12 @@ async def test_structured_output_without_schema():
data: dict[str, Any]
info: str
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')])
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=DataOutput)
wrapper = AgentFrameworkAgent(
@@ -160,7 +169,7 @@ async def test_structured_output_without_schema():
input_data = {"messages": [{"role": "user", "content": "Generate data"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -177,18 +186,20 @@ async def test_no_structured_output_when_no_response_format():
"""Test that structured output path is skipped when no response_format."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
yield ChatResponseUpdate(contents=[TextContent(text="Regular text")])
updates = [ChatResponseUpdate(contents=[TextContent(text="Regular text")])]
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(
name="test",
instructions="Test",
chat_client=StreamingChatClientStub(stream_from_updates(updates)),
)
# No response_format set
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -202,12 +213,13 @@ async def test_structured_output_with_message_field():
"""Test structured output that includes a message field."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"}
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))])
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"}
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))])
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=RecipeOutput)
wrapper = AgentFrameworkAgent(
@@ -217,7 +229,7 @@ async def test_structured_output_with_message_field():
input_data = {"messages": [{"role": "user", "content": "Make salad"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -236,20 +248,20 @@ async def test_empty_updates_no_structured_processing():
"""Test that empty updates don't trigger structured output processing."""
from agent_framework.ag_ui import AgentFrameworkAgent
class MockChatClient:
async def get_streaming_response(self, messages, chat_options, **kwargs):
# Return nothing
if False:
yield
async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
if False:
yield ChatResponseUpdate(contents=[])
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.chat_options = ChatOptions(response_format=RecipeOutput)
wrapper = AgentFrameworkAgent(agent=agent)
input_data = {"messages": [{"role": "user", "content": "Test"}]}
events = []
events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
@@ -0,0 +1,36 @@
# Copyright (c) Microsoft. All rights reserved.
from types import SimpleNamespace
from agent_framework_ag_ui._orchestration._tooling import merge_tools, register_additional_client_tools
class DummyTool:
def __init__(self, name: str) -> None:
self.name = name
self.declaration_only = True
def test_merge_tools_filters_duplicates() -> None:
server = [DummyTool("a"), DummyTool("b")]
client = [DummyTool("b"), DummyTool("c")]
merged = merge_tools(server, client)
assert merged is not None
names = [getattr(t, "name", None) for t in merged]
assert names == ["a", "b", "c"]
def test_register_additional_client_tools_assigns_when_configured() -> None:
class Fic:
def __init__(self) -> None:
self.additional_tools = None
holder = SimpleNamespace(function_invocation_configuration=Fic())
agent = SimpleNamespace(chat_client=holder)
tools = [DummyTool("x")]
register_additional_client_tools(agent, tools)
assert holder.function_invocation_configuration.additional_tools == tools
+8 -8
View File
@@ -20,8 +20,8 @@ def test_generate_event_id():
def test_merge_state():
"""Test state merging."""
current = {"a": 1, "b": 2}
update = {"b": 3, "c": 4}
current: dict[str, int] = {"a": 1, "b": 2}
update: dict[str, int] = {"b": 3, "c": 4}
result = merge_state(current, update)
@@ -32,8 +32,8 @@ def test_merge_state():
def test_merge_state_empty_update():
"""Test merging with empty update."""
current = {"x": 10, "y": 20}
update = {}
current: dict[str, int] = {"x": 10, "y": 20}
update: dict[str, int] = {}
result = merge_state(current, update)
@@ -43,8 +43,8 @@ def test_merge_state_empty_update():
def test_merge_state_empty_current():
"""Test merging with empty current state."""
current = {}
update = {"a": 1, "b": 2}
current: dict[str, int] = {}
update: dict[str, int] = {"a": 1, "b": 2}
result = merge_state(current, update)
@@ -53,8 +53,8 @@ def test_merge_state_empty_current():
def test_merge_state_deep_copy():
"""Test that merge_state creates a deep copy preventing mutation of original."""
current = {"recipe": {"name": "Cake", "ingredients": ["flour", "sugar"]}}
update = {"other": "value"}
current: dict[str, dict[str, object]] = {"recipe": {"name": "Cake", "ingredients": ["flour", "sugar"]}}
update: dict[str, str] = {"other": "value"}
result = merge_state(current, update)
@@ -876,7 +876,11 @@ class ChatAgent(BaseAgent):
)
# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **filtered_kwargs)
response = await self.chat_client.get_response(
messages=thread_messages,
chat_options=co,
**filtered_kwargs,
)
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
@@ -1013,7 +1017,9 @@ class ChatAgent(BaseAgent):
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
response_updates: list[ChatResponseUpdate] = []
async for update in self.chat_client.get_streaming_response(
messages=thread_messages, chat_options=co, **filtered_kwargs
messages=thread_messages,
chat_options=co,
**filtered_kwargs,
):
response_updates.append(update)