mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Refactor ag-ui to clean up some patterns (#2363)
* Refactor ag-ui to clean up some patterns * Mypy fixes * Fix imports, typing, tests, logging. * Fix test import error * Fix imports again * Fix thread handling
This commit is contained in:
committed by
GitHub
Unverified
parent
6c624319db
commit
8cf8b0f995
@@ -3,7 +3,7 @@
|
||||
"""AgentFrameworkAgent wrapper for AG-UI protocol - Clean Architecture."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from ag_ui.core import BaseEvent
|
||||
from agent_framework import AgentProtocol
|
||||
@@ -22,21 +22,48 @@ class AgentConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_schema: dict[str, Any] | None = None,
|
||||
state_schema: Any | None = None,
|
||||
predict_state_config: dict[str, dict[str, str]] | None = None,
|
||||
require_confirmation: bool = True,
|
||||
):
|
||||
"""Initialize agent configuration.
|
||||
|
||||
Args:
|
||||
state_schema: Optional state schema for state management
|
||||
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
|
||||
predict_state_config: Configuration for predictive state updates
|
||||
require_confirmation: Whether predictive updates require confirmation
|
||||
"""
|
||||
self.state_schema = state_schema or {}
|
||||
self.state_schema = self._normalize_state_schema(state_schema)
|
||||
self.predict_state_config = predict_state_config or {}
|
||||
self.require_confirmation = require_confirmation
|
||||
|
||||
@staticmethod
|
||||
def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]:
|
||||
"""Accept dict or Pydantic model/class and return a properties dict."""
|
||||
if state_schema is None:
|
||||
return {}
|
||||
|
||||
if isinstance(state_schema, dict):
|
||||
return cast(dict[str, Any], state_schema)
|
||||
|
||||
base_model_type: type[Any] | None
|
||||
try:
|
||||
from pydantic import BaseModel as ImportedBaseModel
|
||||
|
||||
base_model_type = ImportedBaseModel
|
||||
except Exception: # pragma: no cover
|
||||
base_model_type = None
|
||||
|
||||
if base_model_type is not None and isinstance(state_schema, base_model_type):
|
||||
schema_dict = state_schema.__class__.model_json_schema()
|
||||
return schema_dict.get("properties", {}) or {}
|
||||
|
||||
if base_model_type is not None and isinstance(state_schema, type) and issubclass(state_schema, base_model_type):
|
||||
schema_dict = state_schema.model_json_schema()
|
||||
return schema_dict.get("properties", {}) or {}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
class AgentFrameworkAgent:
|
||||
"""Wraps Agent Framework agents for AG-UI protocol compatibility.
|
||||
@@ -55,7 +82,7 @@ class AgentFrameworkAgent:
|
||||
agent: AgentProtocol,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
state_schema: dict[str, Any] | None = None,
|
||||
state_schema: Any | None = None,
|
||||
predict_state_config: dict[str, dict[str, str]] | None = None,
|
||||
require_confirmation: bool = True,
|
||||
orchestrators: list[Orchestrator] | None = None,
|
||||
@@ -67,7 +94,7 @@ class AgentFrameworkAgent:
|
||||
agent: The Agent Framework agent to wrap
|
||||
name: Optional name for the agent
|
||||
description: Optional description
|
||||
state_schema: Optional state schema for state management
|
||||
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
|
||||
predict_state_config: Configuration for predictive state updates.
|
||||
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
|
||||
require_confirmation: Whether predictive updates require confirmation.
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
"""FastAPI endpoint creation for AG-UI agents."""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -19,9 +20,10 @@ def add_agent_framework_fastapi_endpoint(
|
||||
app: FastAPI,
|
||||
agent: AgentProtocol | AgentFrameworkAgent,
|
||||
path: str = "/",
|
||||
state_schema: dict[str, Any] | None = None,
|
||||
state_schema: Any | None = None,
|
||||
predict_state_config: dict[str, dict[str, str]] | None = None,
|
||||
allow_origins: list[str] | None = None,
|
||||
default_state: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Add an AG-UI endpoint to a FastAPI app.
|
||||
|
||||
@@ -29,10 +31,11 @@ def add_agent_framework_fastapi_endpoint(
|
||||
app: The FastAPI application
|
||||
agent: The agent to expose (can be raw AgentProtocol or wrapped)
|
||||
path: The endpoint path
|
||||
state_schema: Optional state schema for shared state management
|
||||
state_schema: Optional state schema for shared state management; accepts dict or Pydantic model/class
|
||||
predict_state_config: Optional predictive state update configuration.
|
||||
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
|
||||
allow_origins: CORS origins (not yet implemented)
|
||||
default_state: Optional initial state to seed when the client does not provide state keys
|
||||
"""
|
||||
if isinstance(agent, AgentProtocol):
|
||||
wrapped_agent = AgentFrameworkAgent(
|
||||
@@ -52,6 +55,11 @@ def add_agent_framework_fastapi_endpoint(
|
||||
"""
|
||||
try:
|
||||
input_data = await request.json()
|
||||
if default_state:
|
||||
state = input_data.setdefault("state", {})
|
||||
for key, value in default_state.items():
|
||||
if key not in state:
|
||||
state[key] = copy.deepcopy(value)
|
||||
logger.debug(
|
||||
f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, "
|
||||
f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Message hygiene utilities for orchestrators."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
"""Normalize tool ordering and inject synthetic results for AG-UI edge cases."""
|
||||
sanitized: list[ChatMessage] = []
|
||||
pending_tool_call_ids: set[str] | None = None
|
||||
pending_confirm_changes_id: str | None = None
|
||||
|
||||
for msg in messages:
|
||||
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
|
||||
if role_value == "assistant":
|
||||
tool_ids = {
|
||||
str(content.call_id)
|
||||
for content in msg.contents or []
|
||||
if isinstance(content, FunctionCallContent) and content.call_id
|
||||
}
|
||||
confirm_changes_call = None
|
||||
for content in msg.contents or []:
|
||||
if isinstance(content, FunctionCallContent) and content.name == "confirm_changes":
|
||||
confirm_changes_call = content
|
||||
break
|
||||
|
||||
sanitized.append(msg)
|
||||
pending_tool_call_ids = tool_ids if tool_ids else None
|
||||
pending_confirm_changes_id = (
|
||||
str(confirm_changes_call.call_id) if confirm_changes_call and confirm_changes_call.call_id else None
|
||||
)
|
||||
continue
|
||||
|
||||
if role_value == "user":
|
||||
if pending_confirm_changes_id:
|
||||
user_text = ""
|
||||
for content in msg.contents or []:
|
||||
if isinstance(content, TextContent):
|
||||
user_text = content.text
|
||||
break
|
||||
|
||||
try:
|
||||
parsed = json.loads(user_text)
|
||||
if "accepted" in parsed:
|
||||
logger.info(
|
||||
f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}"
|
||||
)
|
||||
synthetic_result = ChatMessage(
|
||||
role="tool",
|
||||
contents=[
|
||||
FunctionResultContent(
|
||||
call_id=pending_confirm_changes_id,
|
||||
result="Confirmed" if parsed.get("accepted") else "Rejected",
|
||||
)
|
||||
],
|
||||
)
|
||||
sanitized.append(synthetic_result)
|
||||
if pending_tool_call_ids:
|
||||
pending_tool_call_ids.discard(pending_confirm_changes_id)
|
||||
pending_confirm_changes_id = None
|
||||
continue
|
||||
except (json.JSONDecodeError, KeyError) as exc:
|
||||
logger.debug("Could not parse user message as confirm_changes response: %s", type(exc).__name__)
|
||||
|
||||
if pending_tool_call_ids:
|
||||
logger.info(
|
||||
f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - injecting synthetic results"
|
||||
)
|
||||
for pending_call_id in pending_tool_call_ids:
|
||||
logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}")
|
||||
synthetic_result = ChatMessage(
|
||||
role="tool",
|
||||
contents=[
|
||||
FunctionResultContent(
|
||||
call_id=pending_call_id,
|
||||
result="Tool execution skipped - user provided follow-up message",
|
||||
)
|
||||
],
|
||||
)
|
||||
sanitized.append(synthetic_result)
|
||||
pending_tool_call_ids = None
|
||||
pending_confirm_changes_id = None
|
||||
|
||||
sanitized.append(msg)
|
||||
pending_confirm_changes_id = None
|
||||
continue
|
||||
|
||||
if role_value == "tool":
|
||||
if not pending_tool_call_ids:
|
||||
continue
|
||||
keep = False
|
||||
for content in msg.contents or []:
|
||||
if isinstance(content, FunctionResultContent):
|
||||
call_id = str(content.call_id)
|
||||
if call_id in pending_tool_call_ids:
|
||||
keep = True
|
||||
if call_id == pending_confirm_changes_id:
|
||||
pending_confirm_changes_id = None
|
||||
break
|
||||
if keep:
|
||||
sanitized.append(msg)
|
||||
continue
|
||||
|
||||
sanitized.append(msg)
|
||||
pending_tool_call_ids = None
|
||||
pending_confirm_changes_id = None
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
"""Remove duplicate messages while preserving order."""
|
||||
seen_keys: dict[Any, int] = {}
|
||||
unique_messages: list[ChatMessage] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
|
||||
if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent):
|
||||
call_id = str(msg.contents[0].call_id)
|
||||
key: Any = (role_value, call_id)
|
||||
|
||||
if key in seen_keys:
|
||||
existing_idx = seen_keys[key]
|
||||
existing_msg = unique_messages[existing_idx]
|
||||
|
||||
existing_result = None
|
||||
if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent):
|
||||
existing_result = existing_msg.contents[0].result
|
||||
new_result = msg.contents[0].result
|
||||
|
||||
if (not existing_result or existing_result == "") and new_result:
|
||||
logger.info(f"Replacing empty tool result at index {existing_idx} with data from index {idx}")
|
||||
unique_messages[existing_idx] = msg
|
||||
else:
|
||||
logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}")
|
||||
continue
|
||||
|
||||
seen_keys[key] = len(unique_messages)
|
||||
unique_messages.append(msg)
|
||||
|
||||
elif (
|
||||
role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents)
|
||||
):
|
||||
tool_call_ids = tuple(
|
||||
sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id)
|
||||
)
|
||||
key = (role_value, tool_call_ids)
|
||||
|
||||
if key in seen_keys:
|
||||
logger.info(f"Skipping duplicate assistant tool call at index {idx}")
|
||||
continue
|
||||
|
||||
seen_keys[key] = len(unique_messages)
|
||||
unique_messages.append(msg)
|
||||
|
||||
else:
|
||||
content_str = str([str(c) for c in msg.contents]) if msg.contents else ""
|
||||
key = (role_value, hash(content_str))
|
||||
|
||||
if key in seen_keys:
|
||||
logger.info(f"Skipping duplicate message at index {idx}: role={role_value}")
|
||||
continue
|
||||
|
||||
seen_keys[key] = len(unique_messages)
|
||||
unique_messages.append(msg)
|
||||
|
||||
return unique_messages
|
||||
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""State orchestration utilities."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from ag_ui.core import CustomEvent, EventType
|
||||
from agent_framework import ChatMessage, TextContent
|
||||
|
||||
|
||||
class StateManager:
|
||||
"""Coordinates state defaults, snapshots, and structured updates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_schema: dict[str, Any] | None,
|
||||
predict_state_config: dict[str, dict[str, str]] | None,
|
||||
require_confirmation: bool,
|
||||
) -> None:
|
||||
self.state_schema = state_schema or {}
|
||||
self.predict_state_config = predict_state_config or {}
|
||||
self.require_confirmation = require_confirmation
|
||||
self.current_state: dict[str, Any] = {}
|
||||
|
||||
def initialize(self, initial_state: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Initialize state with schema defaults."""
|
||||
self.current_state = (initial_state or {}).copy()
|
||||
self._apply_schema_defaults()
|
||||
return self.current_state
|
||||
|
||||
def predict_state_event(self) -> CustomEvent | None:
|
||||
"""Create predict-state custom event when configured."""
|
||||
if not self.predict_state_config:
|
||||
return None
|
||||
|
||||
predict_state_value = [
|
||||
{
|
||||
"state_key": state_key,
|
||||
"tool": config["tool"],
|
||||
"tool_argument": config["tool_argument"],
|
||||
}
|
||||
for state_key, config in self.predict_state_config.items()
|
||||
]
|
||||
|
||||
return CustomEvent(
|
||||
type=EventType.CUSTOM,
|
||||
name="PredictState",
|
||||
value=predict_state_value,
|
||||
)
|
||||
|
||||
def initial_snapshot_event(self, event_bridge: Any) -> Any:
|
||||
"""Emit initial snapshot when schema and state present."""
|
||||
if not self.state_schema:
|
||||
return None
|
||||
self._apply_schema_defaults()
|
||||
return event_bridge.create_state_snapshot_event(self.current_state)
|
||||
|
||||
def state_context_message(self, is_new_user_turn: bool, conversation_has_tool_calls: bool) -> ChatMessage | None:
|
||||
"""Inject state context only when starting a new user turn."""
|
||||
if not self.current_state or not self.state_schema:
|
||||
return None
|
||||
if not is_new_user_turn or conversation_has_tool_calls:
|
||||
return None
|
||||
|
||||
state_json = json.dumps(self.current_state, indent=2)
|
||||
return ChatMessage(
|
||||
role="system",
|
||||
contents=[
|
||||
TextContent(
|
||||
text=(
|
||||
"Current state of the application:\n"
|
||||
f"{state_json}\n\n"
|
||||
"When modifying state, you MUST include ALL existing data plus your changes.\n"
|
||||
"For example, if adding one new item to a list, include ALL existing items PLUS the one new item.\n"
|
||||
"Never replace existing data - always preserve and append or merge."
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def extract_state_updates(self, response_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract state updates from structured response payloads."""
|
||||
if self.state_schema:
|
||||
return {key: response_dict[key] for key in self.state_schema.keys() if key in response_dict}
|
||||
return {k: v for k, v in response_dict.items() if k != "message"}
|
||||
|
||||
def apply_state_updates(self, updates: dict[str, Any]) -> None:
|
||||
"""Merge state updates into current state."""
|
||||
if not updates:
|
||||
return
|
||||
self.current_state.update(updates)
|
||||
|
||||
def _apply_schema_defaults(self) -> None:
|
||||
"""Fill missing state fields based on schema hints."""
|
||||
for key, schema in self.state_schema.items():
|
||||
if key in self.current_state:
|
||||
continue
|
||||
if isinstance(schema, dict) and schema.get("type") == "array": # type: ignore
|
||||
self.current_state[key] = []
|
||||
else:
|
||||
self.current_state[key] = {}
|
||||
@@ -0,0 +1,80 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tool handling helpers."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import BaseChatClient, ChatAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def collect_server_tools(agent: Any) -> list[Any]:
|
||||
"""Collect server tools from ChatAgent or duck-typed agent."""
|
||||
if isinstance(agent, ChatAgent):
|
||||
tools_from_agent = agent.chat_options.tools
|
||||
server_tools = list(tools_from_agent) if tools_from_agent else []
|
||||
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
|
||||
for tool in server_tools:
|
||||
tool_name = getattr(tool, "name", "unknown")
|
||||
approval_mode = getattr(tool, "approval_mode", None)
|
||||
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
|
||||
return server_tools
|
||||
|
||||
try:
|
||||
chat_options_attr = getattr(agent, "chat_options", None)
|
||||
if chat_options_attr is not None:
|
||||
return getattr(chat_options_attr, "tools", None) or []
|
||||
except AttributeError:
|
||||
return []
|
||||
return []
|
||||
|
||||
|
||||
def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) -> None:
|
||||
"""Register client tools as additional declaration-only tools to avoid server execution."""
|
||||
if not client_tools:
|
||||
return
|
||||
|
||||
if isinstance(agent, ChatAgent):
|
||||
chat_client = agent.chat_client
|
||||
if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None:
|
||||
chat_client.function_invocation_configuration.additional_tools = client_tools
|
||||
logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)")
|
||||
return
|
||||
|
||||
try:
|
||||
chat_client_attr = getattr(agent, "chat_client", None)
|
||||
if chat_client_attr is not None:
|
||||
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
|
||||
if fic is not None:
|
||||
fic.additional_tools = client_tools # type: ignore[attr-defined]
|
||||
logger.debug(
|
||||
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
|
||||
)
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
|
||||
def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None:
|
||||
"""Combine server and client tools without overriding server metadata."""
|
||||
if not client_tools:
|
||||
logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)")
|
||||
return None
|
||||
|
||||
server_tool_names = {getattr(tool, "name", None) for tool in server_tools}
|
||||
unique_client_tools = [tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names]
|
||||
|
||||
if not unique_client_tools:
|
||||
logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter")
|
||||
return None
|
||||
|
||||
combined_tools: list[Any] = []
|
||||
if server_tools:
|
||||
combined_tools.extend(server_tools)
|
||||
combined_tools.extend(unique_client_tools)
|
||||
logger.info(
|
||||
f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools "
|
||||
f"({len(server_tools)} server + {len(unique_client_tools)} unique client)"
|
||||
)
|
||||
return combined_tools
|
||||
@@ -21,7 +21,6 @@ from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentThread,
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
TextContent,
|
||||
@@ -271,144 +270,29 @@ class DefaultOrchestrator(Orchestrator):
|
||||
AG-UI events
|
||||
"""
|
||||
from ._events import AgentFrameworkEventBridge
|
||||
from ._message_adapters import agui_messages_to_snapshot_format
|
||||
from ._orchestration._message_hygiene import deduplicate_messages, sanitize_tool_history
|
||||
from ._orchestration._state_manager import StateManager
|
||||
from ._orchestration._tooling import (
|
||||
collect_server_tools,
|
||||
merge_tools,
|
||||
register_additional_client_tools,
|
||||
)
|
||||
|
||||
logger.info(f"Starting default agent run for thread_id={context.thread_id}, run_id={context.run_id}")
|
||||
|
||||
# Initialize state tracking
|
||||
initial_state = context.input_data.get("state", {})
|
||||
current_state: dict[str, Any] = initial_state.copy() if initial_state else {}
|
||||
|
||||
# Check if agent uses structured outputs (response_format)
|
||||
# Use isinstance to narrow type for proper attribute access
|
||||
response_format = None
|
||||
if isinstance(context.agent, ChatAgent):
|
||||
response_format = context.agent.chat_options.response_format
|
||||
skip_text_content = response_format is not None
|
||||
|
||||
# Sanitizer: ensure tool results only follow assistant tool calls
|
||||
# Also inject synthetic tool results for confirm_changes
|
||||
def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
sanitized: list[ChatMessage] = []
|
||||
pending_tool_call_ids: set[str] | None = None
|
||||
pending_confirm_changes_id: str | None = None
|
||||
state_manager = StateManager(
|
||||
state_schema=context.config.state_schema,
|
||||
predict_state_config=context.config.predict_state_config,
|
||||
require_confirmation=context.config.require_confirmation,
|
||||
)
|
||||
current_state = state_manager.initialize(context.input_data.get("state", {}))
|
||||
|
||||
for msg in messages:
|
||||
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
|
||||
if role_value == "assistant":
|
||||
tool_ids = {
|
||||
str(content.call_id)
|
||||
for content in msg.contents or []
|
||||
if isinstance(content, FunctionCallContent) and content.call_id
|
||||
}
|
||||
# Check for confirm_changes tool call
|
||||
confirm_changes_call = None
|
||||
for content in msg.contents or []:
|
||||
if isinstance(content, FunctionCallContent) and content.name == "confirm_changes":
|
||||
confirm_changes_call = content
|
||||
break
|
||||
|
||||
sanitized.append(msg)
|
||||
pending_tool_call_ids = tool_ids if tool_ids else None
|
||||
pending_confirm_changes_id = (
|
||||
str(confirm_changes_call.call_id)
|
||||
if confirm_changes_call and confirm_changes_call.call_id
|
||||
else None
|
||||
)
|
||||
continue
|
||||
|
||||
if role_value == "user":
|
||||
# Check if this user message is a confirm_changes response (JSON with "accepted" field)
|
||||
# This must be checked BEFORE injecting synthetic results for pending tool calls
|
||||
if pending_confirm_changes_id:
|
||||
user_text = ""
|
||||
for content in msg.contents or []:
|
||||
if isinstance(content, TextContent):
|
||||
user_text = content.text
|
||||
break
|
||||
|
||||
try:
|
||||
parsed = json.loads(user_text)
|
||||
if "accepted" in parsed:
|
||||
# This is a confirm_changes response - inject synthetic tool result
|
||||
logger.info(
|
||||
f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}"
|
||||
)
|
||||
synthetic_result = ChatMessage(
|
||||
role="tool",
|
||||
contents=[
|
||||
FunctionResultContent(
|
||||
call_id=pending_confirm_changes_id,
|
||||
result="Confirmed" if parsed.get("accepted") else "Rejected",
|
||||
)
|
||||
],
|
||||
)
|
||||
sanitized.append(synthetic_result)
|
||||
if pending_tool_call_ids:
|
||||
pending_tool_call_ids.discard(pending_confirm_changes_id)
|
||||
pending_confirm_changes_id = None
|
||||
# Don't add the user message to sanitized - it's been converted to tool result
|
||||
continue
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
# Failed to parse user message as confirm_changes response; continue normal processing
|
||||
logger.debug(f"Could not parse user message as confirm_changes response: {e}")
|
||||
|
||||
# Before processing user message, check if there are pending tool calls without results
|
||||
# This happens when assistant made multiple tool calls but only some got results
|
||||
# This is checked AFTER confirm_changes special handling above
|
||||
if pending_tool_call_ids:
|
||||
logger.info(
|
||||
f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - injecting synthetic results"
|
||||
)
|
||||
for pending_call_id in pending_tool_call_ids:
|
||||
logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}")
|
||||
synthetic_result = ChatMessage(
|
||||
role="tool",
|
||||
contents=[
|
||||
FunctionResultContent(
|
||||
call_id=pending_call_id,
|
||||
result="Tool execution skipped - user provided follow-up message",
|
||||
)
|
||||
],
|
||||
)
|
||||
sanitized.append(synthetic_result)
|
||||
pending_tool_call_ids = None
|
||||
pending_confirm_changes_id = None
|
||||
|
||||
# Normal user message processing
|
||||
sanitized.append(msg)
|
||||
pending_confirm_changes_id = None
|
||||
continue
|
||||
|
||||
if role_value == "tool":
|
||||
if not pending_tool_call_ids:
|
||||
continue
|
||||
keep = False
|
||||
for content in msg.contents or []:
|
||||
if isinstance(content, FunctionResultContent):
|
||||
call_id = str(content.call_id)
|
||||
if call_id in pending_tool_call_ids:
|
||||
keep = True
|
||||
# Note: We do NOT remove call_id from pending here.
|
||||
# This allows duplicate tool results to pass through sanitization
|
||||
# so the deduplicator can choose the best one (prefer non-empty results).
|
||||
# We only clear pending_tool_call_ids when a user message arrives.
|
||||
if call_id == pending_confirm_changes_id:
|
||||
# For confirm_changes specifically, we do want to clear it
|
||||
# since we only expect one response
|
||||
pending_confirm_changes_id = None
|
||||
break
|
||||
if keep:
|
||||
sanitized.append(msg)
|
||||
continue
|
||||
|
||||
sanitized.append(msg)
|
||||
pending_tool_call_ids = None
|
||||
pending_confirm_changes_id = None
|
||||
|
||||
return sanitized
|
||||
|
||||
# Create event bridge
|
||||
event_bridge = AgentFrameworkEventBridge(
|
||||
run_id=context.run_id,
|
||||
thread_id=context.thread_id,
|
||||
@@ -421,42 +305,19 @@ class DefaultOrchestrator(Orchestrator):
|
||||
|
||||
yield event_bridge.create_run_started_event()
|
||||
|
||||
# Emit PredictState custom event if we have predictive state config
|
||||
if context.config.predict_state_config:
|
||||
from ag_ui.core import CustomEvent, EventType
|
||||
predict_event = state_manager.predict_state_event()
|
||||
if predict_event:
|
||||
yield predict_event
|
||||
|
||||
predict_state_value = [
|
||||
{
|
||||
"state_key": state_key,
|
||||
"tool": config["tool"],
|
||||
"tool_argument": config["tool_argument"],
|
||||
}
|
||||
for state_key, config in context.config.predict_state_config.items()
|
||||
]
|
||||
snapshot_event = state_manager.initial_snapshot_event(event_bridge)
|
||||
if snapshot_event:
|
||||
yield snapshot_event
|
||||
|
||||
yield CustomEvent(
|
||||
type=EventType.CUSTOM,
|
||||
name="PredictState",
|
||||
value=predict_state_value,
|
||||
)
|
||||
|
||||
# If we have a state schema, ensure we emit initial state snapshot
|
||||
if context.config.state_schema:
|
||||
# Initialize missing state fields with appropriate empty values based on schema type
|
||||
for key, schema in context.config.state_schema.items():
|
||||
if key not in current_state:
|
||||
# Default to empty object; use empty array if schema specifies "array" type
|
||||
current_state[key] = [] if isinstance(schema, dict) and schema.get("type") == "array" else {} # type: ignore
|
||||
yield event_bridge.create_state_snapshot_event(current_state)
|
||||
|
||||
# Create thread for context tracking
|
||||
thread = AgentThread()
|
||||
thread.metadata = { # type: ignore[attr-defined]
|
||||
"ag_ui_thread_id": context.thread_id,
|
||||
"ag_ui_run_id": context.run_id,
|
||||
}
|
||||
|
||||
# Inject current state into thread metadata so agent can access it
|
||||
if current_state:
|
||||
thread.metadata["current_state"] = current_state # type: ignore[attr-defined]
|
||||
|
||||
@@ -475,90 +336,24 @@ class DefaultOrchestrator(Orchestrator):
|
||||
for j, content in enumerate(msg.contents):
|
||||
content_type = type(content).__name__
|
||||
if isinstance(content, TextContent):
|
||||
logger.debug(f" Content {j}: {content_type} - {content.text}")
|
||||
logger.debug(" Content %s: %s - text_length=%s", j, content_type, len(content.text))
|
||||
elif isinstance(content, FunctionCallContent):
|
||||
logger.debug(f" Content {j}: {content_type} - {content.name}({content.arguments})")
|
||||
elif isinstance(content, FunctionResultContent):
|
||||
arg_length = len(str(content.arguments)) if content.arguments else 0
|
||||
logger.debug(
|
||||
f" Content {j}: {content_type} - call_id={content.call_id}, result={content.result}"
|
||||
" Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length
|
||||
)
|
||||
elif isinstance(content, FunctionResultContent):
|
||||
result_preview = type(content.result).__name__ if content.result is not None else "None"
|
||||
logger.debug(
|
||||
" Content %s: %s - call_id=%s, result_type=%s",
|
||||
j,
|
||||
content_type,
|
||||
content.call_id,
|
||||
result_preview,
|
||||
)
|
||||
else:
|
||||
logger.debug(f" Content {j}: {content_type} - {content}")
|
||||
logger.debug(f" Content {j}: {content_type}")
|
||||
|
||||
# After getting sanitized_messages, deduplicate them
|
||||
def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
"""Remove duplicate messages while preserving order.
|
||||
|
||||
For tool results with the same call_id, prefer the one with actual data.
|
||||
"""
|
||||
seen_keys: dict[Any, int] = {} # key -> index in unique_messages (key can be various tuple types)
|
||||
unique_messages: list[ChatMessage] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
|
||||
# For tool messages, use call_id as unique key
|
||||
if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent):
|
||||
call_id = str(msg.contents[0].call_id)
|
||||
key: Any = (role_value, call_id)
|
||||
|
||||
# Check if we already have this tool result
|
||||
if key in seen_keys:
|
||||
existing_idx = seen_keys[key]
|
||||
existing_msg = unique_messages[existing_idx]
|
||||
|
||||
# Compare results - prefer non-empty over empty
|
||||
existing_result = None
|
||||
if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent):
|
||||
existing_result = existing_msg.contents[0].result
|
||||
new_result = msg.contents[0].result
|
||||
|
||||
# Replace if existing is empty/None and new has data
|
||||
if (not existing_result or existing_result == "") and new_result:
|
||||
logger.info(
|
||||
f"Replacing empty tool result at index {existing_idx} with data from index {idx}"
|
||||
)
|
||||
unique_messages[existing_idx] = msg
|
||||
else:
|
||||
logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}")
|
||||
continue
|
||||
|
||||
seen_keys[key] = len(unique_messages)
|
||||
unique_messages.append(msg)
|
||||
|
||||
elif (
|
||||
role_value == "assistant"
|
||||
and msg.contents
|
||||
and any(isinstance(c, FunctionCallContent) for c in msg.contents)
|
||||
):
|
||||
# For assistant messages with tool_calls, use the tool call IDs
|
||||
tool_call_ids = tuple(
|
||||
sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id)
|
||||
)
|
||||
key = (role_value, tool_call_ids)
|
||||
|
||||
if key in seen_keys:
|
||||
logger.info(f"Skipping duplicate assistant tool call at index {idx}")
|
||||
continue
|
||||
|
||||
seen_keys[key] = len(unique_messages)
|
||||
unique_messages.append(msg)
|
||||
|
||||
else:
|
||||
# For other messages (system, user, assistant without tools), hash the content
|
||||
content_str = str([str(c) for c in msg.contents]) if msg.contents else ""
|
||||
key = (role_value, hash(content_str))
|
||||
|
||||
if key in seen_keys:
|
||||
logger.info(f"Skipping duplicate message at index {idx}: role={role_value}")
|
||||
continue
|
||||
|
||||
seen_keys[key] = len(unique_messages)
|
||||
unique_messages.append(msg)
|
||||
|
||||
return unique_messages
|
||||
|
||||
# Then use it:
|
||||
sanitized_messages = sanitize_tool_history(raw_messages)
|
||||
provider_messages = deduplicate_messages(sanitized_messages)
|
||||
|
||||
@@ -575,66 +370,45 @@ class DefaultOrchestrator(Orchestrator):
|
||||
for j, content in enumerate(msg.contents):
|
||||
content_type = type(content).__name__
|
||||
if isinstance(content, TextContent):
|
||||
logger.info(f" Content {j}: {content_type} - {content.text}")
|
||||
logger.info(f" Content {j}: {content_type} - text_length={len(content.text)}")
|
||||
elif isinstance(content, FunctionCallContent):
|
||||
logger.info(f" Content {j}: {content_type} - {content.name}({content.arguments})")
|
||||
arg_length = len(str(content.arguments)) if content.arguments else 0
|
||||
logger.info(" Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length)
|
||||
elif isinstance(content, FunctionResultContent):
|
||||
result_preview = type(content.result).__name__ if content.result is not None else "None"
|
||||
logger.info(
|
||||
f" Content {j}: {content_type} - call_id={content.call_id}, result={content.result}"
|
||||
" Content %s: %s - call_id=%s, result_type=%s",
|
||||
j,
|
||||
content_type,
|
||||
content.call_id,
|
||||
result_preview,
|
||||
)
|
||||
else:
|
||||
logger.info(f" Content {j}: {content_type} - {content}")
|
||||
logger.info(f" Content {j}: {content_type}")
|
||||
|
||||
# NOTE: For AG-UI, the client sends the full conversation history on each request.
|
||||
# We should NOT add to thread.on_new_messages() as that would cause duplication.
|
||||
# Instead, we pass messages directly to the agent via messages_to_run.
|
||||
|
||||
# Inject current state as system message context if we have state and this is a new user turn
|
||||
messages_to_run: list[Any] = []
|
||||
|
||||
# Check if the last message is from the user (new turn) vs assistant/tool (mid-execution)
|
||||
is_new_user_turn = False
|
||||
if provider_messages:
|
||||
last_msg = provider_messages[-1]
|
||||
is_new_user_turn = last_msg.role.value == "user"
|
||||
role_value = last_msg.role.value if hasattr(last_msg.role, "value") else str(last_msg.role)
|
||||
is_new_user_turn = role_value == "user"
|
||||
|
||||
# Check if conversation has tool calls (indicates mid-execution)
|
||||
conversation_has_tool_calls = False
|
||||
for msg in provider_messages:
|
||||
if msg.role.value == "assistant" and hasattr(msg, "contents") and msg.contents:
|
||||
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
if role_value == "assistant" and hasattr(msg, "contents") and msg.contents:
|
||||
if any(isinstance(content, FunctionCallContent) for content in msg.contents):
|
||||
conversation_has_tool_calls = True
|
||||
break
|
||||
|
||||
# Only inject state context on new user turns AND when conversation doesn't have tool calls
|
||||
# (tool calls indicate we're mid-execution, so state context was already injected)
|
||||
if current_state and context.config.state_schema and is_new_user_turn and not conversation_has_tool_calls:
|
||||
state_json = json.dumps(current_state, indent=2)
|
||||
state_context_msg = ChatMessage(
|
||||
role="system",
|
||||
contents=[
|
||||
TextContent(
|
||||
text=f"""Current state of the application:
|
||||
{state_json}
|
||||
|
||||
When modifying state, you MUST include ALL existing data plus your changes.
|
||||
For example, if adding one new item to a list, include ALL existing items PLUS the one new item.
|
||||
Never replace existing data - always preserve and append or merge."""
|
||||
)
|
||||
],
|
||||
)
|
||||
state_context_msg = state_manager.state_context_message(
|
||||
is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls
|
||||
)
|
||||
if state_context_msg:
|
||||
messages_to_run.append(state_context_msg)
|
||||
|
||||
# Add all provider messages to messages_to_run
|
||||
# AG-UI sends full conversation history on each request, so we pass it directly to the agent
|
||||
messages_to_run.extend(provider_messages)
|
||||
|
||||
# Handle client tools for hybrid execution
|
||||
# Client sends tool metadata, server merges with its own tools.
|
||||
# Client tools have func=None (declaration-only), so @use_function_invocation
|
||||
# will return the function call without executing (passes back to client).
|
||||
from agent_framework import BaseChatClient
|
||||
|
||||
client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools"))
|
||||
logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools")
|
||||
if client_tools:
|
||||
@@ -643,85 +417,31 @@ class DefaultOrchestrator(Orchestrator):
|
||||
declaration_only = getattr(tool, "declaration_only", None)
|
||||
logger.info(f"[TOOLS] - Client tool: {tool_name}, declaration_only={declaration_only}")
|
||||
|
||||
# Extract server tools - use type narrowing when possible
|
||||
server_tools: list[Any] = []
|
||||
if isinstance(context.agent, ChatAgent):
|
||||
tools_from_agent = context.agent.chat_options.tools
|
||||
server_tools = list(tools_from_agent) if tools_from_agent else []
|
||||
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
|
||||
for tool in server_tools:
|
||||
tool_name = getattr(tool, "name", "unknown")
|
||||
approval_mode = getattr(tool, "approval_mode", None)
|
||||
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
|
||||
else:
|
||||
# AgentProtocol allows duck-typed implementations - fallback to attribute access
|
||||
# This supports test mocks and custom agent implementations
|
||||
try:
|
||||
chat_options_attr = getattr(context.agent, "chat_options", None)
|
||||
if chat_options_attr is not None:
|
||||
server_tools = getattr(chat_options_attr, "tools", None) or []
|
||||
except AttributeError:
|
||||
pass
|
||||
server_tools = collect_server_tools(context.agent)
|
||||
register_additional_client_tools(context.agent, client_tools)
|
||||
tools_param = merge_tools(server_tools, client_tools)
|
||||
|
||||
# Register client tools as additional (declaration-only) so they are not executed on server
|
||||
if client_tools:
|
||||
if isinstance(context.agent, ChatAgent):
|
||||
# Type-safe path for ChatAgent
|
||||
chat_client = context.agent.chat_client
|
||||
if (
|
||||
isinstance(chat_client, BaseChatClient)
|
||||
and chat_client.function_invocation_configuration is not None
|
||||
):
|
||||
chat_client.function_invocation_configuration.additional_tools = client_tools
|
||||
logger.debug(
|
||||
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
|
||||
)
|
||||
else:
|
||||
# Fallback for AgentProtocol implementations (test mocks, custom agents)
|
||||
try:
|
||||
chat_client_attr = getattr(context.agent, "chat_client", None)
|
||||
if chat_client_attr is not None:
|
||||
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
|
||||
if fic is not None:
|
||||
fic.additional_tools = client_tools # type: ignore[attr-defined]
|
||||
logger.debug(
|
||||
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# For tools parameter: only pass if we have client tools to add
|
||||
# If we pass tools=, it overrides the agent's configured tools and loses metadata like approval_mode
|
||||
# So only pass tools when we need to add client tools on top of server tools
|
||||
# IMPORTANT: Don't include client tools that duplicate server tools (same name)
|
||||
tools_param = None
|
||||
if client_tools:
|
||||
# Get server tool names
|
||||
server_tool_names = {getattr(tool, "name", None) for tool in server_tools}
|
||||
|
||||
# Filter out client tools that duplicate server tools
|
||||
unique_client_tools = [
|
||||
tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names
|
||||
]
|
||||
|
||||
if unique_client_tools:
|
||||
combined_tools: list[Any] = []
|
||||
if server_tools:
|
||||
combined_tools.extend(server_tools)
|
||||
combined_tools.extend(unique_client_tools)
|
||||
tools_param = combined_tools
|
||||
logger.info(
|
||||
f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools ({len(server_tools)} server + {len(unique_client_tools)} unique client)"
|
||||
)
|
||||
else:
|
||||
logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter")
|
||||
else:
|
||||
logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)")
|
||||
|
||||
# Collect all updates to get the final structured output
|
||||
all_updates: list[Any] = []
|
||||
update_count = 0
|
||||
async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=tools_param):
|
||||
# Prepare metadata for chat client (Azure requires string values)
|
||||
safe_metadata: dict[str, Any] = {}
|
||||
thread_metadata = getattr(thread, "metadata", None)
|
||||
if thread_metadata:
|
||||
for key, value in thread_metadata.items():
|
||||
value_str = value if isinstance(value, str) else json.dumps(value)
|
||||
if len(value_str) > 512:
|
||||
value_str = value_str[:512]
|
||||
safe_metadata[key] = value_str
|
||||
|
||||
run_kwargs: dict[str, Any] = {
|
||||
"thread": thread,
|
||||
"tools": tools_param,
|
||||
"metadata": safe_metadata,
|
||||
}
|
||||
if safe_metadata:
|
||||
run_kwargs["store"] = True
|
||||
|
||||
async for update in context.agent.run_stream(messages_to_run, **run_kwargs):
|
||||
update_count += 1
|
||||
logger.info(f"[STREAM] Received update #{update_count} from agent")
|
||||
all_updates.append(update)
|
||||
@@ -733,23 +453,19 @@ class DefaultOrchestrator(Orchestrator):
|
||||
|
||||
logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}")
|
||||
|
||||
# After agent completes, check if we should stop (waiting for user to confirm changes)
|
||||
if event_bridge.should_stop_after_confirm:
|
||||
logger.info("Stopping run after confirm_changes - waiting for user response")
|
||||
yield event_bridge.create_run_finished_event()
|
||||
return
|
||||
|
||||
# Check if there are pending tool calls (declaration-only tools that weren't executed)
|
||||
# These need ToolCallEndEvent to signal the client to execute them
|
||||
# Only emit for tool calls that haven't already had ToolCallEndEvent emitted
|
||||
# (approval-required tools already had their end event emitted)
|
||||
if event_bridge.pending_tool_calls:
|
||||
pending_without_end = [
|
||||
tc for tc in event_bridge.pending_tool_calls if tc.get("id") not in event_bridge.tool_calls_ended
|
||||
]
|
||||
if pending_without_end:
|
||||
logger.info(
|
||||
f"Found {len(pending_without_end)} pending tool calls without end event - emitting ToolCallEndEvent"
|
||||
"Found %s pending tool calls without end event - emitting ToolCallEndEvent",
|
||||
len(pending_without_end),
|
||||
)
|
||||
for tool_call in pending_without_end:
|
||||
tool_call_id = tool_call.get("id")
|
||||
@@ -760,76 +476,47 @@ class DefaultOrchestrator(Orchestrator):
|
||||
logger.info(f"Emitting ToolCallEndEvent for declaration-only tool call '{tool_call_id}'")
|
||||
yield end_event
|
||||
|
||||
# After streaming completes, check if agent has response_format and extract structured output
|
||||
if all_updates and response_format:
|
||||
from agent_framework import AgentRunResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger.info(f"Processing structured output, update count: {len(all_updates)}")
|
||||
|
||||
# Convert streaming updates to final response to get the structured output
|
||||
final_response = AgentRunResponse.from_agent_run_response_updates(
|
||||
all_updates, output_format_type=response_format
|
||||
)
|
||||
|
||||
if final_response.value and isinstance(final_response.value, BaseModel):
|
||||
# Convert Pydantic model to dict
|
||||
response_dict = final_response.value.model_dump(mode="json", exclude_none=True)
|
||||
logger.info(f"Received structured output: {list(response_dict.keys())}")
|
||||
logger.info(f"Received structured output keys: {list(response_dict.keys())}")
|
||||
|
||||
# Extract state fields based on state_schema
|
||||
state_updates: dict[str, Any] = {}
|
||||
|
||||
if context.config.state_schema:
|
||||
# Use state_schema to determine which fields are state
|
||||
for state_key in context.config.state_schema.keys():
|
||||
if state_key in response_dict:
|
||||
state_updates[state_key] = response_dict[state_key]
|
||||
else:
|
||||
# No schema: treat all non-message fields as state
|
||||
state_updates = {k: v for k, v in response_dict.items() if k != "message"}
|
||||
|
||||
# Apply state updates if any found
|
||||
state_updates = state_manager.extract_state_updates(response_dict)
|
||||
if state_updates:
|
||||
current_state.update(state_updates)
|
||||
|
||||
# Emit StateSnapshotEvent with the updated state
|
||||
state_manager.apply_state_updates(state_updates)
|
||||
state_snapshot = event_bridge.create_state_snapshot_event(current_state)
|
||||
yield state_snapshot
|
||||
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")
|
||||
|
||||
# If there's a message field, emit it as chat text
|
||||
if "message" in response_dict and response_dict["message"]:
|
||||
message_id = generate_event_id()
|
||||
yield TextMessageStartEvent(message_id=message_id, role="assistant")
|
||||
yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"])
|
||||
yield TextMessageEndEvent(message_id=message_id)
|
||||
logger.info(f"Emitted conversational message: {response_dict['message'][:100]}...")
|
||||
logger.info(f"Emitted conversational message with length={len(response_dict['message'])}")
|
||||
|
||||
logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}")
|
||||
if event_bridge.current_message_id:
|
||||
logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}")
|
||||
yield event_bridge.create_message_end_event(event_bridge.current_message_id)
|
||||
|
||||
# Emit MessagesSnapshotEvent to persist the final assistant text message
|
||||
from ._message_adapters import agui_messages_to_snapshot_format
|
||||
|
||||
# Build the final assistant message with accumulated text content
|
||||
assistant_text_message = {
|
||||
"id": event_bridge.current_message_id,
|
||||
"role": "assistant",
|
||||
"content": event_bridge.accumulated_text_content,
|
||||
}
|
||||
|
||||
# Convert input messages to snapshot format (normalize content structure)
|
||||
# event_bridge.input_messages are already in AG-UI format, just need normalization
|
||||
converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages)
|
||||
|
||||
# Build complete messages array
|
||||
# Include: input messages + any pending tool calls/results + final text message
|
||||
all_messages = converted_input_messages.copy()
|
||||
|
||||
# Add assistant message with tool calls if any
|
||||
if event_bridge.pending_tool_calls:
|
||||
tool_call_message = {
|
||||
"id": generate_event_id(),
|
||||
@@ -838,18 +525,16 @@ class DefaultOrchestrator(Orchestrator):
|
||||
}
|
||||
all_messages.append(tool_call_message)
|
||||
|
||||
# Add tool results if any
|
||||
all_messages.extend(event_bridge.tool_results.copy())
|
||||
|
||||
# Add final text message
|
||||
all_messages.append(assistant_text_message)
|
||||
|
||||
messages_snapshot = MessagesSnapshotEvent(
|
||||
messages=all_messages, # type: ignore[arg-type]
|
||||
)
|
||||
logger.info(
|
||||
f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages "
|
||||
f"(text content length: {len(event_bridge.accumulated_text_content)})"
|
||||
"[FINALIZE] Emitting MessagesSnapshotEvent with %s messages (text content length: %s)",
|
||||
len(all_messages),
|
||||
len(event_bridge.accumulated_text_content),
|
||||
)
|
||||
yield messages_snapshot
|
||||
else:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
@@ -1,10 +1,61 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for AGUIChatClient."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, MutableSequence
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatMessage, ChatOptions, FunctionCallContent, Role, ai_function
|
||||
from agent_framework import (
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponseUpdate,
|
||||
FunctionCallContent,
|
||||
Role,
|
||||
TextContent,
|
||||
ai_function,
|
||||
)
|
||||
from agent_framework._types import ChatResponse
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent
|
||||
from agent_framework_ag_ui._http_service import AGUIHttpService
|
||||
|
||||
|
||||
class TestableAGUIChatClient(AGUIChatClient):
|
||||
"""Testable wrapper exposing protected helpers."""
|
||||
|
||||
@property
|
||||
def http_service(self) -> AGUIHttpService:
|
||||
"""Expose http service for monkeypatching."""
|
||||
return self._http_service
|
||||
|
||||
def extract_state_from_messages(
|
||||
self, messages: list[ChatMessage]
|
||||
) -> tuple[list[ChatMessage], dict[str, Any] | None]:
|
||||
"""Expose state extraction helper."""
|
||||
return self._extract_state_from_messages(messages)
|
||||
|
||||
def convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[dict[str, Any]]:
|
||||
"""Expose message conversion helper."""
|
||||
return self._convert_messages_to_agui_format(messages)
|
||||
|
||||
def get_thread_id(self, chat_options: ChatOptions) -> str:
|
||||
"""Expose thread id helper."""
|
||||
return self._get_thread_id(chat_options)
|
||||
|
||||
async def inner_get_streaming_response(
|
||||
self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
"""Proxy to protected streaming call."""
|
||||
async for update in self._inner_get_streaming_response(messages=messages, chat_options=chat_options):
|
||||
yield update
|
||||
|
||||
async def inner_get_response(
|
||||
self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
|
||||
) -> ChatResponse:
|
||||
"""Proxy to protected response call."""
|
||||
return await self._inner_get_response(messages=messages, chat_options=chat_options)
|
||||
|
||||
|
||||
class TestAGUIChatClient:
|
||||
@@ -12,25 +63,25 @@ class TestAGUIChatClient:
|
||||
|
||||
async def test_client_initialization(self) -> None:
|
||||
"""Test client initialization."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
|
||||
assert client._http_service is not None
|
||||
assert client._http_service.endpoint.startswith("http://localhost:8888")
|
||||
assert client.http_service is not None
|
||||
assert client.http_service.endpoint.startswith("http://localhost:8888")
|
||||
|
||||
async def test_client_context_manager(self) -> None:
|
||||
"""Test client as async context manager."""
|
||||
async with AGUIChatClient(endpoint="http://localhost:8888/") as client:
|
||||
async with TestableAGUIChatClient(endpoint="http://localhost:8888/") as client:
|
||||
assert client is not None
|
||||
|
||||
async def test_extract_state_from_messages_no_state(self) -> None:
|
||||
"""Test state extraction when no state is present."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
messages = [
|
||||
ChatMessage(role="user", text="Hello"),
|
||||
ChatMessage(role="assistant", text="Hi there"),
|
||||
]
|
||||
|
||||
result_messages, state = client._extract_state_from_messages(messages)
|
||||
result_messages, state = client.extract_state_from_messages(messages)
|
||||
|
||||
assert result_messages == messages
|
||||
assert state is None
|
||||
@@ -39,7 +90,7 @@ class TestAGUIChatClient:
|
||||
"""Test state extraction from last message."""
|
||||
import base64
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
|
||||
state_data = {"key": "value", "count": 42}
|
||||
state_json = json.dumps(state_data)
|
||||
@@ -55,7 +106,7 @@ class TestAGUIChatClient:
|
||||
),
|
||||
]
|
||||
|
||||
result_messages, state = client._extract_state_from_messages(messages)
|
||||
result_messages, state = client.extract_state_from_messages(messages)
|
||||
|
||||
assert len(result_messages) == 1
|
||||
assert result_messages[0].text == "Hello"
|
||||
@@ -65,7 +116,7 @@ class TestAGUIChatClient:
|
||||
"""Test state extraction with invalid JSON."""
|
||||
import base64
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
|
||||
invalid_json = "not valid json"
|
||||
state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8")
|
||||
@@ -79,20 +130,20 @@ class TestAGUIChatClient:
|
||||
),
|
||||
]
|
||||
|
||||
result_messages, state = client._extract_state_from_messages(messages)
|
||||
result_messages, state = client.extract_state_from_messages(messages)
|
||||
|
||||
assert result_messages == messages
|
||||
assert state is None
|
||||
|
||||
async def test_convert_messages_to_agui_format(self) -> None:
|
||||
"""Test message conversion to AG-UI format."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text="What is the weather?"),
|
||||
ChatMessage(role=Role.ASSISTANT, text="Let me check.", message_id="msg_123"),
|
||||
]
|
||||
|
||||
agui_messages = client._convert_messages_to_agui_format(messages)
|
||||
agui_messages = client.convert_messages_to_agui_format(messages)
|
||||
|
||||
assert len(agui_messages) == 2
|
||||
assert agui_messages[0]["role"] == "user"
|
||||
@@ -103,24 +154,24 @@ class TestAGUIChatClient:
|
||||
|
||||
async def test_get_thread_id_from_metadata(self) -> None:
|
||||
"""Test thread ID extraction from metadata."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"})
|
||||
|
||||
thread_id = client._get_thread_id(chat_options)
|
||||
thread_id = client.get_thread_id(chat_options)
|
||||
|
||||
assert thread_id == "existing_thread_123"
|
||||
|
||||
async def test_get_thread_id_generation(self) -> None:
|
||||
"""Test automatic thread ID generation."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
chat_options = ChatOptions()
|
||||
|
||||
thread_id = client._get_thread_id(chat_options)
|
||||
thread_id = client.get_thread_id(chat_options)
|
||||
|
||||
assert thread_id.startswith("thread_")
|
||||
assert len(thread_id) > 7
|
||||
|
||||
async def test_get_streaming_response(self, monkeypatch) -> None:
|
||||
async def test_get_streaming_response(self, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Test streaming response method."""
|
||||
mock_events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
@@ -129,26 +180,32 @@ class TestAGUIChatClient:
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test message")]
|
||||
chat_options = ChatOptions()
|
||||
|
||||
updates = []
|
||||
async for update in client._inner_get_streaming_response(messages=messages, chat_options=chat_options):
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in client.inner_get_streaming_response(messages=messages, chat_options=chat_options):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 4
|
||||
assert updates[0].additional_properties is not None
|
||||
assert updates[0].additional_properties["thread_id"] == "thread_1"
|
||||
assert updates[1].contents[0].text == "Hello"
|
||||
assert updates[2].contents[0].text == " world"
|
||||
|
||||
async def test_get_response_non_streaming(self, monkeypatch) -> None:
|
||||
first_content = updates[1].contents[0]
|
||||
second_content = updates[2].contents[0]
|
||||
assert isinstance(first_content, TextContent)
|
||||
assert isinstance(second_content, TextContent)
|
||||
assert first_content.text == "Hello"
|
||||
assert second_content.text == " world"
|
||||
|
||||
async def test_get_response_non_streaming(self, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Test non-streaming response method."""
|
||||
mock_events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
@@ -156,23 +213,23 @@ class TestAGUIChatClient:
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test message")]
|
||||
chat_options = ChatOptions()
|
||||
|
||||
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
||||
response = await client.inner_get_response(messages=messages, chat_options=chat_options)
|
||||
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert "Complete response" in response.text
|
||||
|
||||
async def test_tool_handling(self, monkeypatch) -> None:
|
||||
async def test_tool_handling(self, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Test that client tool metadata is sent to server.
|
||||
|
||||
Client tool metadata (name, description, schema) is sent to server for planning.
|
||||
@@ -191,28 +248,29 @@ class TestAGUIChatClient:
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
# Client tool metadata should be sent to server
|
||||
tools = kwargs.get("tools")
|
||||
tools: list[dict[str, Any]] | None = kwargs.get("tools")
|
||||
assert tools is not None
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "test_tool"
|
||||
assert tools[0]["description"] == "Test tool."
|
||||
assert "parameters" in tools[0]
|
||||
tool_entry = tools[0]
|
||||
assert tool_entry["name"] == "test_tool"
|
||||
assert tool_entry["description"] == "Test tool."
|
||||
assert "parameters" in tool_entry
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test with tools")]
|
||||
chat_options = ChatOptions(tools=[test_tool])
|
||||
|
||||
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
||||
response = await client.inner_get_response(messages=messages, chat_options=chat_options)
|
||||
|
||||
assert response is not None
|
||||
|
||||
async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch) -> None:
|
||||
async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Ensure server-side tool calls are exposed as FunctionCallContent after processing."""
|
||||
|
||||
mock_events = [
|
||||
@@ -222,17 +280,17 @@ class TestAGUIChatClient:
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test server tool execution")]
|
||||
chat_options = ChatOptions()
|
||||
|
||||
updates = []
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in client.get_streaming_response(messages, chat_options=chat_options):
|
||||
updates.append(update)
|
||||
|
||||
@@ -245,7 +303,7 @@ class TestAGUIChatClient:
|
||||
isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents
|
||||
)
|
||||
|
||||
async def test_server_tool_calls_not_executed_locally(self, monkeypatch) -> None:
|
||||
async def test_server_tool_calls_not_executed_locally(self, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Server tools should not trigger local function invocation even when client tools exist."""
|
||||
|
||||
@ai_function
|
||||
@@ -260,18 +318,18 @@ class TestAGUIChatClient:
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
async def fake_auto_invoke(*args, **kwargs):
|
||||
async def fake_auto_invoke(*args: object, **kwargs: Any) -> None:
|
||||
function_call = kwargs.get("function_call_content") or args[0]
|
||||
raise AssertionError(f"Unexpected local execution of server tool: {getattr(function_call, 'name', '?')}")
|
||||
|
||||
monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke)
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test server tool execution")]
|
||||
chat_options = ChatOptions(tool_choice="auto", tools=[client_tool])
|
||||
@@ -279,7 +337,7 @@ class TestAGUIChatClient:
|
||||
async for _ in client.get_streaming_response(messages, chat_options=chat_options):
|
||||
pass
|
||||
|
||||
async def test_state_transmission(self, monkeypatch) -> None:
|
||||
async def test_state_transmission(self, monkeypatch: MonkeyPatch) -> None:
|
||||
"""Test state is properly transmitted to server."""
|
||||
import base64
|
||||
|
||||
@@ -302,16 +360,16 @@ class TestAGUIChatClient:
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
assert kwargs.get("state") == state_data
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
client = TestableAGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client.http_service, "post_run", mock_post_run)
|
||||
|
||||
chat_options = ChatOptions()
|
||||
|
||||
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
||||
response = await client.inner_get_response(messages=messages, chat_options=chat_options)
|
||||
|
||||
assert response is not None
|
||||
|
||||
@@ -3,21 +3,30 @@
|
||||
"""Comprehensive tests for AgentFrameworkAgent (_agent.py)."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import AsyncIterator, MutableSequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from agent_framework import ChatAgent, TextContent
|
||||
from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent
|
||||
from agent_framework._types import ChatResponseUpdate
|
||||
from pydantic import BaseModel
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from test_helpers_ag_ui import StreamingChatClientStub
|
||||
|
||||
|
||||
async def test_agent_initialization_basic():
|
||||
"""Test basic agent initialization without state schema."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
assert wrapper.name == "test_agent"
|
||||
@@ -30,12 +39,13 @@ async def test_agent_initialization_with_state_schema():
|
||||
"""Test agent initialization with state_schema."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
state_schema = {"document": {"type": "string"}}
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
state_schema: dict[str, dict[str, Any]] = {"document": {"type": "string"}}
|
||||
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
|
||||
|
||||
assert wrapper.config.state_schema == state_schema
|
||||
@@ -45,31 +55,56 @@ async def test_agent_initialization_with_predict_state_config():
|
||||
"""Test agent initialization with predict_state_config."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
|
||||
wrapper = AgentFrameworkAgent(agent=agent, predict_state_config=predict_config)
|
||||
|
||||
assert wrapper.config.predict_state_config == predict_config
|
||||
|
||||
|
||||
async def test_agent_initialization_with_pydantic_state_schema():
|
||||
"""Test agent initialization when state_schema is provided as Pydantic model/class."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
class MyState(BaseModel):
|
||||
document: str
|
||||
tags: list[str] = []
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
|
||||
wrapper_class_schema = AgentFrameworkAgent(agent=agent, state_schema=MyState)
|
||||
wrapper_instance_schema = AgentFrameworkAgent(agent=agent, state_schema=MyState(document="hi"))
|
||||
|
||||
expected_properties = MyState.model_json_schema().get("properties", {})
|
||||
assert wrapper_class_schema.config.state_schema == expected_properties
|
||||
assert wrapper_instance_schema.config.state_schema == expected_properties
|
||||
|
||||
|
||||
async def test_run_started_event_emission():
|
||||
"""Test RunStartedEvent is emitted at start of run."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -83,11 +118,12 @@ async def test_predict_state_custom_event_emission():
|
||||
"""Test PredictState CustomEvent is emitted when predict_state_config is present."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
predict_config = {
|
||||
"document": {"tool": "write_doc", "tool_argument": "content"},
|
||||
"summary": {"tool": "summarize", "tool_argument": "text"},
|
||||
@@ -96,7 +132,7 @@ async def test_predict_state_custom_event_emission():
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -114,11 +150,12 @@ async def test_initial_state_snapshot_with_schema():
|
||||
"""Test initial StateSnapshotEvent emission when state_schema present."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
state_schema = {"document": {"type": "string"}}
|
||||
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
|
||||
|
||||
@@ -127,7 +164,7 @@ async def test_initial_state_snapshot_with_schema():
|
||||
"state": {"document": "Initial content"},
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -143,17 +180,18 @@ async def test_state_initialization_object_type():
|
||||
"""Test state initialization with object type in schema."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
state_schema = {"recipe": {"type": "object", "properties": {}}}
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
state_schema: dict[str, dict[str, Any]] = {"recipe": {"type": "object", "properties": {}}}
|
||||
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -169,17 +207,18 @@ async def test_state_initialization_array_type():
|
||||
"""Test state initialization with array type in schema."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
state_schema = {"steps": {"type": "array", "items": {}}}
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
state_schema: dict[str, dict[str, Any]] = {"steps": {"type": "array", "items": {}}}
|
||||
wrapper = AgentFrameworkAgent(agent=agent, state_schema=state_schema)
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -195,16 +234,17 @@ async def test_run_finished_event_emission():
|
||||
"""Test RunFinishedEvent is emitted at end of run."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -216,11 +256,12 @@ async def test_tool_result_confirm_changes_accepted():
|
||||
"""Test confirm_changes tool result handling when accepted."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Document updated")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Document updated")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
state_schema={"document": {"type": "string"}},
|
||||
@@ -228,8 +269,8 @@ async def test_tool_result_confirm_changes_accepted():
|
||||
)
|
||||
|
||||
# Simulate tool result message with acceptance
|
||||
tool_result = {"accepted": True, "steps": []}
|
||||
input_data = {
|
||||
tool_result: dict[str, Any] = {"accepted": True, "steps": []}
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool", # Tool result from UI
|
||||
@@ -240,7 +281,7 @@ async def test_tool_result_confirm_changes_accepted():
|
||||
"state": {"document": "Updated content"},
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -262,16 +303,17 @@ async def test_tool_result_confirm_changes_rejected():
|
||||
"""Test confirm_changes tool result handling when rejected."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
# Simulate tool result message with rejection
|
||||
tool_result = {"accepted": False, "steps": []}
|
||||
input_data = {
|
||||
tool_result: dict[str, Any] = {"accepted": False, "steps": []}
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -281,7 +323,7 @@ async def test_tool_result_confirm_changes_rejected():
|
||||
],
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -295,22 +337,23 @@ async def test_tool_result_function_approval_accepted():
|
||||
"""Test function approval tool result when steps are accepted."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
# Simulate tool result with multiple steps
|
||||
tool_result = {
|
||||
tool_result: dict[str, Any] = {
|
||||
"accepted": True,
|
||||
"steps": [
|
||||
{"id": "step1", "description": "Send email", "status": "enabled"},
|
||||
{"id": "step2", "description": "Create calendar event", "status": "enabled"},
|
||||
],
|
||||
}
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -320,7 +363,7 @@ async def test_tool_result_function_approval_accepted():
|
||||
],
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -340,19 +383,20 @@ async def test_tool_result_function_approval_rejected():
|
||||
"""Test function approval tool result when rejected."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="OK")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
# Simulate tool result rejection with steps
|
||||
tool_result = {
|
||||
tool_result: dict[str, Any] = {
|
||||
"accepted": False,
|
||||
"steps": [{"id": "step1", "description": "Send email", "status": "disabled"}],
|
||||
}
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -362,7 +406,7 @@ async def test_tool_result_function_approval_rejected():
|
||||
],
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -376,17 +420,16 @@ async def test_thread_metadata_tracking():
|
||||
"""Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
thread_metadata = {}
|
||||
thread_metadata: dict[str, Any] = {}
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
# Capture thread metadata from kwargs
|
||||
nonlocal thread_metadata
|
||||
if "thread" in kwargs:
|
||||
thread_metadata = kwargs["thread"].metadata
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
if chat_options.metadata:
|
||||
thread_metadata.update(chat_options.metadata)
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
input_data = {
|
||||
@@ -395,28 +438,28 @@ async def test_thread_metadata_tracking():
|
||||
"run_id": "test_run_456",
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
# Check thread metadata was set
|
||||
# Note: This test may need adjustment based on actual thread passing mechanism
|
||||
assert thread_metadata.get("ag_ui_thread_id") == "test_thread_123"
|
||||
assert thread_metadata.get("ag_ui_run_id") == "test_run_456"
|
||||
|
||||
|
||||
async def test_state_context_injection():
|
||||
"""Test that current state is injected into thread metadata."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
thread_metadata = {}
|
||||
thread_metadata: dict[str, Any] = {}
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
# Track if state context message was added
|
||||
nonlocal thread_metadata
|
||||
# In actual implementation, thread is passed and state is in metadata
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
if chat_options.metadata:
|
||||
thread_metadata.update(chat_options.metadata)
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
state_schema={"document": {"type": "string"}},
|
||||
@@ -427,27 +470,31 @@ async def test_state_context_injection():
|
||||
"state": {"document": "Test content"},
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
# State should be injected - this is validated by agent execution flow
|
||||
current_state = thread_metadata.get("current_state")
|
||||
if isinstance(current_state, str):
|
||||
current_state = json.loads(current_state)
|
||||
assert current_state == {"document": "Test content"}
|
||||
|
||||
|
||||
async def test_no_messages_provided():
|
||||
"""Test handling when no messages are provided."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
input_data = {"messages": []}
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -461,16 +508,17 @@ async def test_message_end_event_emission():
|
||||
"""Test TextMessageEndEvent is emitted for assistant messages."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello world")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello world")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -488,19 +536,20 @@ async def test_error_handling_with_exception():
|
||||
"""Test that exceptions during agent execution are re-raised."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class FailingChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
if False:
|
||||
yield
|
||||
raise RuntimeError("Simulated failure")
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
if False:
|
||||
yield ChatResponseUpdate(contents=[])
|
||||
raise RuntimeError("Simulated failure")
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=FailingChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
|
||||
with pytest.raises(RuntimeError, match="Simulated failure"):
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
async for _ in wrapper.run_agent(input_data):
|
||||
pass
|
||||
|
||||
|
||||
@@ -508,18 +557,18 @@ async def test_json_decode_error_in_tool_result():
|
||||
"""Test handling of orphaned tool result - should be sanitized out."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
# Should not be called since orphaned tool result is dropped
|
||||
if False:
|
||||
yield
|
||||
raise AssertionError("ChatClient should not be called with orphaned tool result")
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
if False:
|
||||
yield ChatResponseUpdate(contents=[])
|
||||
raise AssertionError("ChatClient should not be called with orphaned tool result")
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
# Send invalid JSON as tool result without preceding tool call
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -529,7 +578,7 @@ async def test_json_decode_error_in_tool_result():
|
||||
],
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -545,11 +594,12 @@ async def test_suppressed_summary_with_document_state():
|
||||
"""Test suppressed summary uses document state for confirmation message."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Response")])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Response")])
|
||||
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
wrapper = AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
state_schema={"document": {"type": "string"}},
|
||||
@@ -558,8 +608,8 @@ async def test_suppressed_summary_with_document_state():
|
||||
)
|
||||
|
||||
# Simulate confirmation with document state
|
||||
tool_result = {"accepted": True, "steps": []}
|
||||
input_data = {
|
||||
tool_result: dict[str, Any] = {"accepted": True, "steps": []}
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -570,7 +620,7 @@ async def test_suppressed_summary_with_document_state():
|
||||
"state": {"document": "This is the beginning of a document. It contains important information."},
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
"""Tests for backend tool rendering."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from ag_ui.core import (
|
||||
TextMessageContentEvent,
|
||||
TextMessageStartEvent,
|
||||
@@ -119,6 +121,9 @@ async def test_multiple_tool_results():
|
||||
assert isinstance(events[end_idx], ToolCallEndEvent)
|
||||
assert isinstance(events[result_idx], ToolCallResultEvent)
|
||||
|
||||
assert events[end_idx].tool_call_id == f"tool-{i + 1}"
|
||||
assert events[result_idx].tool_call_id == f"tool-{i + 1}"
|
||||
assert f"Result {i + 1}" in events[result_idx].content
|
||||
end_event = cast(ToolCallEndEvent, events[end_idx])
|
||||
result_event = cast(ToolCallResultEvent, events[result_idx])
|
||||
|
||||
assert end_event.tool_call_id == f"tool-{i + 1}"
|
||||
assert result_event.tool_call_id == f"tool-{i + 1}"
|
||||
assert f"Result {i + 1}" in result_event.content
|
||||
|
||||
@@ -14,7 +14,7 @@ from agent_framework_ag_ui._confirmation_strategies import (
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_steps():
|
||||
def sample_steps() -> list[dict[str, str]]:
|
||||
"""Sample steps for testing approval messages."""
|
||||
return [
|
||||
{"description": "Step 1: Do something", "status": "enabled"},
|
||||
@@ -24,7 +24,7 @@ def sample_steps():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_enabled_steps():
|
||||
def all_enabled_steps() -> list[dict[str, str]]:
|
||||
"""All steps enabled."""
|
||||
return [
|
||||
{"description": "Task A", "status": "enabled"},
|
||||
@@ -34,7 +34,7 @@ def all_enabled_steps():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_steps():
|
||||
def empty_steps() -> list[dict[str, str]]:
|
||||
"""Empty steps list."""
|
||||
return []
|
||||
|
||||
@@ -42,7 +42,7 @@ def empty_steps():
|
||||
class TestDefaultConfirmationStrategy:
|
||||
"""Tests for DefaultConfirmationStrategy."""
|
||||
|
||||
def test_on_approval_accepted_with_enabled_steps(self, sample_steps):
|
||||
def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None:
|
||||
strategy = DefaultConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(sample_steps)
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestDefaultConfirmationStrategy:
|
||||
assert "Step 3" not in message # Disabled step shouldn't appear
|
||||
assert "All steps completed successfully!" in message
|
||||
|
||||
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps):
|
||||
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None:
|
||||
strategy = DefaultConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(all_enabled_steps)
|
||||
|
||||
@@ -61,28 +61,28 @@ class TestDefaultConfirmationStrategy:
|
||||
assert "Task B" in message
|
||||
assert "Task C" in message
|
||||
|
||||
def test_on_approval_accepted_with_empty_steps(self, empty_steps):
|
||||
def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None:
|
||||
strategy = DefaultConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(empty_steps)
|
||||
|
||||
assert "Executing 0 approved steps" in message
|
||||
assert "All steps completed successfully!" in message
|
||||
|
||||
def test_on_approval_rejected(self, sample_steps):
|
||||
def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None:
|
||||
strategy = DefaultConfirmationStrategy()
|
||||
message = strategy.on_approval_rejected(sample_steps)
|
||||
|
||||
assert "No problem!" in message
|
||||
assert "What would you like me to change" in message
|
||||
|
||||
def test_on_state_confirmed(self):
|
||||
def test_on_state_confirmed(self) -> None:
|
||||
strategy = DefaultConfirmationStrategy()
|
||||
message = strategy.on_state_confirmed()
|
||||
|
||||
assert "Changes confirmed" in message
|
||||
assert "successfully" in message
|
||||
|
||||
def test_on_state_rejected(self):
|
||||
def test_on_state_rejected(self) -> None:
|
||||
strategy = DefaultConfirmationStrategy()
|
||||
message = strategy.on_state_rejected()
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestDefaultConfirmationStrategy:
|
||||
class TestTaskPlannerConfirmationStrategy:
|
||||
"""Tests for TaskPlannerConfirmationStrategy."""
|
||||
|
||||
def test_on_approval_accepted_with_enabled_steps(self, sample_steps):
|
||||
def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None:
|
||||
strategy = TaskPlannerConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(sample_steps)
|
||||
|
||||
@@ -103,7 +103,7 @@ class TestTaskPlannerConfirmationStrategy:
|
||||
assert "Step 3" not in message
|
||||
assert "All tasks completed successfully!" in message
|
||||
|
||||
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps):
|
||||
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None:
|
||||
strategy = TaskPlannerConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(all_enabled_steps)
|
||||
|
||||
@@ -112,28 +112,28 @@ class TestTaskPlannerConfirmationStrategy:
|
||||
assert "2. Task B" in message
|
||||
assert "3. Task C" in message
|
||||
|
||||
def test_on_approval_accepted_with_empty_steps(self, empty_steps):
|
||||
def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None:
|
||||
strategy = TaskPlannerConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(empty_steps)
|
||||
|
||||
assert "Executing your requested tasks" in message
|
||||
assert "All tasks completed successfully!" in message
|
||||
|
||||
def test_on_approval_rejected(self, sample_steps):
|
||||
def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None:
|
||||
strategy = TaskPlannerConfirmationStrategy()
|
||||
message = strategy.on_approval_rejected(sample_steps)
|
||||
|
||||
assert "No problem!" in message
|
||||
assert "revise the plan" in message
|
||||
|
||||
def test_on_state_confirmed(self):
|
||||
def test_on_state_confirmed(self) -> None:
|
||||
strategy = TaskPlannerConfirmationStrategy()
|
||||
message = strategy.on_state_confirmed()
|
||||
|
||||
assert "Tasks confirmed" in message
|
||||
assert "ready to execute" in message
|
||||
|
||||
def test_on_state_rejected(self):
|
||||
def test_on_state_rejected(self) -> None:
|
||||
strategy = TaskPlannerConfirmationStrategy()
|
||||
message = strategy.on_state_rejected()
|
||||
|
||||
@@ -144,7 +144,7 @@ class TestTaskPlannerConfirmationStrategy:
|
||||
class TestRecipeConfirmationStrategy:
|
||||
"""Tests for RecipeConfirmationStrategy."""
|
||||
|
||||
def test_on_approval_accepted_with_enabled_steps(self, sample_steps):
|
||||
def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None:
|
||||
strategy = RecipeConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(sample_steps)
|
||||
|
||||
@@ -154,7 +154,7 @@ class TestRecipeConfirmationStrategy:
|
||||
assert "Step 3" not in message
|
||||
assert "Recipe updated successfully!" in message
|
||||
|
||||
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps):
|
||||
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None:
|
||||
strategy = RecipeConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(all_enabled_steps)
|
||||
|
||||
@@ -163,28 +163,28 @@ class TestRecipeConfirmationStrategy:
|
||||
assert "2. Task B" in message
|
||||
assert "3. Task C" in message
|
||||
|
||||
def test_on_approval_accepted_with_empty_steps(self, empty_steps):
|
||||
def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None:
|
||||
strategy = RecipeConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(empty_steps)
|
||||
|
||||
assert "Updating your recipe" in message
|
||||
assert "Recipe updated successfully!" in message
|
||||
|
||||
def test_on_approval_rejected(self, sample_steps):
|
||||
def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None:
|
||||
strategy = RecipeConfirmationStrategy()
|
||||
message = strategy.on_approval_rejected(sample_steps)
|
||||
|
||||
assert "No problem!" in message
|
||||
assert "ingredients or steps" in message
|
||||
|
||||
def test_on_state_confirmed(self):
|
||||
def test_on_state_confirmed(self) -> None:
|
||||
strategy = RecipeConfirmationStrategy()
|
||||
message = strategy.on_state_confirmed()
|
||||
|
||||
assert "Recipe changes applied" in message
|
||||
assert "successfully" in message
|
||||
|
||||
def test_on_state_rejected(self):
|
||||
def test_on_state_rejected(self) -> None:
|
||||
strategy = RecipeConfirmationStrategy()
|
||||
message = strategy.on_state_rejected()
|
||||
|
||||
@@ -195,7 +195,7 @@ class TestRecipeConfirmationStrategy:
|
||||
class TestDocumentWriterConfirmationStrategy:
|
||||
"""Tests for DocumentWriterConfirmationStrategy."""
|
||||
|
||||
def test_on_approval_accepted_with_enabled_steps(self, sample_steps):
|
||||
def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None:
|
||||
strategy = DocumentWriterConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(sample_steps)
|
||||
|
||||
@@ -205,7 +205,7 @@ class TestDocumentWriterConfirmationStrategy:
|
||||
assert "Step 3" not in message
|
||||
assert "Document updated successfully!" in message
|
||||
|
||||
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps):
|
||||
def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None:
|
||||
strategy = DocumentWriterConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(all_enabled_steps)
|
||||
|
||||
@@ -214,27 +214,27 @@ class TestDocumentWriterConfirmationStrategy:
|
||||
assert "2. Task B" in message
|
||||
assert "3. Task C" in message
|
||||
|
||||
def test_on_approval_accepted_with_empty_steps(self, empty_steps):
|
||||
def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None:
|
||||
strategy = DocumentWriterConfirmationStrategy()
|
||||
message = strategy.on_approval_accepted(empty_steps)
|
||||
|
||||
assert "Applying your edits" in message
|
||||
assert "Document updated successfully!" in message
|
||||
|
||||
def test_on_approval_rejected(self, sample_steps):
|
||||
def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None:
|
||||
strategy = DocumentWriterConfirmationStrategy()
|
||||
message = strategy.on_approval_rejected(sample_steps)
|
||||
|
||||
assert "No problem!" in message
|
||||
assert "keep or modify" in message
|
||||
|
||||
def test_on_state_confirmed(self):
|
||||
def test_on_state_confirmed(self) -> None:
|
||||
strategy = DocumentWriterConfirmationStrategy()
|
||||
message = strategy.on_state_confirmed()
|
||||
|
||||
assert "Document edits applied!" in message
|
||||
|
||||
def test_on_state_rejected(self):
|
||||
def test_on_state_rejected(self) -> None:
|
||||
strategy = DocumentWriterConfirmationStrategy()
|
||||
message = strategy.on_state_rejected()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
"""Tests for document writer predictive state flow with confirm_changes."""
|
||||
|
||||
from ag_ui.core import EventType
|
||||
from ag_ui.core import EventType, StateDeltaEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent
|
||||
from agent_framework import FunctionCallContent, FunctionResultContent, TextContent
|
||||
from agent_framework._types import AgentRunResponseUpdate
|
||||
|
||||
@@ -35,16 +35,12 @@ async def test_streaming_document_with_state_deltas():
|
||||
assert any(e.type == EventType.TOOL_CALL_ARGS for e in events1)
|
||||
|
||||
# Second chunk - incomplete JSON, should try partial extraction
|
||||
tool_call_chunk2 = FunctionCallContent(
|
||||
call_id="call_123",
|
||||
name=None, # Name only in first chunk
|
||||
arguments=" upon a time",
|
||||
)
|
||||
tool_call_chunk2 = FunctionCallContent(call_id="call_123", name="write_document_local", arguments=" upon a time")
|
||||
update2 = AgentRunResponseUpdate(contents=[tool_call_chunk2])
|
||||
events2 = await bridge.from_agent_run_update(update2)
|
||||
|
||||
# Should emit StateDeltaEvent with partial document
|
||||
state_deltas = [e for e in events2 if e.type == EventType.STATE_DELTA]
|
||||
state_deltas = [e for e in events2 if isinstance(e, StateDeltaEvent)]
|
||||
assert len(state_deltas) >= 1
|
||||
|
||||
# Check JSON Patch format
|
||||
@@ -62,7 +58,7 @@ async def test_confirm_changes_emission():
|
||||
"document": {"tool": "write_document_local", "tool_argument": "document"},
|
||||
}
|
||||
|
||||
current_state = {}
|
||||
current_state: dict[str, str] = {}
|
||||
|
||||
bridge = AgentFrameworkEventBridge(
|
||||
run_id="test_run",
|
||||
@@ -90,15 +86,13 @@ async def test_confirm_changes_emission():
|
||||
assert any(e.type == EventType.STATE_SNAPSHOT for e in events)
|
||||
|
||||
# Check for confirm_changes tool call
|
||||
confirm_starts = [
|
||||
e for e in events if e.type == EventType.TOOL_CALL_START and e.tool_call_name == "confirm_changes"
|
||||
]
|
||||
confirm_starts = [e for e in events if isinstance(e, ToolCallStartEvent) and e.tool_call_name == "confirm_changes"]
|
||||
assert len(confirm_starts) == 1
|
||||
|
||||
confirm_args = [e for e in events if e.type == EventType.TOOL_CALL_ARGS and e.delta == "{}"]
|
||||
confirm_args = [e for e in events if isinstance(e, ToolCallArgsEvent) and e.delta == "{}"]
|
||||
assert len(confirm_args) >= 1
|
||||
|
||||
confirm_ends = [e for e in events if e.type == EventType.TOOL_CALL_END]
|
||||
confirm_ends = [e for e in events if isinstance(e, ToolCallEndEvent)]
|
||||
# At least 2: one for write_document_local, one for confirm_changes
|
||||
assert len(confirm_ends) >= 2
|
||||
|
||||
@@ -141,7 +135,7 @@ async def test_no_confirm_for_non_predictive_tools():
|
||||
"document": {"tool": "write_document_local", "tool_argument": "document"},
|
||||
}
|
||||
|
||||
current_state = {}
|
||||
current_state: dict[str, str] = {}
|
||||
|
||||
bridge = AgentFrameworkEventBridge(
|
||||
run_id="test_run",
|
||||
@@ -162,9 +156,7 @@ async def test_no_confirm_for_non_predictive_tools():
|
||||
events = await bridge.from_agent_run_update(update)
|
||||
|
||||
# Should NOT have confirm_changes
|
||||
confirm_starts = [
|
||||
e for e in events if e.type == EventType.TOOL_CALL_START and e.tool_call_name == "confirm_changes"
|
||||
]
|
||||
confirm_starts = [e for e in events if isinstance(e, ToolCallStartEvent) and e.tool_call_name == "confirm_changes"]
|
||||
assert len(confirm_starts) == 0
|
||||
|
||||
# Stop flag should NOT be set
|
||||
@@ -193,14 +185,14 @@ async def test_state_delta_deduplication():
|
||||
events1 = await bridge.from_agent_run_update(update1)
|
||||
|
||||
# Count state deltas
|
||||
state_deltas_1 = [e for e in events1 if e.type == EventType.STATE_DELTA]
|
||||
state_deltas_1 = [e for e in events1 if isinstance(e, StateDeltaEvent)]
|
||||
assert len(state_deltas_1) >= 1
|
||||
|
||||
# Second tool call with SAME document (shouldn't emit new delta)
|
||||
bridge.current_tool_call_name = "write_document_local"
|
||||
tool_call2 = FunctionCallContent(
|
||||
call_id="call_2",
|
||||
name=None,
|
||||
name="write_document_local",
|
||||
arguments='{"document":"Same text"}', # Identical content
|
||||
)
|
||||
update2 = AgentRunResponseUpdate(contents=[tool_call2])
|
||||
@@ -234,7 +226,7 @@ async def test_predict_state_config_multiple_fields():
|
||||
events = await bridge.from_agent_run_update(update)
|
||||
|
||||
# Should emit StateDeltaEvent for both fields
|
||||
state_deltas = [e for e in events if e.type == EventType.STATE_DELTA]
|
||||
state_deltas = [e for e in events if isinstance(e, StateDeltaEvent)]
|
||||
assert len(state_deltas) >= 2
|
||||
|
||||
# Check both fields are present
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
"""Tests for FastAPI endpoint creation (_endpoint.py)."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from agent_framework import ChatAgent, TextContent
|
||||
from agent_framework._types import ChatResponseUpdate
|
||||
@@ -13,22 +14,20 @@ from fastapi.testclient import TestClient
|
||||
from agent_framework_ag_ui._agent import AgentFrameworkAgent
|
||||
from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates
|
||||
|
||||
class MockChatClient:
|
||||
"""Mock chat client for testing."""
|
||||
|
||||
def __init__(self, response_text: str = "Test response"):
|
||||
self.response_text = response_text
|
||||
|
||||
async def get_streaming_response(self, messages: list[Any], chat_options: Any, **kwargs: Any):
|
||||
"""Mock streaming response."""
|
||||
yield ChatResponseUpdate(contents=[TextContent(text=self.response_text)])
|
||||
def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub:
|
||||
"""Create a typed chat client stub for endpoint tests."""
|
||||
updates = [ChatResponseUpdate(contents=[TextContent(text=response_text)])]
|
||||
return StreamingChatClientStub(stream_from_updates(updates))
|
||||
|
||||
|
||||
async def test_add_endpoint_with_agent_protocol():
|
||||
"""Test adding endpoint with raw AgentProtocol."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/test-agent")
|
||||
|
||||
@@ -42,7 +41,7 @@ async def test_add_endpoint_with_agent_protocol():
|
||||
async def test_add_endpoint_with_wrapped_agent():
|
||||
"""Test adding endpoint with pre-wrapped AgentFrameworkAgent."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
wrapped_agent = AgentFrameworkAgent(agent=agent, name="wrapped")
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, wrapped_agent, path="/wrapped-agent")
|
||||
@@ -57,7 +56,7 @@ async def test_add_endpoint_with_wrapped_agent():
|
||||
async def test_endpoint_with_state_schema():
|
||||
"""Test endpoint with state_schema parameter."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
state_schema = {"document": {"type": "string"}}
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/stateful", state_schema=state_schema)
|
||||
@@ -70,10 +69,37 @@ async def test_endpoint_with_state_schema():
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
async def test_endpoint_with_default_state_seed():
|
||||
"""Test endpoint seeds default state when client omits it."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
state_schema = {"proverbs": {"type": "array"}}
|
||||
default_state = {"proverbs": ["Keep the original."]}
|
||||
|
||||
add_agent_framework_fastapi_endpoint(
|
||||
app,
|
||||
agent,
|
||||
path="/default-state",
|
||||
state_schema=state_schema,
|
||||
default_state=default_state,
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post("/default-state", json={"messages": [{"role": "user", "content": "Hello"}]})
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
content = response.content.decode("utf-8")
|
||||
lines = [line for line in content.split("\n") if line.startswith("data: ")]
|
||||
snapshots = [json.loads(line[6:]) for line in lines if json.loads(line[6:]).get("type") == "STATE_SNAPSHOT"]
|
||||
assert snapshots, "Expected a STATE_SNAPSHOT event"
|
||||
assert snapshots[0]["snapshot"]["proverbs"] == default_state["proverbs"]
|
||||
|
||||
|
||||
async def test_endpoint_with_predict_state_config():
|
||||
"""Test endpoint with predict_state_config parameter."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/predictive", predict_state_config=predict_config)
|
||||
@@ -87,7 +113,7 @@ async def test_endpoint_with_predict_state_config():
|
||||
async def test_endpoint_request_logging():
|
||||
"""Test that endpoint logs request details."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/logged")
|
||||
|
||||
@@ -107,7 +133,7 @@ async def test_endpoint_request_logging():
|
||||
async def test_endpoint_event_streaming():
|
||||
"""Test that endpoint streams events correctly."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient("Streamed response"))
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client("Streamed response"))
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/stream")
|
||||
|
||||
@@ -141,14 +167,14 @@ async def test_endpoint_event_streaming():
|
||||
async def test_endpoint_error_handling():
|
||||
"""Test endpoint error handling during request parsing."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/failing")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Send invalid JSON to trigger parsing error before streaming
|
||||
response = client.post("/failing", data="invalid json", headers={"content-type": "application/json"})
|
||||
response = client.post("/failing", data=b"invalid json", headers={"content-type": "application/json"}) # type: ignore
|
||||
|
||||
# The exception handler catches it and returns JSON error
|
||||
assert response.status_code == 200
|
||||
@@ -160,8 +186,8 @@ async def test_endpoint_error_handling():
|
||||
async def test_endpoint_multiple_paths():
|
||||
"""Test adding multiple endpoints with different paths."""
|
||||
app = FastAPI()
|
||||
agent1 = ChatAgent(name="agent1", instructions="First agent", chat_client=MockChatClient("Response 1"))
|
||||
agent2 = ChatAgent(name="agent2", instructions="Second agent", chat_client=MockChatClient("Response 2"))
|
||||
agent1 = ChatAgent(name="agent1", instructions="First agent", chat_client=build_chat_client("Response 1"))
|
||||
agent2 = ChatAgent(name="agent2", instructions="Second agent", chat_client=build_chat_client("Response 2"))
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent1, path="/agent1")
|
||||
add_agent_framework_fastapi_endpoint(app, agent2, path="/agent2")
|
||||
@@ -178,7 +204,7 @@ async def test_endpoint_multiple_paths():
|
||||
async def test_endpoint_default_path():
|
||||
"""Test endpoint with default path."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent)
|
||||
|
||||
@@ -191,7 +217,7 @@ async def test_endpoint_default_path():
|
||||
async def test_endpoint_response_headers():
|
||||
"""Test that endpoint sets correct response headers."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/headers")
|
||||
|
||||
@@ -207,7 +233,7 @@ async def test_endpoint_response_headers():
|
||||
async def test_endpoint_empty_messages():
|
||||
"""Test endpoint with empty messages list."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/empty")
|
||||
|
||||
@@ -220,7 +246,7 @@ async def test_endpoint_empty_messages():
|
||||
async def test_endpoint_complex_input():
|
||||
"""Test endpoint with complex input data."""
|
||||
app = FastAPI()
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, path="/complex")
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for AG-UI event converter."""
|
||||
|
||||
from agent_framework import FinishReason, Role
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Shared test stubs for AG-UI tests."""
|
||||
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, MutableSequence
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
AgentThread,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
TextContent,
|
||||
)
|
||||
from agent_framework._clients import BaseChatClient
|
||||
from agent_framework._types import ChatResponse, ChatResponseUpdate
|
||||
|
||||
from agent_framework_ag_ui._orchestrators import ExecutionContext
|
||||
|
||||
StreamFn = Callable[..., AsyncIterator[ChatResponseUpdate]]
|
||||
ResponseFn = Callable[..., Awaitable[ChatResponse]]
|
||||
|
||||
|
||||
class StreamingChatClientStub(BaseChatClient):
|
||||
"""Typed streaming stub that satisfies ChatClientProtocol."""
|
||||
|
||||
def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None:
|
||||
super().__init__()
|
||||
self._stream_fn = stream_fn
|
||||
self._response_fn = response_fn
|
||||
|
||||
async def _inner_get_streaming_response(
|
||||
self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
async for update in self._stream_fn(messages, chat_options, **kwargs):
|
||||
yield update
|
||||
|
||||
async def _inner_get_response(
|
||||
self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> ChatResponse:
|
||||
if self._response_fn is not None:
|
||||
return await self._response_fn(messages, chat_options, **kwargs)
|
||||
|
||||
contents: list[Any] = []
|
||||
async for update in self._stream_fn(messages, chat_options, **kwargs):
|
||||
contents.extend(update.contents)
|
||||
|
||||
return ChatResponse(
|
||||
messages=[ChatMessage(role="assistant", contents=contents)],
|
||||
response_id="stub-response",
|
||||
)
|
||||
|
||||
|
||||
def stream_from_updates(updates: list[ChatResponseUpdate]) -> StreamFn:
|
||||
"""Create a stream function that yields from a static list of updates."""
|
||||
|
||||
async def _stream(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
for update in updates:
|
||||
yield update
|
||||
|
||||
return _stream
|
||||
|
||||
|
||||
class StubAgent(AgentProtocol):
|
||||
"""Minimal AgentProtocol stub for orchestrator tests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
updates: list[AgentRunResponseUpdate] | None = None,
|
||||
*,
|
||||
agent_id: str = "stub-agent",
|
||||
agent_name: str | None = "stub-agent",
|
||||
chat_options: Any | None = None,
|
||||
chat_client: Any | None = None,
|
||||
) -> None:
|
||||
self._id = agent_id
|
||||
self._name = agent_name
|
||||
self._description = "stub agent"
|
||||
self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")]
|
||||
self.chat_options = chat_options or SimpleNamespace(tools=None, response_format=None)
|
||||
self.chat_client = chat_client or SimpleNamespace(function_invocation_configuration=None)
|
||||
self.messages_received: list[Any] = []
|
||||
self.tools_received: list[Any] | None = None
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._name or self._id
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
return self._description
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResponse:
|
||||
return AgentRunResponse(messages=[], response_id="stub-response")
|
||||
|
||||
def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
async def _stream() -> AsyncIterator[AgentRunResponseUpdate]:
|
||||
self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type]
|
||||
self.tools_received = kwargs.get("tools")
|
||||
for update in self.updates:
|
||||
yield update
|
||||
|
||||
return _stream()
|
||||
|
||||
def get_new_thread(self, **kwargs: Any) -> AgentThread:
|
||||
return AgentThread()
|
||||
|
||||
|
||||
class TestExecutionContext(ExecutionContext):
|
||||
"""ExecutionContext helper that allows setting messages for tests."""
|
||||
|
||||
def set_messages(self, messages: list[ChatMessage]) -> None:
|
||||
self._messages = messages
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent
|
||||
|
||||
from agent_framework_ag_ui._orchestration._message_hygiene import (
|
||||
deduplicate_messages,
|
||||
sanitize_tool_history,
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_tool_history_injects_confirm_changes_result() -> None:
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
name="confirm_changes",
|
||||
call_id="call_confirm_123",
|
||||
arguments='{"changes": "test"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
contents=[TextContent(text='{"accepted": true}')],
|
||||
),
|
||||
]
|
||||
|
||||
sanitized = sanitize_tool_history(messages)
|
||||
|
||||
tool_messages = [
|
||||
msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool"
|
||||
]
|
||||
assert len(tool_messages) == 1
|
||||
assert str(tool_messages[0].contents[0].call_id) == "call_confirm_123"
|
||||
assert tool_messages[0].contents[0].result == "Confirmed"
|
||||
|
||||
|
||||
def test_deduplicate_messages_prefers_non_empty_tool_results() -> None:
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
contents=[FunctionResultContent(call_id="call1", result="")],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
contents=[FunctionResultContent(call_id="call1", result="result data")],
|
||||
),
|
||||
]
|
||||
|
||||
deduped = deduplicate_messages(messages)
|
||||
assert len(deduped) == 1
|
||||
assert deduped[0].contents[0].result == "result data"
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for AG-UI orchestrators."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -34,6 +36,7 @@ class DummyAgent:
|
||||
*,
|
||||
thread: Any,
|
||||
tools: list[Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
|
||||
self.seen_tools = tools
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
"""Comprehensive tests for orchestrator coverage."""
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
@@ -15,11 +17,10 @@ from agent_framework import (
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_ag_ui._agent import AgentConfig
|
||||
from agent_framework_ag_ui._orchestrators import (
|
||||
DefaultOrchestrator,
|
||||
ExecutionContext,
|
||||
HumanInTheLoopOrchestrator,
|
||||
)
|
||||
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, HumanInTheLoopOrchestrator
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from test_helpers_ag_ui import StubAgent, TestExecutionContext
|
||||
|
||||
|
||||
@ai_function(approval_mode="always_require")
|
||||
@@ -28,34 +29,14 @@ def approval_tool(param: str) -> str:
|
||||
return f"executed: {param}"
|
||||
|
||||
|
||||
class MockAgent:
|
||||
"""Mock agent for testing."""
|
||||
|
||||
def __init__(self, updates: list[AgentRunResponseUpdate] | None = None) -> None:
|
||||
self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")]
|
||||
self.chat_options = SimpleNamespace(tools=[approval_tool], response_format=None)
|
||||
self.chat_client = SimpleNamespace(function_invocation_configuration=None)
|
||||
self.messages_received: list[Any] = []
|
||||
self.tools_received: list[Any] | None = None
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: list[Any],
|
||||
*,
|
||||
thread: Any = None,
|
||||
tools: list[Any] | None = None,
|
||||
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
|
||||
self.messages_received = messages
|
||||
self.tools_received = tools
|
||||
for update in self.updates:
|
||||
yield update
|
||||
DEFAULT_CHAT_OPTIONS = SimpleNamespace(tools=[approval_tool], response_format=None)
|
||||
|
||||
|
||||
async def test_human_in_the_loop_json_decode_error() -> None:
|
||||
"""Test HumanInTheLoopOrchestrator handles invalid JSON in tool result."""
|
||||
orchestrator = HumanInTheLoopOrchestrator()
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -72,21 +53,25 @@ async def test_human_in_the_loop_json_decode_error() -> None:
|
||||
)
|
||||
]
|
||||
|
||||
context = ExecutionContext(
|
||||
agent = StubAgent(
|
||||
chat_options=SimpleNamespace(tools=[approval_tool], response_format=None),
|
||||
updates=[AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")],
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=MockAgent(),
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
assert orchestrator.can_handle(context)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
# Should emit RunErrorEvent for invalid JSON
|
||||
error_events = [e for e in events if e.type == "RUN_ERROR"]
|
||||
error_events: list[Any] = [e for e in events if e.type == "RUN_ERROR"]
|
||||
assert len(error_events) == 1
|
||||
assert "Invalid tool result format" in error_events[0].message
|
||||
|
||||
@@ -118,18 +103,20 @@ async def test_sanitize_tool_history_confirm_changes() -> None:
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
# Use pre-constructed ChatMessage objects to bypass message adapter
|
||||
input_data = {"messages": []}
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
# Override the messages property to use our pre-constructed messages
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -162,16 +149,18 @@ async def test_sanitize_tool_history_orphaned_tool_result() -> None:
|
||||
]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": []}
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -188,7 +177,7 @@ async def test_orphaned_tool_result_sanitization() -> None:
|
||||
"""Test that orphaned tool results are filtered out."""
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -201,14 +190,16 @@ async def test_orphaned_tool_result_sanitization() -> None:
|
||||
],
|
||||
}
|
||||
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -241,16 +232,18 @@ async def test_deduplicate_messages_empty_tool_results() -> None:
|
||||
]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": []}
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -284,16 +277,18 @@ async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None:
|
||||
]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": []}
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -326,16 +321,18 @@ async def test_deduplicate_messages_duplicate_system_messages() -> None:
|
||||
]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": []}
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -354,7 +351,7 @@ async def test_state_context_injection() -> None:
|
||||
"""Test state context message injection for first request."""
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -364,14 +361,16 @@ async def test_state_context_injection() -> None:
|
||||
"state": {"items": ["apple", "banana"]},
|
||||
}
|
||||
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(state_schema={"items": {"type": "array"}}),
|
||||
)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -406,16 +405,18 @@ async def test_no_state_context_injection_with_tool_calls() -> None:
|
||||
]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": [], "state": {"weather": "sunny"}}
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
input_data: dict[str, Any] = {"messages": [], "state": {"weather": "sunny"}}
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(state_schema={"weather": {"type": "string"}}),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -437,7 +438,7 @@ async def test_structured_output_processing() -> None:
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -447,32 +448,33 @@ async def test_structured_output_processing() -> None:
|
||||
}
|
||||
|
||||
# Agent with structured output
|
||||
agent = MockAgent(
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
updates=[
|
||||
AgentRunResponseUpdate(
|
||||
contents=[TextContent(text='{"ingredients": ["tomato"], "message": "Added tomato"}')],
|
||||
role="assistant",
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
agent.chat_options.response_format = RecipeState
|
||||
|
||||
context = ExecutionContext(
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(state_schema={"ingredients": {"type": "array"}}),
|
||||
)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
# Should emit StateSnapshotEvent with ingredients
|
||||
state_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
|
||||
state_events: list[Any] = [e for e in events if e.type == "STATE_SNAPSHOT"]
|
||||
assert len(state_events) >= 1
|
||||
|
||||
# Should emit TextMessage with message field
|
||||
text_content_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
|
||||
text_content_events: list[Any] = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"]
|
||||
assert len(text_content_events) >= 1
|
||||
assert any("Added tomato" in e.delta for e in text_content_events)
|
||||
|
||||
@@ -487,7 +489,7 @@ async def test_duplicate_client_tools_filtered() -> None:
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -507,16 +509,18 @@ async def test_duplicate_client_tools_filtered() -> None:
|
||||
],
|
||||
}
|
||||
|
||||
agent = MockAgent()
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
agent.chat_options.tools = [get_weather]
|
||||
|
||||
context = ExecutionContext(
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -534,7 +538,7 @@ async def test_unique_client_tools_merged() -> None:
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -554,16 +558,18 @@ async def test_unique_client_tools_merged() -> None:
|
||||
],
|
||||
}
|
||||
|
||||
agent = MockAgent()
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
agent.chat_options.tools = [server_tool]
|
||||
|
||||
context = ExecutionContext(
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -578,16 +584,18 @@ async def test_empty_messages_handling() -> None:
|
||||
"""Test orchestrator handles empty message list gracefully."""
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
input_data = {"messages": []}
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -603,7 +611,7 @@ async def test_all_messages_filtered_handling() -> None:
|
||||
"""Test orchestrator handles case where all messages are filtered out."""
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -612,14 +620,16 @@ async def test_all_messages_filtered_handling() -> None:
|
||||
]
|
||||
}
|
||||
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -651,16 +661,18 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None:
|
||||
]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": []}
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -689,16 +701,18 @@ async def test_tool_result_kept_when_call_id_matches() -> None:
|
||||
]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": []}
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -729,6 +743,7 @@ async def test_agent_protocol_fallback_paths() -> None:
|
||||
*,
|
||||
thread: Any = None,
|
||||
tools: list[Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
|
||||
self.messages_received = messages
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")
|
||||
@@ -738,16 +753,16 @@ async def test_agent_protocol_fallback_paths() -> None:
|
||||
messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": []}
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
agent = CustomAgent()
|
||||
context = ExecutionContext(
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent, # type: ignore
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
@@ -762,21 +777,23 @@ async def test_initial_state_snapshot_with_array_schema() -> None:
|
||||
messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": [], "state": {}}
|
||||
agent = MockAgent()
|
||||
context = ExecutionContext(
|
||||
input_data: dict[str, Any] = {"messages": [], "state": {}}
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(state_schema={"items": {"type": "array"}}),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
# Should emit state snapshot with empty array for items
|
||||
state_events = [e for e in events if e.type == "STATE_SNAPSHOT"]
|
||||
state_events: list[Any] = [e for e in events if e.type == "STATE_SNAPSHOT"]
|
||||
assert len(state_events) >= 1
|
||||
|
||||
|
||||
@@ -791,19 +808,21 @@ async def test_response_format_skip_text_content() -> None:
|
||||
messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
|
||||
|
||||
orchestrator = DefaultOrchestrator()
|
||||
input_data = {"messages": []}
|
||||
input_data: dict[str, Any] = {"messages": []}
|
||||
|
||||
agent = MockAgent()
|
||||
agent = StubAgent(
|
||||
chat_options=DEFAULT_CHAT_OPTIONS,
|
||||
)
|
||||
agent.chat_options.response_format = OutputModel
|
||||
|
||||
context = ExecutionContext(
|
||||
context = TestExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
context._messages = messages
|
||||
context.set_messages(messages)
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
"""Tests for shared state management."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from ag_ui.core import StateSnapshotEvent
|
||||
from agent_framework import ChatAgent, TextContent
|
||||
@@ -10,20 +14,16 @@ from agent_framework._types import ChatResponseUpdate
|
||||
from agent_framework_ag_ui._agent import AgentFrameworkAgent
|
||||
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
def mock_agent() -> ChatAgent:
|
||||
"""Create a mock agent for testing."""
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Hello!")])
|
||||
|
||||
return ChatAgent(
|
||||
name="test_agent",
|
||||
instructions="Test agent",
|
||||
chat_client=MockChatClient(),
|
||||
)
|
||||
updates = [ChatResponseUpdate(contents=[TextContent(text="Hello!")])]
|
||||
chat_client = StreamingChatClientStub(stream_from_updates(updates))
|
||||
return ChatAgent(name="test_agent", instructions="Test agent", chat_client=chat_client)
|
||||
|
||||
|
||||
def test_state_snapshot_event():
|
||||
@@ -65,9 +65,9 @@ def test_state_delta_event():
|
||||
assert event.delta[1]["op"] == "replace"
|
||||
|
||||
|
||||
async def test_agent_with_initial_state(mock_agent):
|
||||
async def test_agent_with_initial_state(mock_agent: ChatAgent) -> None:
|
||||
"""Test agent emits state snapshot when initial state provided."""
|
||||
state_schema = {"recipe": {"type": "object", "properties": {"name": {"type": "string"}}}}
|
||||
state_schema: dict[str, Any] = {"recipe": {"type": "object", "properties": {"name": {"type": "string"}}}}
|
||||
|
||||
agent = AgentFrameworkAgent(
|
||||
agent=mock_agent,
|
||||
@@ -76,12 +76,12 @@ async def test_agent_with_initial_state(mock_agent):
|
||||
|
||||
initial_state = {"recipe": {"name": "Test Recipe"}}
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"state": initial_state,
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in agent.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -91,16 +91,16 @@ async def test_agent_with_initial_state(mock_agent):
|
||||
assert snapshot_events[0].snapshot == initial_state
|
||||
|
||||
|
||||
async def test_agent_without_state_schema(mock_agent):
|
||||
async def test_agent_without_state_schema(mock_agent: ChatAgent) -> None:
|
||||
"""Test agent doesn't emit state events without state schema."""
|
||||
agent = AgentFrameworkAgent(agent=mock_agent)
|
||||
|
||||
input_data = {
|
||||
input_data: dict[str, Any] = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"state": {"some": "state"},
|
||||
}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in agent.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from ag_ui.core import CustomEvent, EventType
|
||||
from agent_framework import ChatMessage, TextContent
|
||||
|
||||
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
|
||||
from agent_framework_ag_ui._orchestration._state_manager import StateManager
|
||||
|
||||
|
||||
def test_state_manager_initializes_defaults_and_snapshot() -> None:
|
||||
state_manager = StateManager(
|
||||
state_schema={"items": {"type": "array"}, "metadata": {"type": "object"}},
|
||||
predict_state_config=None,
|
||||
require_confirmation=True,
|
||||
)
|
||||
current_state = state_manager.initialize({"metadata": {"a": 1}})
|
||||
bridge = AgentFrameworkEventBridge(run_id="run", thread_id="thread", current_state=current_state)
|
||||
|
||||
snapshot_event = state_manager.initial_snapshot_event(bridge)
|
||||
assert snapshot_event is not None
|
||||
assert snapshot_event.snapshot["items"] == []
|
||||
assert snapshot_event.snapshot["metadata"] == {"a": 1}
|
||||
|
||||
|
||||
def test_state_manager_predict_state_event_shape() -> None:
|
||||
state_manager = StateManager(
|
||||
state_schema=None,
|
||||
predict_state_config={"doc": {"tool": "write_document_local", "tool_argument": "document"}},
|
||||
require_confirmation=True,
|
||||
)
|
||||
predict_event = state_manager.predict_state_event()
|
||||
assert isinstance(predict_event, CustomEvent)
|
||||
assert predict_event.type == EventType.CUSTOM
|
||||
assert predict_event.name == "PredictState"
|
||||
assert predict_event.value[0]["state_key"] == "doc"
|
||||
|
||||
|
||||
def test_state_context_only_when_new_user_turn() -> None:
|
||||
state_manager = StateManager(
|
||||
state_schema={"items": {"type": "array"}},
|
||||
predict_state_config=None,
|
||||
require_confirmation=True,
|
||||
)
|
||||
state_manager.initialize({"items": [1]})
|
||||
|
||||
assert state_manager.state_context_message(is_new_user_turn=False, conversation_has_tool_calls=False) is None
|
||||
|
||||
message = state_manager.state_context_message(is_new_user_turn=True, conversation_has_tool_calls=False)
|
||||
assert isinstance(message, ChatMessage)
|
||||
assert isinstance(message.contents[0], TextContent)
|
||||
assert "Current state of the application" in message.contents[0].text
|
||||
@@ -3,12 +3,18 @@
|
||||
"""Tests for structured output handling in _agent.py."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import AsyncIterator, MutableSequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatAgent, ChatOptions, TextContent
|
||||
from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent
|
||||
from agent_framework._types import ChatResponseUpdate
|
||||
from pydantic import BaseModel
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates
|
||||
|
||||
|
||||
class RecipeOutput(BaseModel):
|
||||
"""Test Pydantic model for recipe output."""
|
||||
@@ -34,14 +40,14 @@ async def test_structured_output_with_recipe():
|
||||
"""Test structured output processing with recipe state."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
# Simulate structured output
|
||||
yield ChatResponseUpdate(
|
||||
contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')]
|
||||
)
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')]
|
||||
)
|
||||
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
agent.chat_options = ChatOptions(response_format=RecipeOutput)
|
||||
|
||||
wrapper = AgentFrameworkAgent(
|
||||
@@ -51,7 +57,7 @@ async def test_structured_output_with_recipe():
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Make pasta"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -72,17 +78,18 @@ async def test_structured_output_with_steps():
|
||||
"""Test structured output processing with steps state."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
steps_data = {
|
||||
"steps": [
|
||||
{"id": "1", "description": "Step 1", "status": "pending"},
|
||||
{"id": "2", "description": "Step 2", "status": "pending"},
|
||||
]
|
||||
}
|
||||
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
steps_data = {
|
||||
"steps": [
|
||||
{"id": "1", "description": "Step 1", "status": "pending"},
|
||||
{"id": "2", "description": "Step 2", "status": "pending"},
|
||||
]
|
||||
}
|
||||
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))])
|
||||
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
agent.chat_options = ChatOptions(response_format=StepsOutput)
|
||||
|
||||
wrapper = AgentFrameworkAgent(
|
||||
@@ -92,7 +99,7 @@ async def test_structured_output_with_steps():
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Do steps"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -111,12 +118,13 @@ async def test_structured_output_with_no_schema_match():
|
||||
"""Test structured output when response fields don't match state_schema keys."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
# Response has "data" field but schema expects "result" field
|
||||
yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')])
|
||||
updates = [
|
||||
ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')]),
|
||||
]
|
||||
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(
|
||||
name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_from_updates(updates))
|
||||
)
|
||||
agent.chat_options = ChatOptions(response_format=GenericOutput)
|
||||
|
||||
wrapper = AgentFrameworkAgent(
|
||||
@@ -126,7 +134,7 @@ async def test_structured_output_with_no_schema_match():
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Generate data"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -146,11 +154,12 @@ async def test_structured_output_without_schema():
|
||||
data: dict[str, Any]
|
||||
info: str
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')])
|
||||
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
agent.chat_options = ChatOptions(response_format=DataOutput)
|
||||
|
||||
wrapper = AgentFrameworkAgent(
|
||||
@@ -160,7 +169,7 @@ async def test_structured_output_without_schema():
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Generate data"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -177,18 +186,20 @@ async def test_no_structured_output_when_no_response_format():
|
||||
"""Test that structured output path is skipped when no response_format."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="Regular text")])
|
||||
updates = [ChatResponseUpdate(contents=[TextContent(text="Regular text")])]
|
||||
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(
|
||||
name="test",
|
||||
instructions="Test",
|
||||
chat_client=StreamingChatClientStub(stream_from_updates(updates)),
|
||||
)
|
||||
# No response_format set
|
||||
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -202,12 +213,13 @@ async def test_structured_output_with_message_field():
|
||||
"""Test structured output that includes a message field."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"}
|
||||
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))])
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"}
|
||||
yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))])
|
||||
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
agent.chat_options = ChatOptions(response_format=RecipeOutput)
|
||||
|
||||
wrapper = AgentFrameworkAgent(
|
||||
@@ -217,7 +229,7 @@ async def test_structured_output_with_message_field():
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Make salad"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
@@ -236,20 +248,20 @@ async def test_empty_updates_no_structured_processing():
|
||||
"""Test that empty updates don't trigger structured output processing."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
class MockChatClient:
|
||||
async def get_streaming_response(self, messages, chat_options, **kwargs):
|
||||
# Return nothing
|
||||
if False:
|
||||
yield
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
if False:
|
||||
yield ChatResponseUpdate(contents=[])
|
||||
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=MockChatClient())
|
||||
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
|
||||
agent.chat_options = ChatOptions(response_format=RecipeOutput)
|
||||
|
||||
wrapper = AgentFrameworkAgent(agent=agent)
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Test"}]}
|
||||
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run_agent(input_data):
|
||||
events.append(event)
|
||||
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent_framework_ag_ui._orchestration._tooling import merge_tools, register_additional_client_tools
|
||||
|
||||
|
||||
class DummyTool:
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.declaration_only = True
|
||||
|
||||
|
||||
def test_merge_tools_filters_duplicates() -> None:
|
||||
server = [DummyTool("a"), DummyTool("b")]
|
||||
client = [DummyTool("b"), DummyTool("c")]
|
||||
|
||||
merged = merge_tools(server, client)
|
||||
|
||||
assert merged is not None
|
||||
names = [getattr(t, "name", None) for t in merged]
|
||||
assert names == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_register_additional_client_tools_assigns_when_configured() -> None:
|
||||
class Fic:
|
||||
def __init__(self) -> None:
|
||||
self.additional_tools = None
|
||||
|
||||
holder = SimpleNamespace(function_invocation_configuration=Fic())
|
||||
agent = SimpleNamespace(chat_client=holder)
|
||||
|
||||
tools = [DummyTool("x")]
|
||||
register_additional_client_tools(agent, tools)
|
||||
|
||||
assert holder.function_invocation_configuration.additional_tools == tools
|
||||
@@ -20,8 +20,8 @@ def test_generate_event_id():
|
||||
|
||||
def test_merge_state():
|
||||
"""Test state merging."""
|
||||
current = {"a": 1, "b": 2}
|
||||
update = {"b": 3, "c": 4}
|
||||
current: dict[str, int] = {"a": 1, "b": 2}
|
||||
update: dict[str, int] = {"b": 3, "c": 4}
|
||||
|
||||
result = merge_state(current, update)
|
||||
|
||||
@@ -32,8 +32,8 @@ def test_merge_state():
|
||||
|
||||
def test_merge_state_empty_update():
|
||||
"""Test merging with empty update."""
|
||||
current = {"x": 10, "y": 20}
|
||||
update = {}
|
||||
current: dict[str, int] = {"x": 10, "y": 20}
|
||||
update: dict[str, int] = {}
|
||||
|
||||
result = merge_state(current, update)
|
||||
|
||||
@@ -43,8 +43,8 @@ def test_merge_state_empty_update():
|
||||
|
||||
def test_merge_state_empty_current():
|
||||
"""Test merging with empty current state."""
|
||||
current = {}
|
||||
update = {"a": 1, "b": 2}
|
||||
current: dict[str, int] = {}
|
||||
update: dict[str, int] = {"a": 1, "b": 2}
|
||||
|
||||
result = merge_state(current, update)
|
||||
|
||||
@@ -53,8 +53,8 @@ def test_merge_state_empty_current():
|
||||
|
||||
def test_merge_state_deep_copy():
|
||||
"""Test that merge_state creates a deep copy preventing mutation of original."""
|
||||
current = {"recipe": {"name": "Cake", "ingredients": ["flour", "sugar"]}}
|
||||
update = {"other": "value"}
|
||||
current: dict[str, dict[str, object]] = {"recipe": {"name": "Cake", "ingredients": ["flour", "sugar"]}}
|
||||
update: dict[str, str] = {"other": "value"}
|
||||
|
||||
result = merge_state(current, update)
|
||||
|
||||
|
||||
@@ -876,7 +876,11 @@ class ChatAgent(BaseAgent):
|
||||
)
|
||||
# Filter chat_options from kwargs to prevent duplicate keyword argument
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
|
||||
response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **filtered_kwargs)
|
||||
response = await self.chat_client.get_response(
|
||||
messages=thread_messages,
|
||||
chat_options=co,
|
||||
**filtered_kwargs,
|
||||
)
|
||||
|
||||
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
|
||||
|
||||
@@ -1013,7 +1017,9 @@ class ChatAgent(BaseAgent):
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
|
||||
response_updates: list[ChatResponseUpdate] = []
|
||||
async for update in self.chat_client.get_streaming_response(
|
||||
messages=thread_messages, chat_options=co, **filtered_kwargs
|
||||
messages=thread_messages,
|
||||
chat_options=co,
|
||||
**filtered_kwargs,
|
||||
):
|
||||
response_updates.append(update)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user