Files
Evan Mattson 88968da0bd Python: fix(ag-ui): Execute tools with approval_mode, fix shared state, code cleanup (#3079)
* fix(ag-ui): execute tools after approval in human-in-the-loop flow

* Fix shared state bug

* Bug fix finalized

* Refactoring to clean up code

* Code cleanup

* More fixes

* More code cleanup

* Add version detection in __init__.py to ruff ignore list
2026-01-09 03:08:05 +00:00

278 lines
8.5 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for AG-UI orchestrators."""
from collections.abc import AsyncGenerator
from types import SimpleNamespace
from typing import Any
from agent_framework import AgentRunResponseUpdate, TextContent, ai_function
from agent_framework._tools import FunctionInvocationConfiguration
from agent_framework_ag_ui._agent import AgentConfig
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext
@ai_function
def server_tool() -> str:
"""Server-executable tool."""
return "server"
class DummyAgent:
"""Minimal agent stub to capture run_stream parameters."""
def __init__(self) -> None:
self.chat_options = SimpleNamespace(tools=[server_tool], response_format=None)
self.tools = [server_tool]
self.chat_client = SimpleNamespace(
function_invocation_configuration=FunctionInvocationConfiguration(),
)
self.seen_tools: list[Any] | None = None
async def run_stream(
self,
messages: list[Any],
*,
thread: Any,
tools: list[Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
self.seen_tools = tools
yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
class RecordingAgent:
"""Agent stub that captures messages passed to run_stream."""
def __init__(self) -> None:
self.chat_options = SimpleNamespace(tools=[], response_format=None)
self.tools: list[Any] = []
self.chat_client = SimpleNamespace(
function_invocation_configuration=FunctionInvocationConfiguration(),
)
self.seen_messages: list[Any] | None = None
async def run_stream(
self,
messages: list[Any],
*,
thread: Any,
tools: list[Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
self.seen_messages = messages
yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
async def test_default_orchestrator_merges_client_tools() -> None:
"""Client tool declarations are merged with server tools before running agent."""
agent = DummyAgent()
orchestrator = DefaultOrchestrator()
input_data = {
"messages": [
{
"role": "user",
"content": [{"type": "input_text", "text": "Hello"}],
}
],
"tools": [
{
"name": "get_weather",
"description": "Client weather lookup.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
}
],
}
context = ExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
async for event in orchestrator.run(context):
events.append(event)
assert agent.seen_tools is not None
tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools]
assert "server_tool" in tool_names
assert "get_weather" in tool_names
assert agent.chat_client.function_invocation_configuration.additional_tools
async def test_default_orchestrator_with_camel_case_ids() -> None:
"""Client tool is able to extract camelCase IDs."""
agent = DummyAgent()
orchestrator = DefaultOrchestrator()
input_data = {
"runId": "test-camelcase-runid",
"threadId": "test-camelcase-threadid",
"messages": [
{
"role": "user",
"content": [{"type": "input_text", "text": "Hello"}],
}
],
"tools": [],
}
context = ExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
async for event in orchestrator.run(context):
events.append(event)
# assert the last event has the expected run_id and thread_id
last_event = events[-1]
assert last_event.run_id == "test-camelcase-runid"
assert last_event.thread_id == "test-camelcase-threadid"
async def test_default_orchestrator_with_snake_case_ids() -> None:
"""Client tool is able to extract snake_case IDs."""
agent = DummyAgent()
orchestrator = DefaultOrchestrator()
input_data = {
"run_id": "test-snakecase-runid",
"thread_id": "test-snakecase-threadid",
"messages": [
{
"role": "user",
"content": [{"type": "input_text", "text": "Hello"}],
}
],
"tools": [],
}
context = ExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
async for event in orchestrator.run(context):
events.append(event)
# assert the last event has the expected run_id and thread_id
last_event = events[-1]
assert last_event.run_id == "test-snakecase-runid"
assert last_event.thread_id == "test-snakecase-threadid"
async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
"""State context should be injected when current state differs from tool call args."""
agent = RecordingAgent()
orchestrator = DefaultOrchestrator()
tool_recipe = {"title": "Salad", "special_preferences": []}
current_recipe = {"title": "Salad", "special_preferences": ["Vegetarian"]}
input_data = {
"state": {"recipe": current_recipe},
"messages": [
{"role": "system", "content": "Instructions"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "update_recipe", "arguments": {"recipe": tool_recipe}},
}
],
},
{"role": "user", "content": "What are the dietary preferences?"},
],
}
context = ExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(
state_schema={"recipe": {"type": "object"}},
predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}},
require_confirmation=False,
),
)
async for _event in orchestrator.run(context):
pass
assert agent.seen_messages is not None
state_messages = []
for msg in agent.seen_messages:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value != "system":
continue
for content in msg.contents or []:
if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"):
state_messages.append(content.text)
assert state_messages
assert "Vegetarian" in state_messages[0]
async def test_state_context_not_injected_when_tool_call_matches_state() -> None:
"""State context should be skipped when tool call args match current state."""
agent = RecordingAgent()
orchestrator = DefaultOrchestrator()
input_data = {
"messages": [
{"role": "system", "content": "Instructions"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "update_recipe", "arguments": {"recipe": {}}},
}
],
},
{"role": "user", "content": "What are the dietary preferences?"},
],
}
context = ExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(
state_schema={"recipe": {"type": "object"}},
predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}},
require_confirmation=False,
),
)
async for _event in orchestrator.run(context):
pass
assert agent.seen_messages is not None
state_messages = []
for msg in agent.seen_messages:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value != "system":
continue
for content in msg.contents or []:
if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"):
state_messages.append(content.text)
assert not state_messages