mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
88968da0bd
* fix(ag-ui): execute tools after approval in human-in-the-loop flow * Fix shared state bug * Bug fix finalized * Refactoring to clean up code * Code cleanup * More fixes * More code cleanup * Add version detection in __init__.py to ruff ignore list
278 lines
8.5 KiB
Python
278 lines
8.5 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Tests for AG-UI orchestrators."""
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from types import SimpleNamespace
|
|
from typing import Any
|
|
|
|
from agent_framework import AgentRunResponseUpdate, TextContent, ai_function
|
|
from agent_framework._tools import FunctionInvocationConfiguration
|
|
|
|
from agent_framework_ag_ui._agent import AgentConfig
|
|
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext
|
|
|
|
|
|
@ai_function
|
|
def server_tool() -> str:
|
|
"""Server-executable tool."""
|
|
return "server"
|
|
|
|
|
|
class DummyAgent:
|
|
"""Minimal agent stub to capture run_stream parameters."""
|
|
|
|
def __init__(self) -> None:
|
|
self.chat_options = SimpleNamespace(tools=[server_tool], response_format=None)
|
|
self.tools = [server_tool]
|
|
self.chat_client = SimpleNamespace(
|
|
function_invocation_configuration=FunctionInvocationConfiguration(),
|
|
)
|
|
self.seen_tools: list[Any] | None = None
|
|
|
|
async def run_stream(
|
|
self,
|
|
messages: list[Any],
|
|
*,
|
|
thread: Any,
|
|
tools: list[Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
|
|
self.seen_tools = tools
|
|
yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
|
|
|
|
|
|
class RecordingAgent:
|
|
"""Agent stub that captures messages passed to run_stream."""
|
|
|
|
def __init__(self) -> None:
|
|
self.chat_options = SimpleNamespace(tools=[], response_format=None)
|
|
self.tools: list[Any] = []
|
|
self.chat_client = SimpleNamespace(
|
|
function_invocation_configuration=FunctionInvocationConfiguration(),
|
|
)
|
|
self.seen_messages: list[Any] | None = None
|
|
|
|
async def run_stream(
|
|
self,
|
|
messages: list[Any],
|
|
*,
|
|
thread: Any,
|
|
tools: list[Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
|
|
self.seen_messages = messages
|
|
yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
|
|
|
|
|
|
async def test_default_orchestrator_merges_client_tools() -> None:
|
|
"""Client tool declarations are merged with server tools before running agent."""
|
|
|
|
agent = DummyAgent()
|
|
orchestrator = DefaultOrchestrator()
|
|
|
|
input_data = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "input_text", "text": "Hello"}],
|
|
}
|
|
],
|
|
"tools": [
|
|
{
|
|
"name": "get_weather",
|
|
"description": "Client weather lookup.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"location": {"type": "string"}},
|
|
"required": ["location"],
|
|
},
|
|
}
|
|
],
|
|
}
|
|
|
|
context = ExecutionContext(
|
|
input_data=input_data,
|
|
agent=agent,
|
|
config=AgentConfig(),
|
|
)
|
|
|
|
events = []
|
|
async for event in orchestrator.run(context):
|
|
events.append(event)
|
|
|
|
assert agent.seen_tools is not None
|
|
tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools]
|
|
assert "server_tool" in tool_names
|
|
assert "get_weather" in tool_names
|
|
assert agent.chat_client.function_invocation_configuration.additional_tools
|
|
|
|
|
|
async def test_default_orchestrator_with_camel_case_ids() -> None:
|
|
"""Client tool is able to extract camelCase IDs."""
|
|
|
|
agent = DummyAgent()
|
|
orchestrator = DefaultOrchestrator()
|
|
|
|
input_data = {
|
|
"runId": "test-camelcase-runid",
|
|
"threadId": "test-camelcase-threadid",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "input_text", "text": "Hello"}],
|
|
}
|
|
],
|
|
"tools": [],
|
|
}
|
|
|
|
context = ExecutionContext(
|
|
input_data=input_data,
|
|
agent=agent,
|
|
config=AgentConfig(),
|
|
)
|
|
|
|
events = []
|
|
async for event in orchestrator.run(context):
|
|
events.append(event)
|
|
|
|
# assert the last event has the expected run_id and thread_id
|
|
last_event = events[-1]
|
|
assert last_event.run_id == "test-camelcase-runid"
|
|
assert last_event.thread_id == "test-camelcase-threadid"
|
|
|
|
|
|
async def test_default_orchestrator_with_snake_case_ids() -> None:
|
|
"""Client tool is able to extract snake_case IDs."""
|
|
|
|
agent = DummyAgent()
|
|
orchestrator = DefaultOrchestrator()
|
|
|
|
input_data = {
|
|
"run_id": "test-snakecase-runid",
|
|
"thread_id": "test-snakecase-threadid",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "input_text", "text": "Hello"}],
|
|
}
|
|
],
|
|
"tools": [],
|
|
}
|
|
|
|
context = ExecutionContext(
|
|
input_data=input_data,
|
|
agent=agent,
|
|
config=AgentConfig(),
|
|
)
|
|
|
|
events = []
|
|
async for event in orchestrator.run(context):
|
|
events.append(event)
|
|
|
|
# assert the last event has the expected run_id and thread_id
|
|
last_event = events[-1]
|
|
assert last_event.run_id == "test-snakecase-runid"
|
|
assert last_event.thread_id == "test-snakecase-threadid"
|
|
|
|
|
|
async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
|
|
"""State context should be injected when current state differs from tool call args."""
|
|
|
|
agent = RecordingAgent()
|
|
orchestrator = DefaultOrchestrator()
|
|
|
|
tool_recipe = {"title": "Salad", "special_preferences": []}
|
|
current_recipe = {"title": "Salad", "special_preferences": ["Vegetarian"]}
|
|
|
|
input_data = {
|
|
"state": {"recipe": current_recipe},
|
|
"messages": [
|
|
{"role": "system", "content": "Instructions"},
|
|
{
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": {"name": "update_recipe", "arguments": {"recipe": tool_recipe}},
|
|
}
|
|
],
|
|
},
|
|
{"role": "user", "content": "What are the dietary preferences?"},
|
|
],
|
|
}
|
|
|
|
context = ExecutionContext(
|
|
input_data=input_data,
|
|
agent=agent,
|
|
config=AgentConfig(
|
|
state_schema={"recipe": {"type": "object"}},
|
|
predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}},
|
|
require_confirmation=False,
|
|
),
|
|
)
|
|
|
|
async for _event in orchestrator.run(context):
|
|
pass
|
|
|
|
assert agent.seen_messages is not None
|
|
state_messages = []
|
|
for msg in agent.seen_messages:
|
|
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
|
if role_value != "system":
|
|
continue
|
|
for content in msg.contents or []:
|
|
if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"):
|
|
state_messages.append(content.text)
|
|
assert state_messages
|
|
assert "Vegetarian" in state_messages[0]
|
|
|
|
|
|
async def test_state_context_not_injected_when_tool_call_matches_state() -> None:
|
|
"""State context should be skipped when tool call args match current state."""
|
|
|
|
agent = RecordingAgent()
|
|
orchestrator = DefaultOrchestrator()
|
|
|
|
input_data = {
|
|
"messages": [
|
|
{"role": "system", "content": "Instructions"},
|
|
{
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": {"name": "update_recipe", "arguments": {"recipe": {}}},
|
|
}
|
|
],
|
|
},
|
|
{"role": "user", "content": "What are the dietary preferences?"},
|
|
],
|
|
}
|
|
|
|
context = ExecutionContext(
|
|
input_data=input_data,
|
|
agent=agent,
|
|
config=AgentConfig(
|
|
state_schema={"recipe": {"type": "object"}},
|
|
predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}},
|
|
require_confirmation=False,
|
|
),
|
|
)
|
|
|
|
async for _event in orchestrator.run(context):
|
|
pass
|
|
|
|
assert agent.seen_messages is not None
|
|
state_messages = []
|
|
for msg in agent.seen_messages:
|
|
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
|
if role_value != "system":
|
|
continue
|
|
for content in msg.contents or []:
|
|
if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"):
|
|
state_messages.append(content.text)
|
|
assert not state_messages
|