Python: Add tau2 benchmark integration with comprehensive testing and documentation (#817)

* first commit to tau2-bench

* tau2-bench agent

* tau2 agent

* add condition

* checkpoint

* bug fix

* add tests

* fix tests

* add comments

* add comments

* minor fix

* fix

* batch test script

* .

* init.bak -> init.py

* fix mypy

* update readme

* fix env

* remove temp files

* setup tests

* fix gaia tasks

* fix tau2 tests

* fix coverage

* fix default version

* update cookiecutter template

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Yuge Zhang
2025-09-21 16:08:45 -07:00
committed by GitHub
Unverified
parent 52790b9f6a
commit 205cd700c8
34 changed files with 3675 additions and 9 deletions
@@ -36,6 +36,7 @@ requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
package-dir = {"" = "src"}
packages = ["agent_framework_lab_{{ cookiecutter.package_name }}", "agent_framework.lab.{{ cookiecutter.package_name }}"]
[tool.setuptools.package-data]
@@ -66,7 +67,7 @@ warn_return_any = true
warn_unreachable = true
show_error_codes = true
implicit_reexport = true
packages = ["agent_framework_lab_{{ cookiecutter.package_name }}"]
packages = ["src.agent_framework_lab_{{ cookiecutter.package_name }}"]
{% if cookiecutter.within_microsoft_agent_framework_repo == "y" %}
[tool.poe]
@@ -86,7 +87,7 @@ mypy = "mypy agent_framework_lab_{{ cookiecutter.package_name }}"
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]
pythonpath = ["src"]
addopts = "--strict-markers --strict-config"
markers = [
"unit: marks tests as unit tests",
@@ -4,13 +4,18 @@
{{ cookiecutter.package_description }}
"""
import importlib.metadata
# Import your main exports here
# from .main_module import MainClass, main_function
try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
__all__ = [
# List your exports here
# "MainClass",
# "main_function",
]
__version__ = "{{ cookiecutter.version }}"
@@ -4,9 +4,16 @@
GAIA benchmark module for Agent Framework.
"""
import importlib.metadata
from ._types import Evaluation, Evaluator, Prediction, Task, TaskResult, TaskRunner
from .gaia import GAIA, GAIATelemetryConfig, gaia_scorer, viewer_main
try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
__all__ = [
"GAIA",
"GAIATelemetryConfig",
@@ -19,5 +26,3 @@ __all__ = [
"TaskRunner",
"Evaluator",
]
__version__ = "0.1.0b1"
+2
View File
@@ -72,7 +72,9 @@ packages = ["agent_framework_lab_gaia"]
[tool.poe]
executor.type = "uv"
include = "../../../shared_tasks.toml"
[tool.poe.tasks]
test = "pytest --cov=agent_framework_lab_gaia --cov-report=term-missing:skip-covered tests"
[tool.pytest.ini_options]
testpaths = ["tests"]
@@ -4,6 +4,13 @@
RL Module for Microsoft Agent Framework
"""
import importlib.metadata
try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
# Import your main exports here
# from .main_module import MainClass, main_function
@@ -12,5 +19,3 @@ __all__ = [
# "MainClass",
# "main_function",
]
__version__ = "0.1.0b1"
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+190
View File
@@ -0,0 +1,190 @@
# Agent Framework Lab - τ²-bench
τ²-bench implements a simulation framework for evaluating customer service agents across various domains.
The framework orchestrates conversations between two AI agents:
- **Customer Service Agent**: Follows domain-specific policies and has access to tools (e.g., booking systems, databases)
- **User Simulator**: Simulates realistic customer behavior with specific goals and scenarios
Each evaluation runs a multi-turn conversation where the user simulator presents a customer service scenario, and the agent must resolve it following the domain policy while using available tools appropriately. The results are evaluated using τ²'s comprehensive evaluation system.
## Supported Domains
| Domain | Status | Description |
|--------|--------|-------------|
| **airline** | ✅ Supported | Customer service for airline booking, changes, and support |
| **retail** | 🚧 In Development | E-commerce customer support scenarios |
| **telecom** | 🚧 In Development | Telecommunications service support |
*Note: Currently only the airline domain is fully supported.*
## Installation
```bash
pip install agent-framework-lab-tau2
```
Download data from [Tau2-Bench](https://github.com/sierra-research/tau2-bench):
```bash
git clone https://github.com/sierra-research/tau2-bench.git
mv tau2-bench/data/ .
rm -rf tau2-bench
```
Export the data directory to `TAU2_DATA_DIR` environment variable:
```bash
export TAU2_DATA_DIR="data"
```
## Quick Start
### Running a Single Task
```python
import asyncio
from agent_framework.openai import OpenAIChatClient
from agent_framework_lab_tau2 import TaskRunner
from tau2.domains.airline.environment import get_tasks
async def run_single_task():
# Initialize the task runner
runner = TaskRunner(max_steps=50)
# Set up your LLM clients
assistant_client = OpenAIChatClient(
base_url="https://api.openai.com/v1",
api_key="your-api-key",
ai_model_id="gpt-4o"
)
user_client = OpenAIChatClient(
base_url="https://api.openai.com/v1",
api_key="your-api-key",
ai_model_id="gpt-4o-mini"
)
# Get a task and run it
tasks = get_tasks()
task = tasks[0] # Run the first task
conversation = await runner.run(task, assistant_client, user_client)
reward = runner.evaluate(task, conversation, runner.termination_reason)
print(f"Task completed with reward: {reward}")
# Run the example
asyncio.run(run_single_task())
```
### Running the Full Benchmark
Use the provided script to run the complete benchmark:
```bash
# Run with default models (gpt-4.1 for both agent and user)
python samples/run_benchmark.py
# Use custom models
python samples/run_benchmark.py --assistant gpt-4o --user gpt-4o-mini
# Debug a specific task
python samples/run_benchmark.py --debug-task-id task_001 --assistant gpt-4o
# Limit conversation length
python samples/run_benchmark.py --max-steps 20
```
## Results (on Airline Domain)
The following results are reproduced from our implementation of τ²-bench with `samples/run_benchmark.py`. It shows the average success rate over the dataset of 50 tasks.
| Agent Model | User Model | Success Rate |
|-------------|------------|----------|
| gpt-5 | gpt-4.1 | 62.0% |
| gpt-5-mini | gpt-4.1 | 52.0% |
| gpt-4.1 | gpt-4.1 | 60.0% |
| gpt-4.1-mini | gpt-4.1 | 50.0% |
| gpt-4.1 | gpt-4o-mini | 42.0% |
| gpt-4o | gpt-4.1 | 42.0% |
| gpt-4o-mini | gpt-4.1 | 26.0% |
## Advanced Usage
### Environment Configuration
Set required environment variables:
```bash
export OPENAI_BASE_URL="https://api.openai.com/v1"
export OPENAI_API_KEY="your-api-key"
# Optional: for custom endpoints
export OPENAI_BASE_URL="https://your-custom-endpoint.com/v1"
```
### Custom Agent Implementation
```python
from agent_framework_lab_tau2 import TaskRunner
from agent_framework import ChatAgent
class CustomTaskRunner(TaskRunner):
def assistant_agent(self, assistant_chat_client):
# Override to customize the assistant agent
return ChatAgent(
chat_client=assistant_chat_client,
instructions="Your custom system prompt here",
# Add custom tools, temperature, etc.
)
def user_simulator(self, user_chat_client, task):
# Override to customize the user simulator
return ChatAgent(
chat_client=user_chat_client,
instructions="Custom user simulator prompt",
)
```
### Custom Workflow Integration
```python
from agent_framework._workflow import WorkflowBuilder, AgentExecutor
from agent_framework_lab_tau2 import TaskRunner
class WorkflowTaskRunner(TaskRunner):
def build_conversation_workflow(self, assistant_agent, user_simulator_agent):
# Build a custom workflow
builder = WorkflowBuilder()
# Create agent executors
assistant_executor = AgentExecutor(assistant_agent, id="assistant_agent")
user_executor = AgentExecutor(user_simulator_agent, id="user_simulator")
# Add workflow edges and conditions
builder.set_start_executor(assistant_executor)
builder.add_edge(assistant_executor, user_executor)
builder.add_edge(user_executor, assistant_executor, condition=self.should_not_stop)
return builder.build()
```
### Utility Functions
```python
from agent_framework_lab_tau2 import patch_env_set_state, unpatch_env_set_state
# Enable compatibility patches for τ²-bench integration
patch_env_set_state()
# Disable patches when done
unpatch_env_set_state()
```
## Contributing
This package is part of the Microsoft Agent Framework Lab. Please see the main repository for contribution guidelines.
## License
This project is licensed under the MIT License - see the LICENSE file for details.
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft. All rights reserved.
"""
Tau2 Benchmark for Agent Framework.
"""
import importlib.metadata
from ._tau2_utils import patch_env_set_state, unpatch_env_set_state
from .runner import TaskRunner
try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
__all__ = [
"TaskRunner",
"patch_env_set_state",
"unpatch_env_set_state",
]
@@ -0,0 +1,112 @@
# Copyright (c) Microsoft. All rights reserved.
from agent_framework._types import ChatMessage, Contents, Role
from loguru import logger
def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
"""Flip message roles between assistant and user for role-playing scenarios.
Used in agent simulations where the assistant's messages become user inputs
and vice versa. Function calls are filtered out when flipping assistant
messages to user messages (since users typically don't make function calls).
"""
def filter_out_function_calls(messages: list[Contents]) -> list[Contents]:
"""Remove function call content from message contents."""
return [content for content in messages if content.type != "function_call"]
flipped_messages = []
for msg in messages:
if msg.role == Role.ASSISTANT:
# Flip assistant to user
contents = filter_out_function_calls(msg.contents)
if contents:
flipped_msg = ChatMessage(
role=Role.USER,
# The function calls will cause 400 when role is user
contents=contents,
author_name=msg.author_name,
message_id=msg.message_id,
)
flipped_messages.append(flipped_msg)
elif msg.role == Role.USER:
# Flip user to assistant
flipped_msg = ChatMessage(
role=Role.ASSISTANT, contents=msg.contents, author_name=msg.author_name, message_id=msg.message_id
)
flipped_messages.append(flipped_msg)
elif msg.role == Role.TOOL:
# Skip tool messages
pass
else:
# Keep other roles as-is (system, tool, etc.)
flipped_messages.append(msg)
return flipped_messages
def log_messages(messages: list[ChatMessage]) -> None:
"""Log messages with colored output based on role and content type.
Provides visual debugging by color-coding different message roles and
content types. Escapes HTML-like characters to prevent log formatting issues.
"""
_logger = logger.opt(colors=True)
for msg in messages:
# Handle different content types
if hasattr(msg, "contents") and msg.contents:
for content in msg.contents:
if hasattr(content, "type"):
if content.type == "text":
escape_text = content.text.replace("<", r"\<")
if msg.role == Role.SYSTEM:
_logger.info(f"<cyan>[SYSTEM]</cyan> {escape_text}")
elif msg.role == Role.USER:
_logger.info(f"<green>[USER]</green> {escape_text}")
elif msg.role == Role.ASSISTANT:
_logger.info(f"<blue>[ASSISTANT]</blue> {escape_text}")
elif msg.role == Role.TOOL:
_logger.info(f"<yellow>[TOOL]</yellow> {escape_text}")
else:
_logger.info(f"<magenta>[{msg.role.value.upper()}]</magenta> {escape_text}")
elif content.type == "function_call":
function_call_text = f"{content.name}({content.arguments})"
function_call_text = function_call_text.replace("<", r"\<")
_logger.info(f"<yellow>[TOOL_CALL]</yellow> 🔧 {function_call_text}")
elif content.type == "function_result":
function_result_text = f"ID:{content.call_id} -> {content.result}"
function_result_text = function_result_text.replace("<", r"\<")
_logger.info(f"<yellow>[TOOL_RESULT]</yellow> 🔨 {function_result_text}")
else:
content_text = str(content).replace("<", r"\<")
_logger.info(f"<magenta>[{msg.role.value.upper()}] ({content.type})</magenta> {content_text}")
else:
# Fallback for content without type
text_content = str(content).replace("<", r"\<")
if msg.role == Role.SYSTEM:
_logger.info(f"<cyan>[SYSTEM]</cyan> {text_content}")
elif msg.role == Role.USER:
_logger.info(f"<green>[USER]</green> {text_content}")
elif msg.role == Role.ASSISTANT:
_logger.info(f"<blue>[ASSISTANT]</blue> {text_content}")
elif msg.role == Role.TOOL:
_logger.info(f"<yellow>[TOOL]</yellow> {text_content}")
else:
_logger.info(f"<magenta>[{msg.role.value.upper()}]</magenta> {text_content}")
elif hasattr(msg, "text") and msg.text:
# Handle simple text messages
text_content = msg.text.replace("<", r"\<")
if msg.role == Role.SYSTEM:
_logger.info(f"<cyan>[SYSTEM]</cyan> {text_content}")
elif msg.role == Role.USER:
_logger.info(f"<green>[USER]</green> {text_content}")
elif msg.role == Role.ASSISTANT:
_logger.info(f"<blue>[ASSISTANT]</blue> {text_content}")
elif msg.role == Role.TOOL:
_logger.info(f"<yellow>[TOOL]</yellow> {text_content}")
else:
_logger.info(f"<magenta>[{msg.role.value.upper()}]</magenta> {text_content}")
else:
# Fallback for other message formats
text_content = str(msg).replace("<", r"\<")
_logger.info(f"<magenta>[{msg.role.value.upper()}]</magenta> {text_content}")
@@ -0,0 +1,131 @@
# Copyright (c) Microsoft. All rights reserved.
import json
from collections.abc import Sequence
from typing import Any
import tiktoken
from agent_framework._threads import ChatMessageList
from agent_framework._types import ChatMessage, Role
from loguru import logger
class SlidingWindowChatMessageList(ChatMessageList):
"""A token-aware sliding window implementation of ChatMessageList.
Maintains two message lists: complete history and truncated window.
Automatically removes oldest messages when token limit is exceeded.
Also removes leading tool messages to ensure valid conversation flow.
"""
def __init__(
self,
messages: Sequence[ChatMessage] | None = None,
max_tokens: int = 3800,
system_message: str | None = None,
tool_definitions: Any | None = None,
):
super().__init__(messages)
self._truncated_messages = self._messages.copy() # Separate truncated view
self.max_tokens = max_tokens
self.system_message = system_message # Included in token count
self.tool_definitions = tool_definitions
# An estimation based on a commonly used vocab table
self.encoding = tiktoken.get_encoding("o200k_base")
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
await super().add_messages(messages)
self._truncated_messages = self._messages.copy()
self.truncate_messages()
async def list_messages(self) -> list[ChatMessage]:
"""Get the current list of messages, which may be truncated."""
return self._truncated_messages
async def list_all_messages(self) -> list[ChatMessage]:
"""Get all messages from the store including the truncated ones."""
return self._messages
def truncate_messages(self) -> None:
while len(self._truncated_messages) > 0 and self.get_token_count() > self.max_tokens:
logger.warning("Messages exceed max tokens. Truncating oldest message.")
self._truncated_messages.pop(0)
# Remove leading tool messages
while len(self._truncated_messages) > 0 and self._truncated_messages[0].role == Role.TOOL:
logger.warning("Removing leading tool message because tool result cannot be the first message.")
self._truncated_messages.pop(0)
def get_token_count(self) -> int:
"""Estimate token count for a list of messages using tiktoken.
Args:
messages: List of ChatMessage objects
system_message: Optional system message to include in count
Returns:
Estimated token count
"""
total_tokens = 0
# Add system message tokens if provided
if self.system_message:
total_tokens += len(self.encoding.encode(self.system_message))
total_tokens += 4 # Extra tokens for system message formatting
for msg in self._truncated_messages:
# Add 4 tokens per message for role, formatting, etc.
total_tokens += 4
# Handle different content types
if hasattr(msg, "contents") and msg.contents:
for content in msg.contents:
if hasattr(content, "type"):
if content.type == "text":
total_tokens += len(self.encoding.encode(content.text))
elif content.type == "function_call":
total_tokens += 4
# Serialize function call and count tokens
func_call_data = {
"name": content.name,
"arguments": content.arguments,
}
total_tokens += self.estimate_any_object_token_count(func_call_data)
elif content.type == "function_result":
total_tokens += 4
# Serialize function result and count tokens
func_result_data = {
"call_id": content.call_id,
"result": content.result,
}
total_tokens += self.estimate_any_object_token_count(func_result_data)
else:
# For other content types, serialize the whole content
total_tokens += self.estimate_any_object_token_count(content)
else:
# Content without type, treat as text
total_tokens += self.estimate_any_object_token_count(content)
elif hasattr(msg, "text") and msg.text:
# Simple text message
total_tokens += self.estimate_any_object_token_count(msg.text)
else:
# Skip it
pass
if total_tokens > self.max_tokens / 2:
logger.opt(colors=True).warning(
f"<yellow>Total tokens {total_tokens} is "
f"{total_tokens / self.max_tokens * 100:.0f}% "
f"of max tokens {self.max_tokens}</yellow>"
)
elif total_tokens > self.max_tokens:
logger.opt(colors=True).warning(
f"<red>Total tokens {total_tokens} is over max tokens {self.max_tokens}. Will truncate messages.</red>"
)
return total_tokens
def estimate_any_object_token_count(self, obj: Any) -> int:
try:
serialized = json.dumps(obj)
except Exception:
serialized = str(obj)
return len(self.encoding.encode(serialized))
@@ -0,0 +1,244 @@
# Copyright (c) Microsoft. All rights reserved.
import json
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
import numpy as np
from agent_framework._tools import AIFunction
from agent_framework._types import ChatMessage
from loguru import logger
from pydantic import BaseModel
from tau2.data_model.message import ( # type: ignore[import-untyped]
AssistantMessage,
Message,
SystemMessage,
ToolCall,
ToolMessage,
UserMessage,
)
from tau2.data_model.tasks import EnvFunctionCall, InitializationData # type: ignore[import-untyped]
from tau2.environment.environment import Environment # type: ignore[import-untyped]
from tau2.environment.tool import Tool # type: ignore[import-untyped]
_original_set_state = Environment.set_state
def convert_tau2_tool_to_ai_function(tau2_tool: Tool) -> AIFunction[Any, Any]:
"""Convert a tau2 Tool to an AIFunction for agent framework compatibility.
Creates a wrapper that preserves the tool's interface while ensuring
results are deep-copied to prevent unintended mutations.
"""
def wrapped_func(**kwargs: Any) -> Any:
result = tau2_tool(**kwargs)
# Deep copy to prevent mutations of returned data
if isinstance(result, BaseModel):
result = result.model_copy(deep=True)
else:
result = deepcopy(result)
return result
return AIFunction(
name=tau2_tool.name,
description=tau2_tool._get_description(),
func=wrapped_func,
input_model=tau2_tool.params,
)
def convert_agent_framework_messages_to_tau2_messages(messages: list[ChatMessage]) -> list[Message]:
"""Convert agent framework ChatMessages to tau2 Message objects.
Handles role mapping, text extraction, function calls, and function results.
Function results are converted to separate ToolMessage instances.
"""
tau2_messages = []
for msg in messages:
role_str = str(msg.role)
# Extract text content from all text-type contents
text_content = None
text_contents = [c for c in msg.contents if hasattr(c, "text") and hasattr(c, "type") and c.type == "text"]
if text_contents:
text_content = " ".join(c.text for c in text_contents)
# Extract function calls and convert to ToolCall objects
function_calls = [c for c in msg.contents if hasattr(c, "type") and c.type == "function_call"]
tool_calls = None
if function_calls:
tool_calls = []
for fc in function_calls:
arguments = fc.parse_arguments() or {}
tool_call = ToolCall(
id=fc.call_id,
name=fc.name,
arguments=arguments,
requestor="assistant" if role_str == "assistant" else "user",
)
tool_calls.append(tool_call)
# Extract function results for separate ToolMessage creation
function_results = [c for c in msg.contents if hasattr(c, "type") and c.type == "function_result"]
# Create main message based on role
if role_str == "system":
tau2_messages.append(SystemMessage(role="system", content=text_content))
elif role_str == "user":
tau2_messages.append(UserMessage(role="user", content=text_content, tool_calls=tool_calls))
elif role_str == "assistant":
tau2_messages.append(AssistantMessage(role="assistant", content=text_content, tool_calls=tool_calls))
elif role_str == "tool":
# Tool messages are handled as function results below
pass
# Convert function results to separate ToolMessage instances
for fr in function_results:
dumpable_content = _dump_function_result(fr.result)
content = dumpable_content if isinstance(dumpable_content, str) else json.dumps(dumpable_content)
tool_msg = ToolMessage(
id=fr.call_id,
role="tool",
content=content,
requestor="assistant", # Most tool calls originate from assistant
error=fr.exception is not None,
)
tau2_messages.append(tool_msg)
return tau2_messages
def patch_env_set_state() -> None:
"""Patch Environment.set_state to allow inconsistent tool call results.
Modifies the original method to log warnings instead of raising errors
when actual tool results differ from expected results, enabling more
flexible testing and development workflows.
"""
def set_state(
self: Any,
initialization_data: InitializationData | None,
initialization_actions: list[EnvFunctionCall] | None,
message_history: list[Message],
) -> None:
if self.solo_mode:
if any(isinstance(message, UserMessage) for message in message_history):
raise ValueError("User messages are not allowed in solo mode")
def get_actions_from_messages(
messages: list[Message],
) -> list[tuple[ToolCall, ToolMessage]]:
"""
Get the actions from the messages.
"""
messages = deepcopy(messages)[::-1]
actions = []
while messages:
message = messages.pop()
if isinstance(message, ToolMessage):
raise ValueError("Tool message not expected. Tool messages should always follow a tool call.")
if isinstance(message, (AssistantMessage, UserMessage)) and message.is_tool_call():
tool_calls = message.tool_calls
if tool_calls is None:
raise ValueError("Tool message expected. Got None.")
for tc in tool_calls:
if len(messages) == 0:
raise ValueError("Tool message expected. Got None.")
tm = messages.pop()
if not isinstance(tm, ToolMessage):
raise ValueError(f"Tool message expected. Got {type(tm)}")
if tc.id != tm.id:
raise ValueError(f"Tool call id mismatch. Got {tc.id} and {tm.id}")
actions.append((tc, tm))
return actions
if initialization_data is not None:
if initialization_data.agent_data is not None:
self.tools.update_db(initialization_data.agent_data)
if initialization_data.user_data is not None:
self.user_tools.update_db(initialization_data.user_data)
if initialization_actions is not None:
for action in initialization_actions:
self.run_env_function_call(action)
action_responses = get_actions_from_messages(message_history)
for tool_call, expected_response in action_responses:
response = self.get_response(tool_call)
content = _recursive_json_deserialize(response.content)
expected_content = _recursive_json_deserialize(expected_response.content)
if content != expected_content:
diff = f"Tool call:\n{tool_call}\n\nReturned:\n{response}\n\nExpected:\n{expected_response}"
if isinstance(content, str) and content.startswith("Error:"):
# If the tool call resulted in an error, the difference can be ignored
logger.warning(f"Tool call resulted in an error. Ignoring the difference.\n{diff}")
else:
raise ValueError(
f"Tool call:\n{tool_call}\n\nReturned:\n{response}\n\nExpected:\n{expected_response}"
)
self.sync_tools()
Environment.set_state = set_state
def unpatch_env_set_state() -> None:
Environment.set_state = _original_set_state
def _dump_function_result(result: Any) -> Any:
if isinstance(result, BaseModel):
return result.model_dump_json()
elif isinstance(result, list):
return [_dump_function_result(item) for item in result]
elif isinstance(result, dict):
return {k: _dump_function_result(v) for k, v in result.items()}
elif result is None:
return None
else:
return result
def _to_native(obj: Any) -> Any:
"""Convert data retrieved from Panquet to data usable in AGL server."""
# 1) Arrays -> list (then recurse)
if isinstance(obj, np.ndarray):
return _to_native(obj.tolist())
# 2) NumPy scalar types -> Python scalars
if isinstance(obj, np.generic):
return _to_native(obj.item())
# 3) Dict-like -> dict
if isinstance(obj, Mapping):
return {_to_native(k): _to_native(v) for k, v in obj.items()}
# 4) Lists/Tuples/Sets -> list
if isinstance(obj, (list, tuple, set)):
return [_to_native(x) for x in obj]
# 5) Anything else: leave as-is
return obj
def _recursive_json_deserialize(obj: Any) -> Any:
"""
Recursively deserialize a JSON object.
"""
if isinstance(obj, str):
try:
deserialized = json.loads(obj)
return _recursive_json_deserialize(deserialized)
except (json.JSONDecodeError, TypeError):
return obj
elif isinstance(obj, list):
return [_recursive_json_deserialize(item) for item in obj]
elif isinstance(obj, dict):
return {k: _recursive_json_deserialize(v) for k, v in obj.items()}
else:
return obj
@@ -0,0 +1,424 @@
# Copyright (c) Microsoft. All rights reserved.
import uuid
from typing import cast
from agent_framework._agents import ChatAgent
from agent_framework._types import AgentRunResponse, ChatMessage, Role
from agent_framework._workflow import (
AgentExecutor,
AgentExecutorRequest,
AgentExecutorResponse,
FunctionExecutor,
Workflow,
WorkflowBuilder,
WorkflowContext,
)
from agent_framework.openai import OpenAIChatClient
from loguru import logger
from tau2.data_model.simulation import SimulationRun, TerminationReason # type: ignore[import-untyped]
from tau2.data_model.tasks import Task # type: ignore[import-untyped]
from tau2.domains.airline.environment import get_environment # type: ignore[import-untyped]
from tau2.evaluator.evaluator import EvaluationType, RewardInfo, evaluate_simulation # type: ignore[import-untyped]
from tau2.user.user_simulator import ( # type: ignore[import-untyped]
OUT_OF_SCOPE,
STOP,
TRANSFER,
get_global_user_sim_guidelines,
)
from tau2.utils.utils import get_now # type: ignore[import-untyped]
from ._message_utils import flip_messages, log_messages
from ._sliding_window import SlidingWindowChatMessageList
from ._tau2_utils import convert_agent_framework_messages_to_tau2_messages, convert_tau2_tool_to_ai_function
# Agent instructions matching tau2's LLMAgent
ASSISTANT_AGENT_INSTRUCTION = """
You are a customer service agent that helps the user according to the <policy> provided below.
In each turn you can either:
- Send a message to the user.
- Make a tool call.
You cannot do both at the same time.
Try to be helpful and always follow the policy. Always make sure you generate valid JSON only.
""".strip()
# Default first message from agent (matching tau2)
DEFAULT_FIRST_AGENT_MESSAGE = "Hi! How can I help you today?"
# Constants of Agent executor IDs
ASSISTANT_AGENT_ID = "assistant_agent"
USER_SIMULATOR_ID = "user_simulator"
ORCHESTRATOR_ID = "orchestrator"
class TaskRunner:
"""Orchestrates task execution using agent framework workflows for tau2 benchmarks.
Manages conversation flow between assistant agents and user simulators,
handles termination conditions, and evaluates performance using tau2 metrics.
Only "airline" domain is supported for now.
"""
# State tracking
step_count: int
full_conversation: list[ChatMessage]
termination_reason: TerminationReason | None
full_reward_info: RewardInfo | None
_final_user_message: list[ChatMessage] | None
_assistant_executor: AgentExecutor | None
_user_executor: AgentExecutor | None
# Configuration
max_steps: int
assistant_sampling_temperature: float
assistant_window_size: int
def __init__(self, max_steps: int, assistant_sampling_temperature: float = 0.0, assistant_window_size: int = 32768):
"""Initialize the TaskRunner.
Args:
max_steps: The maximum number of steps to run.
assistant_sampling_temperature: The sampling temperature for the assistant agent.
assistant_window_size: The window size for the assistant agent.
"""
self.assistant_sampling_temperature = assistant_sampling_temperature
self.assistant_window_size = assistant_window_size
self.max_steps = max_steps
self.reinit()
def reinit(self) -> "TaskRunner":
"""Reset all state for a new task run."""
self.step_count = 0
self.full_conversation = []
self.termination_reason = None
self.full_reward_info = None
self._final_user_message = None
self._assistant_executor = None
self._user_executor = None
logger.info("TaskRunner has been re-initialized.")
return self
def __repr__(self) -> str:
return (
f"TaskRunner(max_steps={self.max_steps}, step_count={self.step_count}, "
f"full_conversation_length={len(self.full_conversation)}, "
f"termination_reason={self.termination_reason}, full_reward_info={self.full_reward_info})"
)
def should_not_stop(self, response: AgentExecutorResponse) -> bool:
"""Based on the response, check whether we should or not stop the conversation."""
# Determine who sent this based on executor_id
is_from_agent = response.executor_id == ASSISTANT_AGENT_ID
is_from_user = response.executor_id == USER_SIMULATOR_ID
self.step_count += 1
logger.opt(colors=True).info(
f"<bold>[Step {self.step_count}] Received the following response from "
f"{'<blue>assistant</blue>' if is_from_agent else '<green>user</green>'}</bold>, "
f"routing to {'<green>user</green>' if is_from_agent else '<blue>assistant</blue>'}:"
)
log_messages(response.agent_run_response.messages)
if self.step_count >= self.max_steps:
logger.info(f"Max steps ({self.max_steps}) reached - terminating conversation")
self.termination_reason = TerminationReason.MAX_STEPS
# Terminate the workflow
return False
response_text = response.agent_run_response.text
if is_from_agent and self._is_agent_stop(response_text):
logger.info("Agent requested stop - terminating conversation")
self.termination_reason = TerminationReason.AGENT_STOP
return False
if is_from_user and self._is_user_stop(response_text):
logger.info(f"User requested stop with message: '{response_text}' - terminating conversation")
self.termination_reason = TerminationReason.USER_STOP
# The final user message won't appear in the assistant's message store,
# because it will never arrive there.
# We need to store it because it's needed for evaluation.
self._final_user_message = flip_messages(response.agent_run_response.messages)
return False
return True
def _is_agent_stop(self, _: str) -> bool:
"""Check if agent wants to stop the conversation."""
# Could check for specific stop tokens if agent uses them
return False # Agent doesn't have explicit stop in this setup
def _is_user_stop(self, text: str) -> bool:
"""Check if user wants to stop the conversation."""
return STOP in text or TRANSFER in text or OUT_OF_SCOPE in text
def assistant_agent(self, assistant_chat_client: OpenAIChatClient) -> ChatAgent:
"""Create an assistant agent.
Users can override this method to provide a custom assistant agent.
Args:
assistant_chat_client: The chat client for the assistant agent.
Returns:
The assistant agent.
"""
# Initialize tau2 environment and extract tools/policy
# This provides the domain-specific context (airline customer service in this case)
env = get_environment()
tools = env.get_tools() # Available actions the assistant can take
policy = env.get_policy() # Guidelines the assistant must follow
logger.info(
f"Environment has {len(env.get_tools())} tools: {', '.join([tool.name for tool in env.get_tools()])}"
)
# Convert tau2 tools to agent framework AIFunction format
# This bridges the gap between tau2's tool system and agent framework's expectations
ai_functions = [convert_tau2_tool_to_ai_function(tool) for tool in tools]
# Combines general customer service behavior with specific policy guidelines
assistant_system_prompt = f"""<instructions>
{ASSISTANT_AGENT_INSTRUCTION}
</instructions>
<policy>
{policy}
</policy>"""
# Assistant agent has:
# - Access to all domain tools (booking, cancellation, etc.)
# - Sliding window memory to handle long conversations within token limits
# - Temperature-controlled response generation
return ChatAgent(
chat_client=assistant_chat_client,
instructions=assistant_system_prompt,
tools=ai_functions, # type: ignore
temperature=self.assistant_sampling_temperature,
chat_message_store_factory=lambda: SlidingWindowChatMessageList(
system_message=assistant_system_prompt,
tool_definitions=[tool.openai_schema for tool in tools],
max_tokens=self.assistant_window_size,
),
)
def user_simulator(self, user_simuator_chat_client: OpenAIChatClient, task: Task) -> ChatAgent:
"""Create a user simulator agent.
Users can override this method to provide a custom user simulator agent.
Args:
user_simuator_chat_client: The chat client for the user simulator agent.
task: The task to be executed.
Returns:
The user simulator agent.
"""
# User simulator follows tau2's guidelines for realistic customer behavior
# No tools available - users typically don't have direct system access
user_sim_guidelines = get_global_user_sim_guidelines(use_tools=False)
# User simulator prompt combines general guidelines with task-specific scenario
user_sim_system_prompt = f"""{user_sim_guidelines}
<scenario>
{task.user_scenario.instructions}
</scenario>"""
return ChatAgent(
chat_client=user_simuator_chat_client,
instructions=user_sim_system_prompt,
temperature=0.0,
# No sliding window for user simulator to maintain full conversation context
# TODO(yuge): Consider adding user tools in future for more realistic scenarios
)
async def conversation_orchestrator(
self, response: AgentExecutorResponse, ctx: WorkflowContext[AgentExecutorRequest]
) -> None:
"""Orchestrate conversation flow between assistant and user simulator.
This is the central routing hub that:
1. Receives responses from either the assistant agent or user simulator
2. Flips message roles to create proper conversation flow (assistant->user, user->assistant)
3. Routes the flipped messages to the appropriate target agent
4. Maintains the conversation loop until termination conditions are met
Args:
response: The response from either assistant or user simulator agent
ctx: Workflow context for sending messages to other executors
"""
# Flip message roles for proper conversation flow
# Assistant messages become user messages and vice versa
flipped = flip_messages(response.agent_run_response.messages)
# Determine source to route to correct target
is_from_agent = response.executor_id == ASSISTANT_AGENT_ID
# Send flipped messages to the opposite agent
# Critical: Target ID must be specified to prevent broadcasting to both agents
await ctx.send_message(
AgentExecutorRequest(messages=flipped, should_respond=True),
target_id=USER_SIMULATOR_ID if is_from_agent else ASSISTANT_AGENT_ID,
)
def build_conversation_workflow(self, assistant_agent: ChatAgent, user_simulator_agent: ChatAgent) -> Workflow:
"""Build the conversation workflow.
Users can override this method to provide a custom conversation workflow.
Args:
assistant_agent: The assistant agent.
user_simulator_agent: The user simulator agent.
Returns:
The conversation workflow.
"""
# STEP 1: Create workflow executors
# Each executor wraps an agent or function for workflow orchestration
self._assistant_executor = AgentExecutor(assistant_agent, id=ASSISTANT_AGENT_ID)
self._user_executor = AgentExecutor(user_simulator_agent, id=USER_SIMULATOR_ID)
orchestrator = FunctionExecutor(func=self.conversation_orchestrator, id=ORCHESTRATOR_ID)
# STEP 2: Build the conversation workflow
# Creates a cyclic workflow: Orchestrator -> Assistant -> Orchestrator -> User -> Orchestrator...
# The orchestrator acts as a message router that flips roles and routes to appropriate agent
workflow = (
WorkflowBuilder(max_iterations=10000) # Unlimited - we control termination via should_not_stop
.set_start_executor(orchestrator) # Orchestrator manages the conversation flow
.add_edge(orchestrator, self._assistant_executor) # Route messages to assistant
.add_edge(
self._assistant_executor, orchestrator, condition=self.should_not_stop
) # Check termination after assistant
.add_edge(orchestrator, self._user_executor) # Route messages to user simulator
.add_edge(self._user_executor, orchestrator, condition=self.should_not_stop) # Check termination after user
.build()
)
return workflow
async def run(
self,
task: Task,
assistant_chat_client: OpenAIChatClient,
user_simuator_chat_client: OpenAIChatClient,
) -> list[ChatMessage]:
"""Run a tau2 task using workflow-based agent orchestration.
This method orchestrates a complex multi-agent simulation:
1. Sets up tau2 environment and converts tools for agent framework compatibility
2. Creates two agents: assistant (with tools) and user simulator (without tools)
3. Builds a workflow with orchestrated message routing between agents
4. Manages conversation flow until termination conditions are met
5. Returns complete conversation history for evaluation
Args:
task: Tau2 task containing scenario, policy, and evaluation criteria
assistant_chat_client: LLM client for the assistant agent
user_simuator_chat_client: LLM client for the user simulator
Returns:
Complete conversation history as ChatMessage list for evaluation
"""
logger.info(f"Starting workflow agent for task {task.id}: {task.description.purpose}") # type: ignore[unused-ignore]
logger.info(f"Assistant chat client: {assistant_chat_client}")
logger.info(f"User simulator chat client: {user_simuator_chat_client}")
# STEP 1: Create agents
assistant_agent = self.assistant_agent(assistant_chat_client)
user_simulator_agent = self.user_simulator(user_simuator_chat_client, task)
# STEP 2: Create the conversation workflow
workflow = self.build_conversation_workflow(assistant_agent, user_simulator_agent)
# STEP 3: Initialize conversation with standard greeting
# Matches tau2's expected conversation start pattern
logger.info(f"Starting workflow with hardcoded greeting: '{DEFAULT_FIRST_AGENT_MESSAGE}'")
first_message = ChatMessage(Role.ASSISTANT, text=DEFAULT_FIRST_AGENT_MESSAGE)
initial_greeting = AgentExecutorResponse(
executor_id=ASSISTANT_AGENT_ID,
agent_run_response=AgentRunResponse(messages=[first_message]),
full_conversation=[ChatMessage(Role.ASSISTANT, text=DEFAULT_FIRST_AGENT_MESSAGE)],
)
# STEP 4: Execute the workflow and collect results
# The workflow runs until termination conditions are met (max steps, stop signals, etc.)
await workflow.run(initial_greeting)
# STEP 5: Ensemble the conversation history needed for evaluation.
# It's coming from three parts:
# 1. The initial greeting
# 2. The assistant's message store (not just the truncated window)
# 3. The final user message (if any)
assistant_executor = cast(AgentExecutor, self._assistant_executor)
message_store = cast(SlidingWindowChatMessageList, assistant_executor._agent_thread.message_store)
full_conversation = [first_message] + await message_store.list_all_messages()
if self._final_user_message is not None:
full_conversation.extend(self._final_user_message)
logger.opt(colors=True).info(
f"<green>WORKFLOW COMPLETED WITH {len(full_conversation)} MESSAGES. "
f"Termination reason: {self.termination_reason}.</green>"
)
log_messages(full_conversation)
return full_conversation
def evaluate(
self, task_input: Task, conversation: list[ChatMessage], termination_reason: TerminationReason | None
) -> float:
"""Evaluate agent performance using tau2's comprehensive evaluation system.
Bridges agent framework conversation results with tau2's evaluation pipeline.
Considers task completion, policy adherence, conversation quality, and tool usage.
Args:
task_input: Original tau2 task containing evaluation criteria
conversation: Complete conversation history from workflow execution
termination_reason: How/why the conversation ended (affects scoring)
Returns:
Numeric reward score (0.0-1.0) representing overall performance
Side Effects:
Stores detailed evaluation results in self.full_reward_info
"""
# Handle missing termination reason (can happen with unexpected workflow endings)
if termination_reason is None:
termination_reason = TerminationReason.TOO_MANY_ERRORS
# Convert agent framework ChatMessages to tau2 Message format for evaluation
tau2_messages = convert_agent_framework_messages_to_tau2_messages(conversation)
# Package conversation and metadata for tau2's evaluation system
simulation = SimulationRun(
id=str(uuid.uuid4()), # Unique identifier for this evaluation run
task_id=task_input.id, # Links evaluation back to original task
start_time=get_now(), # Timestamp for evaluation records
end_time=get_now(), # Duration is 0 since this is post-hoc evaluation
duration=0.0,
termination_reason=termination_reason, # Context for how conversation ended
messages=tau2_messages, # The actual conversation to evaluate
)
# Run comprehensive multi-dimensional evaluation
# EvaluationType.ALL: evaluates task completion, policy adherence, conversation quality, ...
# solo_mode=False: indicates multi-agent conversation (assistant + user simulator)
self.full_reward_info = evaluate_simulation(
simulation=simulation,
task=task_input,
evaluation_type=EvaluationType.ALL,
solo_mode=False,
domain="airline",
)
logger.info(f"Evaluation completed - Reward: {self.full_reward_info.reward}, Info: {self.full_reward_info}")
return self.full_reward_info.reward # type: ignore[no-any-return]
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft. All rights reserved.
# This makes agent_framework a namespace package
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft. All rights reserved.
# This makes agent_framework.lab a namespace package
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft. All rights reserved.
# Import and re-export from the actual implementation
from agent_framework_lab_tau2 import * # noqa: F403, F401
+99
View File
@@ -0,0 +1,99 @@
[project]
name = "agent-framework-lab-tau2"
description = "Tau2 Benchmark for Agent Framework."
authors = [{ name = "Microsoft", email = "SK-Support@microsoft.com"}]
readme = "README.md"
requires-python = ">=3.10"
version = "0.1.0b1"
license-files = ["LICENSE"]
urls.homepage = "https://learn.microsoft.com/en-us/semantic-kernel/overview/"
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
urls.issues = "https://github.com/microsoft/agent-framework/issues"
classifiers = [
"License :: OSI Approved :: MIT License",
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = [
"agent-framework",
"pydantic>=2.0.0",
"tiktoken>=0.11.0",
"loguru>=0.7.3",
"tau2@git+https://github.com/sierra-research/tau2-bench@5ba9e3e56db57c5e4114bf7f901291f09b2c5619",
]
[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["agent_framework_lab_tau2", "agent_framework.lab.tau2"]
[tool.setuptools.package-dir]
"agent_framework.lab.tau2" = "namespace/agent_framework/lab/tau2"
"agent_framework_lab_tau2" = "agent_framework_lab_tau2"
[tool.setuptools.package-data]
agent_framework_lab_tau2 = ["py.typed"]
[tool.ruff]
line-length = 120
target-version = "py310"
extend-exclude = ["tests", "__pycache__"]
[tool.ruff.lint]
select = ["E", "F", "I", "W", "UP", "C4", "N"]
ignore = ["N803", "N806", "N999", "UP007"]
[tool.ruff.format]
quote-style = "double"
[tool.mypy]
python_version = "3.10"
strict = true
check_untyped_defs = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
disallow_untyped_decorators = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_return_any = true
warn_unreachable = true
show_error_codes = true
implicit_reexport = true
packages = ["agent_framework_lab_tau2"]
exclude = [
"data",
]
[tool.pyright]
exclude = ["**/data"]
[tool.poe]
executor.type = "uv"
include = "../../../shared_tasks.toml"
[tool.poe.tasks]
test = "pytest --cov=agent_framework_lab_tau2 --cov-report=term-missing:skip-covered tests"
mypy = "mypy agent_framework_lab_tau2"
setup-data = "python tests/setup_data.py"
purge-data = "python tests/purge_data.py"
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]
addopts = "--strict-markers --strict-config"
markers = [
"unit: marks tests as unit tests",
"integration: marks tests as integration tests",
]
env = [
"TAU2_DATA_DIR=data",
]
@@ -0,0 +1,4 @@
TAU2_DATA_DIR=/path/to/your/data
OPENAI_API_KEY=dummy
OPENAI_BASE_URL=http://127.0.0.1:12345/
@@ -0,0 +1,240 @@
# Copyright (c) Microsoft. All rights reserved.
import argparse
import asyncio
import json
import os
import traceback
from datetime import datetime
from typing import Any
from agent_framework.openai import OpenAIChatClient
from loguru import logger
from tau2.domains.airline.environment import get_tasks
from agent_framework_lab_tau2 import TaskRunner, patch_env_set_state
def to_dumpable(result: dict[str, Any]) -> dict[str, Any]:
"""Convert benchmark result to JSONL-serializable format.
Handles both successful runs and error cases, ensuring consistent output
format for downstream analysis. Converts Pydantic models to dictionaries
and extracts enum values for JSON compatibility.
"""
if "error" in result:
# Error case: minimal structure with zero reward
return {
"id": result["task"].id,
"error": result["error"],
"evaluation": {
"reward": 0.0, # Standard zero reward for failed runs
},
"config": result["config"],
"task": result["task"].model_dump(),
}
else:
# Success case: full result structure
return {
"id": result["task"].id,
"evaluation": result["evaluation"].model_dump(), # Detailed evaluation metrics
"config": result["config"], # Model configuration used
"termination_reason": result["termination_reason"].value, # Enum to string
"messages": [m.model_dump() for m in result["messages"]], # Full conversation
"task": result["task"].model_dump(), # Task specification
}
async def run_benchmark(assistant_model: str, user_model: str, debug_task_id: str | None, max_steps: int):
"""Run comprehensive tau2 benchmark evaluation using agent framework.
This is the main function that:
1. Sets up output file handling (full benchmark vs debug mode)
2. Loads tau2 task dataset and configures LLM clients
3. Runs each task through the agent framework workflow
4. Evaluates performance using tau2's multi-dimensional metrics
5. Aggregates results and calculates overall benchmark scores
Args:
assistant_model: Model ID for the customer service agent (e.g., "gpt-4o")
user_model: Model ID for the user simulator (e.g., "gpt-4o")
debug_task_id: Optional specific task ID to run (disables batch processing)
max_steps: Maximum conversation steps before forced termination
Output:
Creates timestamped JSONL file with detailed results for analysis
Prints summary statistics to console with colored logging
"""
# STEP 1: Configure output handling based on execution mode
result_fp = None
if debug_task_id is None:
# Full benchmark mode: create timestamped results file
timestamp = datetime.now().strftime("%m%d%H%M") # Format: MMDDHHMM
result_filename = f"results/{assistant_model}_user-{user_model}_{timestamp}.jsonl"
os.makedirs("results", exist_ok=True)
result_fp = open(result_filename, "a") # Append mode for resumability
logger.info(f"Results will be saved to: {result_filename}")
else:
# Debug mode: single task, no file output, verbose logging
logger.info(f"Debug mode: targeting task ID {debug_task_id}")
# STEP 2: Load tau2 dataset and validate environment
tasks = get_tasks() # Loads all tau2 airline customer service tasks
logger.info(f"Found {len(tasks)} tasks in the dataset")
_logger = logger.opt(colors=True) # Enable colored console output
# Validate required OpenAI configuration
# Both models use the same endpoint but can be different model types
openai_base_url = os.getenv("OPENAI_BASE_URL")
if openai_base_url is None:
raise ValueError("OPENAI_BASE_URL must be set")
openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key is None:
raise ValueError("OPENAI_API_KEY must be set")
# STEP 3: Initialize LLM clients for both agent roles
# Assistant: handles customer service with access to tools and policies
assistant_chat_client = OpenAIChatClient(
base_url=openai_base_url,
api_key=openai_api_key,
ai_model_id=assistant_model,
)
# User simulator: simulates realistic customer behavior and requests
user_simulator_chat_client = OpenAIChatClient(
base_url=openai_base_url,
api_key=openai_api_key,
ai_model_id=user_model,
)
# STEP 4: Filter task set for debug mode
if debug_task_id is not None:
tasks = [task for task in tasks if task.id == debug_task_id]
if not tasks:
logger.error(f"Task ID {debug_task_id} not found in dataset")
return
# STEP 5: Initialize evaluation tracking
all_rewards: list[float] = [] # Stores reward scores for final statistics
task_runner = TaskRunner(max_steps=max_steps) # Reusable workflow orchestrator
# STEP 6: Execute benchmark across all tasks
for task in tasks:
_logger.info(f"<red>Testing task #{task.id}</red>")
_logger.info(f"<cyan>Purpose:</cyan> {task.description.purpose}") # type: ignore
# Initialize result structure for this task
result: dict[str, Any] = {
"config": {
"assistant": assistant_chat_client.ai_model_id,
"user": user_simulator_chat_client.ai_model_id,
},
"task": task,
}
# Log user scenario context for transparency
if task.user_scenario and task.user_scenario.instructions:
_logger.info(f"<cyan>User scenario:</cyan> {task.user_scenario.instructions.reason_for_call}") # type: ignore
try:
# Execute the workflow: agent + user simulator conversation
conversation = await task_runner.run(task, assistant_chat_client, user_simulator_chat_client)
# Evaluate performance using tau2's comprehensive metrics
reward_value = task_runner.evaluate(task, conversation, task_runner.termination_reason)
# Store detailed results for analysis
result["evaluation"] = task_runner.full_reward_info # Full evaluation breakdown
result["messages"] = conversation # Complete conversation history
result["termination_reason"] = task_runner.termination_reason # How conversation ended
# Log evaluation results (escape HTML for colored output)
reward_str = str(task_runner.full_reward_info).replace("<", r"\<")
_logger.info(f"<cyan>Final evaluation:</cyan> {reward_str}")
except Exception as e:
# Robust error handling: capture all failures for analysis
_logger.error(f"<red>Error testing task #{task.id}:</red> {e}")
result["error"] = traceback.format_exc() # Full stack trace for debugging
traceback.print_exc() # Console output for immediate debugging
reward_value = 0.0 # Zero score for failed runs
# STEP 7: Persist results incrementally (enables partial analysis)
if result_fp is not None:
result_fp.write(json.dumps(to_dumpable(result), default=str) + "\n")
all_rewards.append(reward_value) # Track for final statistics
# Reset runner state for next task
task_runner.reinit()
# STEP 8: Finalize and report aggregate results
if result_fp is not None:
result_fp.close()
# Calculate overall benchmark performance
all_accuracy = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
# Report final statistics with colored formatting
_logger.info("<green>Final Results:</green>")
_logger.info(f"<cyan>All tasks accuracy:</cyan> {all_accuracy:.2f} ({int(sum(all_rewards))}/{len(tasks)})")
if __name__ == "__main__":
"""Command-line interface for tau2 benchmark execution.
Provides flexible execution modes:
- Full benchmark: Runs all tasks and generates timestamped results file
- Debug mode: Single task execution with verbose logging for development
- Environment patching: Optional compatibility layer for tau2-bench integration
Usage Examples:
# Full benchmark with default models
python run_benchmark.py
# Custom models
python run_benchmark.py --assistant gpt-4o --user gpt-4o-mini
# Debug specific task
python run_benchmark.py --debug-task-id task_123
# Disable environment patching for testing
python run_benchmark.py --disable-env-patch
"""
parser = argparse.ArgumentParser(description="Run tau2-agent-framework model test")
# Model configuration arguments
parser.add_argument("--assistant", type=str, default="gpt-4.1", help="Assistant model id, e.g., gpt-4.1-mini")
parser.add_argument("--user", type=str, default="gpt-4.1", help="User model id")
# Execution mode arguments
parser.add_argument(
"--debug-task-id", type=str, default=None, help="Debug a specific task ID (disables result file creation)"
)
parser.add_argument("--max-steps", type=int, default=100, help="Maximum number of steps to run")
# Environment configuration arguments
parser.add_argument("--disable-env-patch", action="store_true", help="Disable patching tau2-bench environment")
args = parser.parse_args()
# Apply environment patch for tau2-bench compatibility
# This modifies tau2's environment to be more flexible with tool call validation
if not args.disable_env_patch:
patch_env_set_state()
# Execute benchmark with configured parameters
asyncio.run(
run_benchmark(
assistant_model=args.assistant,
user_model=args.user,
debug_task_id=args.debug_task_id,
max_steps=args.max_steps,
)
)
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
@@ -0,0 +1,20 @@
# Copyright (c) Microsoft. All rights reserved.
import shutil
from pathlib import Path
def purge_tau2_data():
"""Purge tau2 data directory if it exists."""
data_dir = Path.cwd() / "data"
if data_dir.exists():
shutil.rmtree(data_dir)
print(f"Data directory at {data_dir} has been purged.")
else:
print("Data directory not found. Skipping purge.")
if __name__ == "__main__":
purge_tau2_data()
@@ -0,0 +1,62 @@
# Copyright (c) Microsoft. All rights reserved.
import shutil
import subprocess
from pathlib import Path
def setup_tau2_data():
"""Set up tau2 data directory by cloning repository if needed."""
# Get project directory (parent of tests directory)
data_dir = Path.cwd() / "data"
print(data_dir)
print("Setting up tau2 data directory...")
# Check if data directory already exists
if data_dir.exists():
print(f"Data directory already exists at {data_dir}")
else:
print("Data directory not found. Cloning tau2-bench repository...")
try:
# Clone the repository
print("Cloning https://github.com/sierra-research/tau2-bench.git...")
subprocess.run(
["git", "clone", "https://github.com/sierra-research/tau2-bench.git"],
check=True,
capture_output=True,
text=True,
)
# Move data directory
print("Moving data directory...")
tau2_bench_dir = Path.cwd() / "tau2-bench"
tau2_data_dir = tau2_bench_dir / "data"
if tau2_data_dir.exists():
shutil.move(str(tau2_data_dir), str(data_dir))
else:
raise FileNotFoundError(f"Data directory not found in cloned repository: {tau2_data_dir}")
# Clean up cloned repository
print("Cleaning up cloned repository...")
shutil.rmtree(tau2_bench_dir)
print("Data directory setup completed successfully!")
except subprocess.CalledProcessError as e:
print(f"ERROR: Failed to clone repository: {e}")
raise
except Exception as e:
print(f"ERROR: Failed to set up data directory: {e}")
raise
print(f"TAU2_DATA_DIR should be set to: {data_dir}")
return str(data_dir)
if __name__ == "__main__":
setup_tau2_data()
@@ -0,0 +1,265 @@
# Copyright (c) Microsoft. All rights reserved.
from unittest.mock import patch
from agent_framework._types import ChatMessage, Role, TextContent, FunctionCallContent, FunctionResultContent
from agent_framework_lab_tau2._message_utils import flip_messages, log_messages
def test_flip_messages_user_to_assistant():
"""Test flipping user message to assistant."""
messages = [
ChatMessage(
role=Role.USER, contents=[TextContent(text="Hello assistant")], author_name="User1", message_id="msg_001"
)
]
flipped = flip_messages(messages)
assert len(flipped) == 1
assert flipped[0].role == Role.ASSISTANT
assert flipped[0].text == "Hello assistant"
assert flipped[0].author_name == "User1"
assert flipped[0].message_id == "msg_001"
def test_flip_messages_assistant_to_user():
"""Test flipping assistant message to user."""
messages = [
ChatMessage(
role=Role.ASSISTANT,
contents=[TextContent(text="Hello user")],
author_name="Assistant1",
message_id="msg_002",
)
]
flipped = flip_messages(messages)
assert len(flipped) == 1
assert flipped[0].role == Role.USER
assert flipped[0].text == "Hello user"
assert flipped[0].author_name == "Assistant1"
assert flipped[0].message_id == "msg_002"
def test_flip_messages_assistant_with_function_calls_filtered():
"""Test that function calls are filtered out when flipping assistant to user."""
function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"})
messages = [
ChatMessage(
role=Role.ASSISTANT,
contents=[TextContent(text="I'll call a function"), function_call, TextContent(text="After the call")],
message_id="msg_003",
)
]
flipped = flip_messages(messages)
assert len(flipped) == 1
assert flipped[0].role == Role.USER
# Function call should be filtered out
assert len(flipped[0].contents) == 2
assert all(content.type == "text" for content in flipped[0].contents)
assert "I'll call a function" in flipped[0].text
assert "After the call" in flipped[0].text
def test_flip_messages_assistant_with_only_function_calls_skipped():
"""Test that assistant messages with only function calls are skipped."""
function_call = FunctionCallContent(call_id="call_456", name="another_function", arguments={"key": "value"})
messages = [
ChatMessage(role=Role.ASSISTANT, contents=[function_call], message_id="msg_004") # Only function call, no text
]
flipped = flip_messages(messages)
# Should be empty since the message had no text content after filtering
assert len(flipped) == 0
def test_flip_messages_tool_messages_skipped():
"""Test that tool messages are skipped."""
function_result = FunctionResultContent(call_id="call_789", result={"success": True})
messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
flipped = flip_messages(messages)
# Tool messages should be skipped
assert len(flipped) == 0
def test_flip_messages_system_messages_preserved():
"""Test that system messages are preserved as-is."""
messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System instruction")], message_id="sys_001")]
flipped = flip_messages(messages)
assert len(flipped) == 1
assert flipped[0].role == Role.SYSTEM
assert flipped[0].text == "System instruction"
assert flipped[0].message_id == "sys_001"
def test_flip_messages_mixed_conversation():
"""Test flipping a mixed conversation."""
function_call = FunctionCallContent(call_id="call_mixed", name="mixed_function", arguments={})
function_result = FunctionResultContent(call_id="call_mixed", result="function result")
messages = [
ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System prompt")]),
ChatMessage(role=Role.USER, contents=[TextContent(text="User question")]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Assistant response"), function_call]),
ChatMessage(role=Role.TOOL, contents=[function_result]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Final response")]),
]
flipped = flip_messages(messages)
# Should have: system (unchanged), assistant (from user), user (from assistant, filtered), assistant (from final assistant)
assert len(flipped) == 4
# Check each flipped message
assert flipped[0].role == Role.SYSTEM
assert flipped[0].text == "System prompt"
assert flipped[1].role == Role.ASSISTANT
assert flipped[1].text == "User question"
assert flipped[2].role == Role.USER
assert flipped[2].text == "Assistant response" # Function call filtered out
# Tool message skipped
assert flipped[3].role == Role.USER
assert flipped[3].text == "Final response"
def test_flip_messages_empty_list():
"""Test flipping empty message list."""
messages = []
flipped = flip_messages(messages)
assert len(flipped) == 0
def test_flip_messages_preserves_metadata():
"""Test that message metadata is preserved during flipping."""
messages = [
ChatMessage(
role=Role.USER, contents=[TextContent(text="Test message")], author_name="TestUser", message_id="test_123"
)
]
flipped = flip_messages(messages)
assert len(flipped) == 1
assert flipped[0].author_name == "TestUser"
assert flipped[0].message_id == "test_123"
@patch("agent_framework_lab_tau2._message_utils.logger")
def test_log_messages_text_content(mock_logger):
"""Test logging messages with text content."""
messages = [
ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]),
]
log_messages(messages)
# Should have called logger.info for each message
assert mock_logger.opt.return_value.info.call_count == 2
@patch("agent_framework_lab_tau2._message_utils.logger")
def test_log_messages_function_call(mock_logger):
"""Test logging messages with function calls."""
function_call = FunctionCallContent(call_id="call_log", name="log_function", arguments={"param": "value"})
messages = [ChatMessage(role=Role.ASSISTANT, contents=[function_call])]
log_messages(messages)
# Should log the function call
mock_logger.opt.return_value.info.assert_called()
call_args = mock_logger.opt.return_value.info.call_args[0][0]
assert "TOOL_CALL" in call_args
assert "log_function" in call_args
@patch("agent_framework_lab_tau2._message_utils.logger")
def test_log_messages_function_result(mock_logger):
"""Test logging messages with function results."""
function_result = FunctionResultContent(call_id="call_result", result="success")
messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
log_messages(messages)
# Should log the function result
mock_logger.opt.return_value.info.assert_called()
call_args = mock_logger.opt.return_value.info.call_args[0][0]
assert "TOOL_RESULT" in call_args
@patch("agent_framework_lab_tau2._message_utils.logger")
def test_log_messages_different_roles(mock_logger):
"""Test logging messages with different roles get different colors."""
messages = [
ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System")]),
ChatMessage(role=Role.USER, contents=[TextContent(text="User")]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Assistant")]),
ChatMessage(role=Role.TOOL, contents=[TextContent(text="Tool")]),
]
log_messages(messages)
# Should have called logger for each message
assert mock_logger.opt.return_value.info.call_count == 4
# Check that different color tags are used
calls = mock_logger.opt.return_value.info.call_args_list
system_call = calls[0][0][0]
user_call = calls[1][0][0]
assistant_call = calls[2][0][0]
tool_call = calls[3][0][0]
assert "cyan" in system_call or "SYSTEM" in system_call
assert "green" in user_call or "USER" in user_call
assert "blue" in assistant_call or "ASSISTANT" in assistant_call
assert "yellow" in tool_call or "TOOL" in tool_call
@patch("agent_framework_lab_tau2._message_utils.logger")
def test_log_messages_escapes_html(mock_logger):
"""Test that HTML-like characters are properly escaped in log output."""
messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Message with <tag> content")])]
log_messages(messages)
mock_logger.opt.return_value.info.assert_called()
call_args = mock_logger.opt.return_value.info.call_args[0][0]
# Should escape < characters
assert "\\<tag>" in call_args or "&lt;tag&gt;" in call_args
@patch("agent_framework_lab_tau2._message_utils.logger")
def test_log_messages_mixed_content_types(mock_logger):
"""Test logging messages with mixed content types."""
function_call = FunctionCallContent(call_id="mixed_call", name="mixed_function", arguments={"key": "value"})
messages = [
ChatMessage(
role=Role.ASSISTANT,
contents=[TextContent(text="I'll call a function"), function_call, TextContent(text="Done!")],
)
]
log_messages(messages)
# Should log multiple times for different content types
assert mock_logger.opt.return_value.info.call_count == 3
@@ -0,0 +1,267 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for sliding window message list."""
import pytest
from unittest.mock import patch
from agent_framework._types import ChatMessage, Role, TextContent, FunctionCallContent, FunctionResultContent
from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageList
def test_initialization_empty():
"""Test initializing with no messages."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
assert sliding_window.max_tokens == 1000
assert sliding_window.system_message is None
assert sliding_window.tool_definitions is None
assert len(sliding_window._messages) == 0
assert len(sliding_window._truncated_messages) == 0
def test_initialization_with_parameters():
"""Test initializing with system message and tool definitions."""
system_msg = "You are a helpful assistant"
tool_defs = [{"name": "test_tool", "description": "A test tool"}]
sliding_window = SlidingWindowChatMessageList(
max_tokens=2000, system_message=system_msg, tool_definitions=tool_defs
)
assert sliding_window.max_tokens == 2000
assert sliding_window.system_message == system_msg
assert sliding_window.tool_definitions == tool_defs
def test_initialization_with_messages():
"""Test initializing with existing messages."""
messages = [
ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]),
]
sliding_window = SlidingWindowChatMessageList(messages=messages, max_tokens=1000)
assert len(sliding_window._messages) == 2
assert len(sliding_window._truncated_messages) == 2
@pytest.mark.asyncio
async def test_add_messages_simple():
"""Test adding messages without truncation."""
sliding_window = SlidingWindowChatMessageList(max_tokens=10000) # Large limit
new_messages = [
ChatMessage(role=Role.USER, contents=[TextContent(text="What's the weather?")]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I can help with that.")]),
]
await sliding_window.add_messages(new_messages)
messages = await sliding_window.list_messages()
assert len(messages) == 2
assert messages[0].text == "What's the weather?"
assert messages[1].text == "I can help with that."
@pytest.mark.asyncio
async def test_list_all_messages_vs_list_messages():
"""Test difference between list_all_messages and list_messages."""
sliding_window = SlidingWindowChatMessageList(max_tokens=50) # Small limit to force truncation
# Add many messages to trigger truncation
messages = [
ChatMessage(role=Role.USER, contents=[TextContent(text=f"Message {i} with some content")]) for i in range(10)
]
await sliding_window.add_messages(messages)
truncated_messages = await sliding_window.list_messages()
all_messages = await sliding_window.list_all_messages()
# All messages should contain everything
assert len(all_messages) == 10
# Truncated messages should be fewer due to token limit
assert len(truncated_messages) < len(all_messages)
def test_get_token_count_basic():
"""Test basic token counting."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window._truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
token_count = sliding_window.get_token_count()
# Should be more than 0 (exact count depends on encoding)
assert token_count > 0
def test_get_token_count_with_system_message():
"""Test token counting includes system message."""
system_msg = "You are a helpful assistant"
sliding_window = SlidingWindowChatMessageList(max_tokens=1000, system_message=system_msg)
# Without messages
token_count_empty = sliding_window.get_token_count()
# Add a message
sliding_window._truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
token_count_with_message = sliding_window.get_token_count()
# With message should be more tokens
assert token_count_with_message > token_count_empty
assert token_count_empty > 0 # System message contributes tokens
def test_get_token_count_function_call():
"""Test token counting with function calls."""
function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"})
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window._truncated_messages = [ChatMessage(role=Role.ASSISTANT, contents=[function_call])]
token_count = sliding_window.get_token_count()
assert token_count > 0
def test_get_token_count_function_result():
"""Test token counting with function results."""
function_result = FunctionResultContent(call_id="call_123", result={"success": True, "data": "result"})
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
sliding_window._truncated_messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
token_count = sliding_window.get_token_count()
assert token_count > 0
@patch("agent_framework_lab_tau2._sliding_window.logger")
def test_truncate_messages_removes_old_messages(mock_logger):
"""Test that truncation removes old messages when token limit exceeded."""
sliding_window = SlidingWindowChatMessageList(max_tokens=20) # Very small limit
# Create messages that will exceed the limit
messages = [
ChatMessage(
role=Role.USER,
contents=[TextContent(text="This is a very long message that should exceed the token limit")],
),
ChatMessage(
role=Role.ASSISTANT,
contents=[TextContent(text="This is another very long message that should also exceed the token limit")],
),
ChatMessage(role=Role.USER, contents=[TextContent(text="Short msg")]),
]
sliding_window._truncated_messages = messages.copy()
sliding_window.truncate_messages()
# Should have fewer messages after truncation
assert len(sliding_window._truncated_messages) < len(messages)
# Should have logged warnings
assert mock_logger.warning.called
@patch("agent_framework_lab_tau2._sliding_window.logger")
def test_truncate_messages_removes_leading_tool_messages(mock_logger):
"""Test that truncation removes leading tool messages."""
sliding_window = SlidingWindowChatMessageList(max_tokens=10000) # Large limit
# Create messages starting with tool message
tool_message = ChatMessage(role=Role.TOOL, contents=[FunctionResultContent(call_id="call_123", result="result")])
user_message = ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])
sliding_window._truncated_messages = [tool_message, user_message]
sliding_window.truncate_messages()
# Tool message should be removed from the beginning
assert len(sliding_window._truncated_messages) == 1
assert sliding_window._truncated_messages[0].role == Role.USER
# Should have logged warning about removing tool message
mock_logger.warning.assert_called()
def test_estimate_any_object_token_count_dict():
"""Test token counting for dictionary objects."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
test_dict = {"key": "value", "number": 42}
token_count = sliding_window.estimate_any_object_token_count(test_dict)
assert token_count > 0
def test_estimate_any_object_token_count_string():
"""Test token counting for string objects."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
test_string = "This is a test string"
token_count = sliding_window.estimate_any_object_token_count(test_string)
assert token_count > 0
def test_estimate_any_object_token_count_non_serializable():
"""Test token counting for non-JSON-serializable objects."""
sliding_window = SlidingWindowChatMessageList(max_tokens=1000)
# Create an object that can't be JSON serialized
class CustomObject:
def __str__(self):
return "CustomObject instance"
custom_obj = CustomObject()
token_count = sliding_window.estimate_any_object_token_count(custom_obj)
# Should fall back to string representation
assert token_count > 0
@pytest.mark.asyncio
async def test_real_world_scenario():
"""Test a realistic conversation scenario."""
sliding_window = SlidingWindowChatMessageList(
max_tokens=30, system_message="You are a helpful assistant" # Moderate limit
)
# Simulate a conversation
conversation = [
ChatMessage(role=Role.USER, contents=[TextContent(text="Hello, how are you?")]),
ChatMessage(
role=Role.ASSISTANT, contents=[TextContent(text="I'm doing well, thank you! How can I help you today?")]
),
ChatMessage(role=Role.USER, contents=[TextContent(text="Can you tell me about the weather?")]),
ChatMessage(
role=Role.ASSISTANT,
contents=[
TextContent(
text="I'd be happy to help with weather information, but I don't have access to current weather data."
)
],
),
ChatMessage(role=Role.USER, contents=[TextContent(text="What about telling me a joke instead?")]),
ChatMessage(
role=Role.ASSISTANT,
contents=[TextContent(text="Sure! Why don't scientists trust atoms? Because they make up everything!")],
),
]
await sliding_window.add_messages(conversation)
current_messages = await sliding_window.list_messages()
all_messages = await sliding_window.list_all_messages()
# All messages should be preserved
assert len(all_messages) == 6
# Current messages might be truncated
assert len(current_messages) <= 6
# Token count should be within or close to limit
token_count = sliding_window.get_token_count()
# Allow some margin since truncation happens when exceeded
assert token_count <= sliding_window.max_tokens * 1.1
@@ -0,0 +1,189 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for tau2 utils module."""
from typing import Any, cast
from pydantic import BaseModel
from agent_framework._tools import AIFunction
from agent_framework._types import ChatMessage, Role, TextContent, FunctionCallContent, FunctionResultContent
from agent_framework_lab_tau2._tau2_utils import (
convert_tau2_tool_to_ai_function,
convert_agent_framework_messages_to_tau2_messages,
)
from tau2.data_model.message import SystemMessage, UserMessage, AssistantMessage, ToolMessage, ToolCall
from tau2.domains.airline.environment import get_environment
def test_convert_tau2_tool_to_ai_function_basic():
"""Test basic conversion from tau2 tool to AIFunction."""
# Get real tools from tau2 environment
env = get_environment()
tools = env.get_tools()
# Use the first available tool for testing
assert len(tools) > 0, "No tools available in environment"
tau2_tool = tools[0]
# Convert the tool
ai_function = convert_tau2_tool_to_ai_function(tau2_tool)
# Verify the conversion
assert isinstance(ai_function, AIFunction)
assert ai_function.name == tau2_tool.name
assert ai_function.description == tau2_tool._get_description()
assert ai_function.input_model == tau2_tool.params
# Test that the function is callable (we won't call it with real params to avoid side effects)
assert callable(ai_function.func)
def test_convert_tau2_tool_to_ai_function_multiple_tools():
"""Test conversion with multiple tau2 tools."""
# Get real tools from tau2 environment
env = get_environment()
tools = env.get_tools()
# Convert multiple tools
ai_functions = [convert_tau2_tool_to_ai_function(tool) for tool in tools[:3]] # Test first 3 tools
# Verify all conversions
for ai_function, tau2_tool in zip(ai_functions, tools[:3]):
assert isinstance(ai_function, AIFunction)
assert ai_function.name == tau2_tool.name
assert ai_function.description == tau2_tool._get_description()
assert ai_function.input_model == tau2_tool.params
assert callable(ai_function.func)
def test_convert_agent_framework_messages_to_tau2_messages_system():
"""Test converting system message."""
messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System instruction")])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
assert len(tau2_messages) == 1
assert isinstance(tau2_messages[0], SystemMessage)
assert tau2_messages[0].role == "system"
assert tau2_messages[0].content == "System instruction"
def test_convert_agent_framework_messages_to_tau2_messages_user():
"""Test converting user message."""
messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello assistant")])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
assert len(tau2_messages) == 1
assert isinstance(tau2_messages[0], UserMessage)
assert tau2_messages[0].role == "user"
assert tau2_messages[0].content == "Hello assistant"
assert tau2_messages[0].tool_calls is None
def test_convert_agent_framework_messages_to_tau2_messages_assistant():
"""Test converting assistant message."""
messages = [ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hello user")])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
assert len(tau2_messages) == 1
assert isinstance(tau2_messages[0], AssistantMessage)
assert tau2_messages[0].role == "assistant"
assert tau2_messages[0].content == "Hello user"
assert tau2_messages[0].tool_calls is None
def test_convert_agent_framework_messages_to_tau2_messages_with_function_call():
"""Test converting message with function call."""
function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"})
messages = [ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I'll call a function"), function_call])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
assert len(tau2_messages) == 1
assert isinstance(tau2_messages[0], AssistantMessage)
assert tau2_messages[0].content == "I'll call a function"
assert tau2_messages[0].tool_calls is not None
assert len(tau2_messages[0].tool_calls) == 1
tool_call = tau2_messages[0].tool_calls[0]
assert isinstance(tool_call, ToolCall)
assert tool_call.id == "call_123"
assert tool_call.name == "test_function"
assert tool_call.arguments == {"param": "value"}
assert tool_call.requestor == "assistant"
def test_convert_agent_framework_messages_to_tau2_messages_with_function_result():
"""Test converting message with function result."""
function_result = FunctionResultContent(call_id="call_123", result={"success": True, "data": "result data"})
messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
assert len(tau2_messages) == 1
assert isinstance(tau2_messages[0], ToolMessage)
assert tau2_messages[0].id == "call_123"
assert tau2_messages[0].role == "tool"
assert tau2_messages[0].content is not None
assert '{"success": true, "data": "result data"}' in tau2_messages[0].content
assert tau2_messages[0].requestor == "assistant"
assert tau2_messages[0].error is False
def test_convert_agent_framework_messages_to_tau2_messages_with_error():
"""Test converting function result with error."""
function_result = FunctionResultContent(
call_id="call_456", result="Error occurred", exception=Exception("Test error")
)
messages = [ChatMessage(role=Role.TOOL, contents=[function_result])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
assert len(tau2_messages) == 1
assert isinstance(tau2_messages[0], ToolMessage)
assert tau2_messages[0].error is True
def test_convert_agent_framework_messages_to_tau2_messages_multiple_text_contents():
"""Test converting message with multiple text contents."""
messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="First part"), TextContent(text="Second part")])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
assert len(tau2_messages) == 1
assert isinstance(tau2_messages[0], UserMessage)
assert tau2_messages[0].content == "First part Second part"
def test_convert_agent_framework_messages_to_tau2_messages_complex_scenario():
"""Test converting complex scenario with multiple message types."""
function_call = FunctionCallContent(call_id="call_789", name="complex_tool", arguments='{"key": "value"}')
function_result = FunctionResultContent(call_id="call_789", result={"output": "tool result"})
messages = [
ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System prompt")]),
ChatMessage(role=Role.USER, contents=[TextContent(text="User request")]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I'll help you"), function_call]),
ChatMessage(role=Role.TOOL, contents=[function_result]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Based on the result...")]),
]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
assert len(tau2_messages) == 5
assert isinstance(tau2_messages[0], SystemMessage)
assert isinstance(tau2_messages[1], UserMessage)
assert isinstance(tau2_messages[2], AssistantMessage)
assert isinstance(tau2_messages[3], ToolMessage)
assert isinstance(tau2_messages[4], AssistantMessage)
# Check the assistant message with tool call
assert tau2_messages[2].tool_calls is not None
assert len(tau2_messages[2].tool_calls) == 1
assert tau2_messages[2].tool_calls[0].name == "complex_tool"
+2 -1
View File
@@ -20,6 +20,7 @@ dev = [
"pytest>=8.4.1",
"pytest-asyncio>=1.0.0",
"pytest-cov>=6.2.1",
"pytest-env>=1.1.5",
"pytest-xdist[psutil]>=3.8.0",
"pytest-timeout>=2.3.1",
"mypy>=1.16.1",
@@ -164,7 +165,7 @@ exclude_dirs = ["tests", "./run_tasks_in_packages_if_exists.py", "./check_md_cod
executor.type = "uv"
[tool.poe.tasks]
markdown-code-lint = """uv run python check_md_code_blocks.py README.md ./packages/**/README.md ./samples/**/*.md --exclude cookiecutter-agent-framework-lab"""
markdown-code-lint = """uv run python check_md_code_blocks.py README.md ./packages/**/README.md ./samples/**/*.md --exclude cookiecutter-agent-framework-lab --exclude tau2"""
samples-code-check = """pyright ./samples"""
docs-install = "uv sync --all-packages --all-extras --dev -U --prerelease=if-necessary-or-explicit --group=docs"
docs-clean = "rm -rf docs/build"
+1311
View File
File diff suppressed because it is too large Load Diff